feat: 重构MCP功能和AI服务提供者架构
This commit is contained in:
+118
-164
@@ -45,7 +45,7 @@ from app.services.memory_service import memory_service
|
||||
from app.services.chapter_regenerator import ChapterRegenerator
|
||||
from app.logger import get_logger
|
||||
from app.api.settings import get_user_ai_service
|
||||
from app.utils.sse_response import create_sse_response
|
||||
from app.utils.sse_response import SSEResponse, create_sse_response
|
||||
|
||||
router = APIRouter(prefix="/chapters", tags=["章节管理"])
|
||||
logger = get_logger(__name__)
|
||||
@@ -1172,7 +1172,6 @@ async def generate_chapter_content_stream(
|
||||
"""
|
||||
style_id = generate_request.style_id
|
||||
target_word_count = generate_request.target_word_count or 3000
|
||||
enable_mcp = generate_request.enable_mcp if hasattr(generate_request, 'enable_mcp') else True
|
||||
custom_model = generate_request.model if hasattr(generate_request, 'model') else None
|
||||
temp_narrative_perspective = generate_request.narrative_perspective if hasattr(generate_request, 'narrative_perspective') else None
|
||||
# 预先验证章节存在性(使用临时会话)
|
||||
@@ -1211,25 +1210,36 @@ async def generate_chapter_content_stream(
|
||||
# 获取当前用户ID(在生成器外部就需要)
|
||||
current_user_id = getattr(request.state, "user_id", "system")
|
||||
|
||||
# 初始化标准进度追踪器
|
||||
from app.utils.sse_response import WizardProgressTracker
|
||||
tracker = WizardProgressTracker("章节")
|
||||
|
||||
try:
|
||||
yield await tracker.start()
|
||||
|
||||
# 创建新的数据库会话
|
||||
async for db_session in get_db(request):
|
||||
# === 加载阶段 ===
|
||||
yield await tracker.loading("加载章节信息...", 0.2)
|
||||
|
||||
# 重新获取章节信息
|
||||
chapter_result = await db_session.execute(
|
||||
select(Chapter).where(Chapter.id == chapter_id)
|
||||
)
|
||||
current_chapter = chapter_result.scalar_one_or_none()
|
||||
if not current_chapter:
|
||||
yield f"data: {json.dumps({'type': 'error', 'error': '章节不存在'}, ensure_ascii=False)}\n\n"
|
||||
yield await tracker.error("章节不存在", 404)
|
||||
return
|
||||
|
||||
yield await tracker.loading("加载项目信息...", 0.4)
|
||||
|
||||
# 获取项目信息
|
||||
project_result = await db_session.execute(
|
||||
select(Project).where(Project.id == current_chapter.project_id)
|
||||
)
|
||||
project = project_result.scalar_one_or_none()
|
||||
if not project:
|
||||
yield f"data: {json.dumps({'type': 'error', 'error': '项目不存在'}, ensure_ascii=False)}\n\n"
|
||||
yield await tracker.error("项目不存在", 404)
|
||||
return
|
||||
|
||||
# 获取项目的大纲模式
|
||||
@@ -1333,80 +1343,7 @@ async def generate_chapter_content_stream(
|
||||
logger.info(f" - 相关记忆: {chapter_context.context_stats.get('memory_count', 0)} 条")
|
||||
logger.info(f" - 总上下文长度: {chapter_context.context_stats.get('total_length', 0)} 字符")
|
||||
|
||||
# 发送开始事件
|
||||
yield f"data: {json.dumps({'type': 'start', 'message': '开始AI创作...'}, ensure_ascii=False)}\n\n"
|
||||
|
||||
# 发送初始进度0%
|
||||
yield f"data: {json.dumps({'type': 'progress', 'progress': 0, 'message': '准备生成...', 'status': 'processing'}, ensure_ascii=False)}\n\n"
|
||||
|
||||
# 🔧 MCP工具增强:收集章节参考资料(优化版)
|
||||
mcp_reference_materials = ""
|
||||
if enable_mcp and current_user_id:
|
||||
try:
|
||||
# 1️⃣ 静默检查工具可用性
|
||||
from app.services.mcp_tool_service import mcp_tool_service
|
||||
available_tools = await mcp_tool_service.get_user_enabled_tools(
|
||||
user_id=current_user_id,
|
||||
db_session=db_session
|
||||
)
|
||||
|
||||
# 2️⃣ 只在有工具时才显示消息和调用
|
||||
if available_tools:
|
||||
yield f"data: {json.dumps({'type': 'progress', 'message': '🔍 使用MCP工具收集参考资料...', 'progress': 28}, ensure_ascii=False)}\n\n"
|
||||
|
||||
# 构建资料收集提示词
|
||||
planning_prompt = f"""你正在为小说《{project.title}》创作第{current_chapter.chapter_number}章《{current_chapter.title}》。
|
||||
|
||||
【章节大纲】
|
||||
{outline.content if outline else current_chapter.summary or '暂无大纲'}
|
||||
|
||||
【小说信息】
|
||||
- 题材:{project.genre or '未设定'}
|
||||
- 主题:{project.theme or '未设定'}
|
||||
- 时代背景:{project.world_time_period or '未设定'}
|
||||
- 地理位置:{project.world_location or '未设定'}
|
||||
|
||||
【任务】
|
||||
请使用可用工具搜索相关背景资料,帮助创作更真实、更有深度的章节内容。
|
||||
你可以查询:
|
||||
1. 该章节涉及的历史事件或时代背景
|
||||
2. 地理环境和场景描写参考
|
||||
3. 相关领域的专业知识(如武术、科技、魔法等)
|
||||
4. 文化习俗和生活细节
|
||||
|
||||
请根据章节内容,有针对性地查询1-2个最关键的问题。"""
|
||||
|
||||
# 调用MCP增强的AI(非流式,限制1轮避免超时)
|
||||
planning_result = await user_ai_service.generate_text_with_mcp(
|
||||
prompt=planning_prompt,
|
||||
user_id=current_user_id,
|
||||
db_session=db_session,
|
||||
enable_mcp=True,
|
||||
max_tool_rounds=2, # ✅ 减少为1轮,避免超时
|
||||
tool_choice="auto",
|
||||
provider=None,
|
||||
model=None
|
||||
)
|
||||
|
||||
# 3️⃣ 提取参考资料并显示结果
|
||||
if planning_result.get("tool_calls_made", 0) > 0:
|
||||
tool_count = planning_result["tool_calls_made"]
|
||||
yield f"data: {json.dumps({'type': 'progress', 'message': f'✅ MCP工具调用成功({tool_count}次)', 'progress': 32}, ensure_ascii=False)}\n\n"
|
||||
mcp_reference_materials = planning_result.get("content", "")
|
||||
logger.info(f"📚 MCP工具收集参考资料:{len(mcp_reference_materials)} 字符")
|
||||
else:
|
||||
yield f"data: {json.dumps({'type': 'progress', 'message': 'ℹ️ MCP未使用工具,继续', 'progress': 32}, ensure_ascii=False)}\n\n"
|
||||
else:
|
||||
logger.debug(f"用户 {current_user_id} 未启用MCP工具,跳过MCP增强")
|
||||
# 未启用MCP时也发送进度,保持连贯性
|
||||
yield f"data: {json.dumps({'type': 'progress', 'message': '准备生成内容...', 'progress': 10}, ensure_ascii=False)}\n\n"
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"⚠️ MCP工具调用失败,降级为基础模式: {str(e)}")
|
||||
yield f"data: {json.dumps({'type': 'progress', 'message': '⚠️ MCP工具暂时不可用,使用基础模式', 'progress': 10}, ensure_ascii=False)}\n\n"
|
||||
else:
|
||||
# 如果未启用MCP,也发送基础进度
|
||||
yield f"data: {json.dumps({'type': 'progress', 'message': '开始构建创作上下文...', 'progress': 10}, ensure_ascii=False)}\n\n"
|
||||
yield await tracker.loading("上下文构建完成", 0.8)
|
||||
|
||||
# 🎭 确定使用的叙事人称(临时指定 > 项目默认 > 系统默认)
|
||||
chapter_perspective = (
|
||||
@@ -1496,26 +1433,17 @@ async def generate_chapter_content_stream(
|
||||
characters_info=characters_info or '暂无角色信息'
|
||||
)
|
||||
|
||||
# 添加 MCP 参考资料(如果有)
|
||||
if mcp_reference_materials:
|
||||
mcp_section = f"\n\n<mcp_reference>\n{mcp_reference_materials}\n</mcp_reference>"
|
||||
base_prompt = base_prompt.replace("</task>", f"{mcp_section}\n</task>")
|
||||
logger.info(f"📖 已整合MCP参考资料({len(mcp_reference_materials)}字符)")
|
||||
|
||||
# 应用写作风格
|
||||
if style_content:
|
||||
prompt = WritingStyleManager.apply_style_to_prompt(base_prompt, style_content)
|
||||
else:
|
||||
prompt = base_prompt
|
||||
|
||||
if mcp_reference_materials:
|
||||
logger.info(f"📖 已整合MCP参考资料({len(mcp_reference_materials)}字符)到章节生成提示词")
|
||||
# === 准备阶段 ===
|
||||
yield await tracker.preparing("准备AI提示词...")
|
||||
|
||||
logger.info(f"开始AI流式创作章节 {chapter_id}")
|
||||
|
||||
# 发送开始生成的进度
|
||||
yield f"data: {json.dumps({'type': 'progress', 'progress': 10, 'message': '开始AI创作...', 'status': 'processing'}, ensure_ascii=False)}\n\n"
|
||||
|
||||
# 🎨 方案一:将写作风格注入到系统提示词(最高优先级)
|
||||
system_prompt_with_style = None
|
||||
if style_content:
|
||||
@@ -1530,7 +1458,8 @@ async def generate_chapter_content_stream(
|
||||
# 准备生成参数
|
||||
generate_kwargs = {
|
||||
"prompt": prompt,
|
||||
"system_prompt": system_prompt_with_style # 🔑 关键:使用系统提示词传递风格
|
||||
"system_prompt": system_prompt_with_style,
|
||||
"tool_choice": "required"
|
||||
}
|
||||
if custom_model:
|
||||
logger.info(f" 使用自定义模型: {custom_model}")
|
||||
@@ -1538,47 +1467,38 @@ async def generate_chapter_content_stream(
|
||||
# 注意:这里使用用户配置的AI服务,模型参数会覆盖默认模型
|
||||
# 如果需要切换provider,需要在前端传递provider参数
|
||||
|
||||
# 流式生成内容
|
||||
# === 生成阶段 ===
|
||||
full_content = ""
|
||||
chunk_count = 0
|
||||
last_progress = 0
|
||||
|
||||
yield await tracker.generating(
|
||||
current_chars=0,
|
||||
estimated_total=target_word_count
|
||||
)
|
||||
|
||||
async for chunk in user_ai_service.generate_text_stream(**generate_kwargs):
|
||||
full_content += chunk
|
||||
chunk_count += 1
|
||||
|
||||
# 发送内容块
|
||||
yield f"data: {json.dumps({'type': 'content', 'content': chunk}, ensure_ascii=False)}\n\n"
|
||||
yield await tracker.generating_chunk(chunk)
|
||||
|
||||
# 每5个chunk发送一次进度更新(10-95%,更平滑)
|
||||
# 每5个chunk发送一次进度更新
|
||||
if chunk_count % 5 == 0:
|
||||
current_word_count = len(full_content)
|
||||
# 优化进度计算:使用更平滑的递增方式
|
||||
# 基于chunk数量和字数的混合计算,避免大幅跳跃
|
||||
chunk_progress = min(40, chunk_count // 5) # chunk贡献最多40%
|
||||
word_progress = min(45, int((current_word_count / target_word_count) * 45)) # 字数贡献最多45%
|
||||
estimated_progress = min(95, 10 + chunk_progress + word_progress)
|
||||
|
||||
# 只在进度变化时发送
|
||||
if estimated_progress > last_progress:
|
||||
progress_data = {
|
||||
'type': 'progress',
|
||||
'progress': estimated_progress,
|
||||
'message': f'正在创作中... 已生成 {current_word_count} 字',
|
||||
'word_count': current_word_count,
|
||||
'status': 'processing'
|
||||
}
|
||||
yield f"data: {json.dumps(progress_data, ensure_ascii=False)}\n\n"
|
||||
last_progress = estimated_progress
|
||||
yield await tracker.generating(
|
||||
current_chars=len(full_content),
|
||||
estimated_total=target_word_count,
|
||||
message=f'正在创作中... 已生成 {len(full_content)} 字'
|
||||
)
|
||||
|
||||
# 每20个chunk发送心跳
|
||||
if chunk_count % 20 == 0:
|
||||
yield f"data: {json.dumps({'type': 'heartbeat'}, ensure_ascii=False)}\n\n"
|
||||
yield await tracker.heartbeat()
|
||||
|
||||
await asyncio.sleep(0) # 让出控制权
|
||||
|
||||
# 发送保存进度
|
||||
yield f"data: {json.dumps({'type': 'progress', 'progress': 97, 'message': '正在保存章节...', 'status': 'processing'}, ensure_ascii=False)}\n\n"
|
||||
# === 保存阶段 ===
|
||||
yield await tracker.saving("正在保存章节...", 0.3)
|
||||
|
||||
# 更新章节内容到数据库
|
||||
old_word_count = current_chapter.word_count or 0
|
||||
@@ -1634,25 +1554,28 @@ async def generate_chapter_content_stream(
|
||||
ai_service=user_ai_service
|
||||
)
|
||||
|
||||
# 发送最终进度100%
|
||||
yield f"data: {json.dumps({'type': 'progress', 'progress': 99, 'message': '创作完成!', 'word_count': new_word_count, 'status': 'success'}, ensure_ascii=False)}\n\n"
|
||||
yield await tracker.saving("章节保存完成", 0.8)
|
||||
|
||||
# 发送完成事件(包含分析任务ID)
|
||||
completion_data = {
|
||||
'type': 'done',
|
||||
'message': '创作完成',
|
||||
# === 完成阶段 ===
|
||||
yield await tracker.complete("创作完成!")
|
||||
|
||||
# 发送结果数据
|
||||
yield await tracker.result({
|
||||
'word_count': new_word_count,
|
||||
'analysis_task_id': task_id
|
||||
}
|
||||
yield f"data: {json.dumps(completion_data, ensure_ascii=False)}\n\n"
|
||||
})
|
||||
|
||||
# 发送分析开始事件
|
||||
analysis_started_data = {
|
||||
'type': 'analysis_started',
|
||||
'task_id': task_id,
|
||||
'message': '章节分析已开始'
|
||||
}
|
||||
yield f"data: {json.dumps(analysis_started_data, ensure_ascii=False)}\n\n"
|
||||
# 发送分析开始事件(使用自定义事件)
|
||||
yield await SSEResponse.send_event(
|
||||
event='analysis_started',
|
||||
data={
|
||||
'task_id': task_id,
|
||||
'message': '章节分析已开始'
|
||||
}
|
||||
)
|
||||
|
||||
# 发送完成信号
|
||||
yield await tracker.done()
|
||||
|
||||
break # 退出async for db_session循环
|
||||
|
||||
@@ -1675,7 +1598,7 @@ async def generate_chapter_content_stream(
|
||||
logger.info("章节生成事务已回滚(异常)")
|
||||
except Exception as rollback_error:
|
||||
logger.error(f"回滚失败: {str(rollback_error)}")
|
||||
yield f"data: {json.dumps({'type': 'error', 'error': str(e)}, ensure_ascii=False)}\n\n"
|
||||
yield await tracker.error(str(e))
|
||||
finally:
|
||||
# 确保数据库会话被正确关闭
|
||||
if db_session:
|
||||
@@ -2813,7 +2736,8 @@ async def generate_single_chapter_for_batch(
|
||||
# 准备生成参数
|
||||
generate_kwargs = {
|
||||
"prompt": prompt,
|
||||
"system_prompt": system_prompt_with_style # 🔑 关键:使用系统提示词传递风格
|
||||
"system_prompt": system_prompt_with_style,
|
||||
"tool_choice": "required"
|
||||
}
|
||||
# 如果传入了自定义模型,使用指定的模型
|
||||
if custom_model:
|
||||
@@ -3029,11 +2953,16 @@ async def regenerate_chapter_stream(
|
||||
db_session = None
|
||||
db_committed = False
|
||||
|
||||
# 初始化标准进度追踪器
|
||||
from app.utils.sse_response import WizardProgressTracker
|
||||
tracker = WizardProgressTracker("章节重新生成")
|
||||
|
||||
try:
|
||||
yield await tracker.start()
|
||||
|
||||
# 创建独立数据库会话
|
||||
async for db_session in get_db(request):
|
||||
# 发送开始事件
|
||||
yield f"data: {json.dumps({'type': 'start', 'message': '开始重新生成章节...'}, ensure_ascii=False)}\n\n"
|
||||
yield await tracker.loading("加载章节信息...", 0.5)
|
||||
|
||||
# 创建重新生成任务
|
||||
regen_task = RegenerationTask(
|
||||
@@ -3062,13 +2991,25 @@ async def regenerate_chapter_stream(
|
||||
task_id = regen_task.id
|
||||
logger.info(f"📝 创建重新生成任务: {task_id}")
|
||||
|
||||
yield f"data: {json.dumps({'type': 'task_created', 'task_id': task_id}, ensure_ascii=False)}\n\n"
|
||||
yield await tracker.preparing("准备重新生成...")
|
||||
|
||||
yield await SSEResponse.send_event(
|
||||
event='task_created',
|
||||
data={'task_id': task_id}
|
||||
)
|
||||
|
||||
# 初始化重新生成器
|
||||
regenerator = ChapterRegenerator(user_ai_service)
|
||||
|
||||
# 流式生成新内容
|
||||
# === 生成阶段 ===
|
||||
full_content = ""
|
||||
estimated_total = regenerate_request.target_word_count or len(chapter.content)
|
||||
|
||||
yield await tracker.generating(
|
||||
current_chars=0,
|
||||
estimated_total=estimated_total
|
||||
)
|
||||
|
||||
async for event in regenerator.regenerate_with_feedback(
|
||||
chapter=chapter,
|
||||
analysis=analysis,
|
||||
@@ -3083,19 +3024,35 @@ async def regenerate_chapter_stream(
|
||||
# 内容块
|
||||
chunk = event['content']
|
||||
full_content += chunk
|
||||
yield f"data: {json.dumps({'type': 'chunk', 'content': chunk}, ensure_ascii=False)}\n\n"
|
||||
yield await tracker.generating_chunk(chunk)
|
||||
|
||||
# 定期更新进度
|
||||
if len(full_content) % 500 == 0:
|
||||
yield await tracker.generating(
|
||||
current_chars=len(full_content),
|
||||
estimated_total=estimated_total,
|
||||
message=f'重新生成中... 已生成 {len(full_content)} 字'
|
||||
)
|
||||
elif event['type'] == 'progress':
|
||||
# 进度更新
|
||||
progress_data = {
|
||||
'type': 'progress',
|
||||
'progress': event.get('progress', 0),
|
||||
'message': event.get('message', ''),
|
||||
'word_count': event.get('word_count', 0)
|
||||
}
|
||||
yield f"data: {json.dumps(progress_data, ensure_ascii=False)}\n\n"
|
||||
# 进度更新 - 映射到对应阶段
|
||||
progress = event.get('progress', 0)
|
||||
message = event.get('message', '')
|
||||
if progress < 20:
|
||||
yield await tracker.preparing(message)
|
||||
elif progress < 85:
|
||||
yield await tracker.generating(
|
||||
current_chars=len(full_content),
|
||||
estimated_total=estimated_total,
|
||||
message=message
|
||||
)
|
||||
else:
|
||||
yield await tracker.parsing(message)
|
||||
|
||||
await asyncio.sleep(0)
|
||||
|
||||
# === 保存阶段 ===
|
||||
yield await tracker.saving("保存重新生成的内容...", 0.5)
|
||||
|
||||
# 更新任务状态
|
||||
regen_task.status = 'completed'
|
||||
regen_task.regenerated_content = full_content
|
||||
@@ -3108,25 +3065,22 @@ async def regenerate_chapter_stream(
|
||||
await db_session.commit()
|
||||
db_committed = True
|
||||
|
||||
# 先发送结果数据
|
||||
result_data = {
|
||||
'type': 'result',
|
||||
'data': {
|
||||
'task_id': task_id,
|
||||
'word_count': len(full_content),
|
||||
'version_number': regen_task.version_number,
|
||||
'auto_applied': regenerate_request.auto_apply,
|
||||
'diff_stats': diff_stats
|
||||
}
|
||||
}
|
||||
yield f"data: {json.dumps(result_data, ensure_ascii=False)}\n\n"
|
||||
yield await tracker.saving("保存完成", 0.9)
|
||||
|
||||
# 再发送完成事件
|
||||
completion_data = {
|
||||
'type': 'done',
|
||||
'message': '重新生成完成'
|
||||
}
|
||||
yield f"data: {json.dumps(completion_data, ensure_ascii=False)}\n\n"
|
||||
# === 完成阶段 ===
|
||||
yield await tracker.complete("重新生成完成!")
|
||||
|
||||
# 发送结果数据
|
||||
yield await tracker.result({
|
||||
'task_id': task_id,
|
||||
'word_count': len(full_content),
|
||||
'version_number': regen_task.version_number,
|
||||
'auto_applied': regenerate_request.auto_apply,
|
||||
'diff_stats': diff_stats
|
||||
})
|
||||
|
||||
# 发送完成信号
|
||||
yield await tracker.done()
|
||||
|
||||
logger.info(f"✅ 章节重新生成完成: {chapter_id}, 任务: {task_id}")
|
||||
|
||||
@@ -3151,7 +3105,7 @@ async def regenerate_chapter_stream(
|
||||
except Exception as update_error:
|
||||
logger.error(f"更新任务失败状态失败: {str(update_error)}")
|
||||
|
||||
yield f"data: {json.dumps({'type': 'error', 'error': str(e)}, ensure_ascii=False)}\n\n"
|
||||
yield await tracker.error(str(e))
|
||||
|
||||
finally:
|
||||
if db_session:
|
||||
|
||||
+28
-130
@@ -775,148 +775,46 @@ async def generate_character_stream(
|
||||
logger.info(f"🎯 开始为项目 {request.project_id} 生成角色(SSE流式)")
|
||||
|
||||
try:
|
||||
# 🔧 MCP工具增强:静默检查并收集参考资料
|
||||
# 直接使用 AIService 流式生成
|
||||
ai_response = ""
|
||||
chunk_count = 0
|
||||
|
||||
if user_id:
|
||||
try:
|
||||
from app.services.mcp_tool_service import mcp_tool_service
|
||||
available_tools = await mcp_tool_service.get_user_enabled_tools(
|
||||
user_id=user_id,
|
||||
db_session=db
|
||||
)
|
||||
|
||||
# 只在有工具时才调用
|
||||
if available_tools:
|
||||
logger.info(f"🔍 检测到可用MCP工具,尝试收集参考资料...")
|
||||
result = await user_ai_service.generate_text_with_mcp(
|
||||
prompt=prompt,
|
||||
user_id=user_id,
|
||||
db_session=db,
|
||||
enable_mcp=True,
|
||||
max_tool_rounds=2,
|
||||
tool_choice="auto",
|
||||
provider=None,
|
||||
model=None
|
||||
)
|
||||
|
||||
if isinstance(result, dict):
|
||||
ai_response = result.get('content', '')
|
||||
finish_reason = result.get('finish_reason', '')
|
||||
tool_calls_made = result.get('tool_calls_made', 0)
|
||||
|
||||
# 🔧 修复:检查工具调用是否真正成功
|
||||
if tool_calls_made > 0:
|
||||
if finish_reason == 'tool_error':
|
||||
logger.warning(f"⚠️ MCP工具调用失败,降级为基础模式")
|
||||
# 工具调用失败,重新用基础模式生成
|
||||
ai_response = ""
|
||||
elif not ai_response.strip():
|
||||
logger.warning(f"⚠️ MCP工具调用后返回空响应,降级为基础模式")
|
||||
# 工具调用成功但返回空内容,重新生成
|
||||
ai_response = ""
|
||||
else:
|
||||
logger.info(f"✅ MCP工具调用成功({tool_calls_made}次),内容长度: {len(ai_response)}")
|
||||
# MCP成功且有内容,模拟流式输出(分块发送)
|
||||
chunk_size = 50
|
||||
for i in range(0, len(ai_response), chunk_size):
|
||||
chunk = ai_response[i:i+chunk_size]
|
||||
chunk_count += 1
|
||||
yield await SSEResponse.send_chunk(chunk)
|
||||
|
||||
if chunk_count % 3 == 0:
|
||||
yield await SSEResponse.send_progress(
|
||||
f"AI生成角色中... ({i+len(chunk)}/{len(ai_response)}字符)",
|
||||
10 + min(85 * (i+len(chunk)) // len(ai_response), 85)
|
||||
)
|
||||
|
||||
# 跳过后续的流式生成
|
||||
ai_response = result.get('content', '')
|
||||
else:
|
||||
ai_response = result
|
||||
|
||||
# 如果MCP调用失败或返回空,继续走流式生成
|
||||
if not ai_response or not ai_response.strip():
|
||||
logger.info(f"🔄 开始流式生成...")
|
||||
ai_response = ""
|
||||
async for chunk in user_ai_service.generate_text_stream(prompt=prompt):
|
||||
chunk_count += 1
|
||||
ai_response += chunk
|
||||
|
||||
# 发送内容块
|
||||
yield await SSEResponse.send_chunk(chunk)
|
||||
|
||||
# 定期更新进度
|
||||
if chunk_count % 5 == 0:
|
||||
yield await SSEResponse.send_progress(
|
||||
f"AI生成角色中... ({len(ai_response)}字符)",
|
||||
10 + min(chunk_count // 2, 85)
|
||||
)
|
||||
|
||||
# 心跳
|
||||
if chunk_count % 20 == 0:
|
||||
yield await SSEResponse.send_heartbeat()
|
||||
else:
|
||||
logger.debug(f"用户 {user_id} 未启用MCP工具,使用流式基础模式")
|
||||
async for chunk in user_ai_service.generate_text_stream(prompt=prompt):
|
||||
chunk_count += 1
|
||||
ai_response += chunk
|
||||
|
||||
# 发送内容块
|
||||
yield await SSEResponse.send_chunk(chunk)
|
||||
|
||||
# 定期更新进度
|
||||
if chunk_count % 5 == 0:
|
||||
yield await SSEResponse.send_progress(
|
||||
f"AI生成角色中... ({len(ai_response)}字符)",
|
||||
10 + min(chunk_count // 2, 85)
|
||||
)
|
||||
|
||||
# 心跳
|
||||
if chunk_count % 20 == 0:
|
||||
yield await SSEResponse.send_heartbeat()
|
||||
|
||||
except Exception as mcp_error:
|
||||
logger.warning(f"⚠️ MCP工具调用异常,降级为流式基础模式: {str(mcp_error)}")
|
||||
ai_response = ""
|
||||
async for chunk in user_ai_service.generate_text_stream(prompt=prompt):
|
||||
chunk_count += 1
|
||||
ai_response += chunk
|
||||
|
||||
# 发送内容块
|
||||
yield await SSEResponse.send_chunk(chunk)
|
||||
|
||||
# 定期更新进度
|
||||
if chunk_count % 5 == 0:
|
||||
yield await SSEResponse.send_progress(
|
||||
f"AI生成角色中... ({len(ai_response)}字符)",
|
||||
10 + min(chunk_count // 2, 85)
|
||||
)
|
||||
|
||||
# 心跳
|
||||
if chunk_count % 20 == 0:
|
||||
yield await SSEResponse.send_heartbeat()
|
||||
else:
|
||||
logger.debug(f"未登录用户,使用流式基础模式")
|
||||
async for chunk in user_ai_service.generate_text_stream(prompt=prompt):
|
||||
chunk_count += 1
|
||||
ai_response += chunk
|
||||
logger.info(f"🎯 开始生成角色(流式模式)...")
|
||||
yield await SSEResponse.send_progress("🎯 开始生成角色...", 15)
|
||||
|
||||
async for chunk in user_ai_service.generate_text_stream(
|
||||
prompt=prompt,
|
||||
tool_choice="required",
|
||||
):
|
||||
# chunk 现在可能是 dict 或 str,提取 content 字段
|
||||
if isinstance(chunk, dict):
|
||||
content = chunk.get("content", "")
|
||||
else:
|
||||
content = chunk
|
||||
|
||||
if content:
|
||||
ai_response += content
|
||||
|
||||
# 发送内容块
|
||||
yield await SSEResponse.send_chunk(chunk)
|
||||
yield await SSEResponse.send_chunk(content)
|
||||
|
||||
# 定期更新进度
|
||||
if chunk_count % 5 == 0:
|
||||
# 定期更新进度(每收到约500字符更新一次,避免过于频繁)
|
||||
current_len = len(ai_response)
|
||||
if current_len >= chunk_count * 500:
|
||||
chunk_count += 1
|
||||
# 使用实际字符数量计算进度,上限85%(留15%给后续解析和保存)
|
||||
# 估算最终字符数约为提示词的8倍,最少3000字符
|
||||
estimated_total = max(3000, len(prompt) * 8)
|
||||
progress = min(15 + int(current_len / estimated_total * 70), 85)
|
||||
yield await SSEResponse.send_progress(
|
||||
f"AI生成角色中... ({len(ai_response)}字符)",
|
||||
10 + min(chunk_count // 2, 85)
|
||||
f"AI生成角色中... ({current_len}字符)",
|
||||
progress
|
||||
)
|
||||
|
||||
# 心跳
|
||||
if chunk_count % 20 == 0:
|
||||
yield await SSEResponse.send_heartbeat()
|
||||
|
||||
|
||||
except Exception as ai_error:
|
||||
logger.error(f"❌ AI服务调用异常:{str(ai_error)}")
|
||||
yield await SSEResponse.send_error(f"AI服务调用失败:{str(ai_error)}")
|
||||
|
||||
+210
-73
@@ -1,4 +1,7 @@
|
||||
"""MCP插件管理API"""
|
||||
"""MCP插件管理API
|
||||
|
||||
重构后使用统一的MCPClientFacade门面来管理所有MCP操作。
|
||||
"""
|
||||
from fastapi import APIRouter, HTTPException, Depends, Query, Request
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select
|
||||
@@ -17,9 +20,8 @@ from app.schemas.mcp_plugin import (
|
||||
)
|
||||
import json
|
||||
from app.user_manager import User
|
||||
from app.mcp.registry import mcp_registry
|
||||
from app.mcp import mcp_client, MCPPluginConfig, PluginStatus
|
||||
from app.services.mcp_test_service import mcp_test_service
|
||||
from app.services.mcp_tool_service import mcp_tool_service
|
||||
from app.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
@@ -34,6 +36,31 @@ def require_login(request: Request) -> User:
|
||||
return request.state.user
|
||||
|
||||
|
||||
async def _register_plugin_to_facade(plugin: MCPPlugin, user_id: str) -> bool:
|
||||
"""
|
||||
将插件注册到统一门面
|
||||
|
||||
Args:
|
||||
plugin: 插件对象
|
||||
user_id: 用户ID
|
||||
|
||||
Returns:
|
||||
是否注册成功
|
||||
"""
|
||||
if plugin.plugin_type in ["http", "streamable_http", "sse"] and plugin.server_url:
|
||||
return await mcp_client.register(MCPPluginConfig(
|
||||
user_id=user_id,
|
||||
plugin_name=plugin.plugin_name,
|
||||
url=plugin.server_url,
|
||||
plugin_type=plugin.plugin_type,
|
||||
headers=plugin.headers,
|
||||
timeout=plugin.config.get('timeout', 60.0) if plugin.config else 60.0
|
||||
))
|
||||
else:
|
||||
logger.warning(f"暂不支持的插件类型: {plugin.plugin_type}")
|
||||
return False
|
||||
|
||||
|
||||
@router.get("", response_model=List[MCPPluginResponse])
|
||||
async def list_plugins(
|
||||
enabled_only: bool = Query(False, description="只返回启用的插件"),
|
||||
@@ -99,9 +126,9 @@ async def create_plugin(
|
||||
await db.commit()
|
||||
await db.refresh(plugin)
|
||||
|
||||
# 如果启用,加载到注册表
|
||||
# 如果启用,注册到统一门面
|
||||
if plugin.enabled:
|
||||
success = await mcp_registry.load_plugin(plugin)
|
||||
success = await _register_plugin_to_facade(plugin, user.user_id)
|
||||
if success:
|
||||
plugin.status = "active"
|
||||
else:
|
||||
@@ -153,7 +180,7 @@ async def create_plugin_simple(
|
||||
# 提取配置
|
||||
server_type = server_config.get("type", "http")
|
||||
|
||||
if server_type not in ["http", "stdio"]:
|
||||
if server_type not in ["http", "stdio", "streamable_http", "sse"]:
|
||||
raise HTTPException(status_code=400, detail=f"不支持的服务器类型: {server_type}")
|
||||
|
||||
# 检查插件名是否已存在
|
||||
@@ -175,12 +202,12 @@ async def create_plugin_simple(
|
||||
"sort_order": 0
|
||||
}
|
||||
|
||||
if server_type == "http":
|
||||
if server_type in ["http", "streamable_http", "sse"]:
|
||||
plugin_data["server_url"] = server_config.get("url")
|
||||
plugin_data["headers"] = server_config.get("headers", {})
|
||||
|
||||
if not plugin_data["server_url"]:
|
||||
raise HTTPException(status_code=400, detail="HTTP类型插件必须提供url字段")
|
||||
raise HTTPException(status_code=400, detail=f"{server_type}类型插件必须提供url字段")
|
||||
|
||||
elif server_type == "stdio":
|
||||
plugin_data["command"] = server_config.get("command")
|
||||
@@ -194,9 +221,9 @@ async def create_plugin_simple(
|
||||
# 更新现有插件
|
||||
logger.info(f"插件 {plugin_name} 已存在,执行更新操作")
|
||||
|
||||
# 先卸载旧插件
|
||||
if existing.enabled:
|
||||
await mcp_registry.unload_plugin(user.user_id, existing.plugin_name)
|
||||
# 保存旧状态
|
||||
old_enabled = existing.enabled
|
||||
old_plugin_name = existing.plugin_name
|
||||
|
||||
# 更新字段
|
||||
for key, value in plugin_data.items():
|
||||
@@ -206,17 +233,24 @@ async def create_plugin_simple(
|
||||
await db.commit()
|
||||
await db.refresh(plugin)
|
||||
|
||||
# 如果启用,重新加载
|
||||
# 数据库完成后进行MCP操作
|
||||
if old_enabled:
|
||||
try:
|
||||
await mcp_client.unregister(user.user_id, old_plugin_name)
|
||||
except Exception as e:
|
||||
logger.warning(f"注销旧插件出错: {e}")
|
||||
|
||||
if plugin.enabled:
|
||||
success = await mcp_registry.load_plugin(plugin)
|
||||
if success:
|
||||
plugin.status = "active"
|
||||
plugin.last_error = None
|
||||
else:
|
||||
try:
|
||||
success = await _register_plugin_to_facade(plugin, user.user_id)
|
||||
plugin.status = "active" if success else "error"
|
||||
plugin.last_error = None if success else "加载失败"
|
||||
await db.commit()
|
||||
except Exception as e:
|
||||
logger.error(f"注册插件失败: {e}")
|
||||
plugin.status = "error"
|
||||
plugin.last_error = "加载失败"
|
||||
await db.commit()
|
||||
await db.refresh(plugin)
|
||||
plugin.last_error = str(e)
|
||||
await db.commit()
|
||||
|
||||
logger.info(f"用户 {user.user_id} 更新插件: {plugin_name}")
|
||||
else:
|
||||
@@ -230,16 +264,18 @@ async def create_plugin_simple(
|
||||
await db.commit()
|
||||
await db.refresh(plugin)
|
||||
|
||||
# 如果启用,加载到注册表
|
||||
# 数据库完成后进行MCP操作
|
||||
if plugin.enabled:
|
||||
success = await mcp_registry.load_plugin(plugin)
|
||||
if success:
|
||||
plugin.status = "active"
|
||||
else:
|
||||
try:
|
||||
success = await _register_plugin_to_facade(plugin, user.user_id)
|
||||
plugin.status = "active" if success else "error"
|
||||
plugin.last_error = None if success else "加载失败"
|
||||
await db.commit()
|
||||
except Exception as e:
|
||||
logger.error(f"注册插件失败: {e}")
|
||||
plugin.status = "error"
|
||||
plugin.last_error = "加载失败"
|
||||
await db.commit()
|
||||
await db.refresh(plugin)
|
||||
plugin.last_error = str(e)
|
||||
await db.commit()
|
||||
|
||||
logger.info(f"用户 {user.user_id} 通过简化配置创建插件: {plugin_name}")
|
||||
|
||||
@@ -306,9 +342,10 @@ async def update_plugin(
|
||||
await db.commit()
|
||||
await db.refresh(plugin)
|
||||
|
||||
# 如果插件已启用,重新加载
|
||||
# 如果插件已启用,重新注册
|
||||
if plugin.enabled:
|
||||
await mcp_registry.reload_plugin(plugin)
|
||||
await mcp_client.unregister(user.user_id, plugin.plugin_name)
|
||||
await _register_plugin_to_facade(plugin, user.user_id)
|
||||
|
||||
logger.info(f"用户 {user.user_id} 更新插件: {plugin.plugin_name}")
|
||||
return plugin
|
||||
@@ -334,8 +371,8 @@ async def delete_plugin(
|
||||
if not plugin:
|
||||
raise HTTPException(status_code=404, detail="插件不存在")
|
||||
|
||||
# 从注册表卸载
|
||||
await mcp_registry.unload_plugin(user.user_id, plugin.plugin_name)
|
||||
# 从统一门面注销
|
||||
await mcp_client.unregister(user.user_id, plugin.plugin_name)
|
||||
|
||||
# 删除数据库记录
|
||||
await db.delete(plugin)
|
||||
@@ -366,27 +403,57 @@ async def toggle_plugin(
|
||||
if not plugin:
|
||||
raise HTTPException(status_code=404, detail="插件不存在")
|
||||
|
||||
plugin.enabled = enabled
|
||||
# 保存插件信息用于后续MCP操作
|
||||
plugin_name = plugin.plugin_name
|
||||
plugin_type = plugin.plugin_type
|
||||
server_url = plugin.server_url
|
||||
headers = plugin.headers
|
||||
config = plugin.config
|
||||
|
||||
if enabled:
|
||||
# 启用:加载到注册表
|
||||
success = await mcp_registry.load_plugin(plugin)
|
||||
if success:
|
||||
plugin.status = "active"
|
||||
plugin.last_error = None
|
||||
else:
|
||||
plugin.status = "error"
|
||||
plugin.last_error = "加载失败"
|
||||
else:
|
||||
# 禁用:从注册表卸载
|
||||
await mcp_registry.unload_plugin(user.user_id, plugin.plugin_name)
|
||||
# 先更新数据库状态
|
||||
plugin.enabled = enabled
|
||||
if not enabled:
|
||||
plugin.status = "inactive"
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(plugin)
|
||||
|
||||
# 数据库操作完成后,再进行MCP操作
|
||||
if enabled:
|
||||
# 启用:注册到统一门面
|
||||
try:
|
||||
if plugin_type in ["http", "streamable_http", "sse"] and server_url:
|
||||
success = await mcp_client.register(MCPPluginConfig(
|
||||
user_id=user.user_id,
|
||||
plugin_name=plugin_name,
|
||||
url=server_url,
|
||||
plugin_type=plugin_type,
|
||||
headers=headers,
|
||||
timeout=config.get('timeout', 60.0) if config else 60.0
|
||||
))
|
||||
else:
|
||||
success = False
|
||||
|
||||
# 更新状态
|
||||
plugin.status = "active" if success else "error"
|
||||
plugin.last_error = None if success else "加载失败"
|
||||
await db.commit()
|
||||
await db.refresh(plugin)
|
||||
except Exception as e:
|
||||
logger.error(f"注册插件失败: {plugin_name}, 错误: {e}")
|
||||
plugin.status = "error"
|
||||
plugin.last_error = str(e)
|
||||
await db.commit()
|
||||
await db.refresh(plugin)
|
||||
else:
|
||||
# 禁用:从统一门面注销(不影响数据库状态)
|
||||
try:
|
||||
await mcp_client.unregister(user.user_id, plugin_name)
|
||||
except Exception as e:
|
||||
logger.warning(f"注销插件时出错(可忽略): {plugin_name}, 错误: {e}")
|
||||
|
||||
action = "启用" if enabled else "禁用"
|
||||
logger.info(f"用户 {user.user_id} {action}插件: {plugin.plugin_name}")
|
||||
logger.info(f"用户 {user.user_id} {action}插件: {plugin_name}")
|
||||
return plugin
|
||||
|
||||
|
||||
@@ -399,7 +466,7 @@ async def test_plugin(
|
||||
"""
|
||||
测试插件连接并调用工具验证功能
|
||||
|
||||
使用新的MCPTestService进行测试
|
||||
使用MCPTestService进行测试
|
||||
"""
|
||||
|
||||
result = await db.execute(
|
||||
@@ -421,7 +488,7 @@ async def test_plugin(
|
||||
suggestions=["点击开关按钮启用插件"]
|
||||
)
|
||||
|
||||
# 使用新的测试服务
|
||||
# 使用测试服务
|
||||
try:
|
||||
test_result = await mcp_test_service.test_plugin_with_ai(plugin, user, db)
|
||||
|
||||
@@ -447,32 +514,77 @@ async def test_plugin(
|
||||
raise HTTPException(status_code=500, detail=f"测试失败: {str(e)}")
|
||||
|
||||
|
||||
async def _ensure_plugin_loaded(
|
||||
async def _ensure_plugin_registered(
|
||||
plugin: MCPPlugin,
|
||||
user_id: str
|
||||
) -> bool:
|
||||
"""
|
||||
确保插件已加载(共享逻辑)
|
||||
确保插件已注册到统一门面
|
||||
|
||||
Args:
|
||||
plugin: 插件对象
|
||||
user_id: 用户ID
|
||||
|
||||
Returns:
|
||||
是否加载成功
|
||||
是否成功
|
||||
|
||||
Raises:
|
||||
HTTPException: 加载失败
|
||||
HTTPException: 注册失败
|
||||
"""
|
||||
if not mcp_registry.get_client(user_id, plugin.plugin_name):
|
||||
logger.info(f"插件 {plugin.plugin_name} 未加载,自动加载中...")
|
||||
success = await mcp_registry.load_plugin(plugin)
|
||||
try:
|
||||
# 使用ensure_registered方法,它会检查是否已注册
|
||||
if plugin.plugin_type in ["http", "streamable_http", "sse"] and plugin.server_url:
|
||||
return await mcp_client.ensure_registered(
|
||||
user_id=user_id,
|
||||
plugin_name=plugin.plugin_name,
|
||||
url=plugin.server_url,
|
||||
plugin_type=plugin.plugin_type,
|
||||
headers=plugin.headers
|
||||
)
|
||||
return False
|
||||
except ValueError as e:
|
||||
logger.info(f"插件 {plugin.plugin_name} 未注册,自动注册中...")
|
||||
success = await _register_plugin_to_facade(plugin, user_id)
|
||||
if not success:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"插件加载失败: {plugin.plugin_name}"
|
||||
detail=f"插件注册失败: {plugin.plugin_name}"
|
||||
)
|
||||
return True
|
||||
return True
|
||||
|
||||
|
||||
@router.get("/{plugin_id}/status")
|
||||
async def get_plugin_status(
|
||||
plugin_id: str,
|
||||
user: User = Depends(require_login),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""获取插件的实时状态(包括内存中的会话状态)"""
|
||||
result = await db.execute(
|
||||
select(MCPPlugin).where(
|
||||
MCPPlugin.id == plugin_id,
|
||||
MCPPlugin.user_id == user.user_id
|
||||
)
|
||||
)
|
||||
plugin = result.scalar_one_or_none()
|
||||
|
||||
if not plugin:
|
||||
raise HTTPException(status_code=404, detail="插件不存在")
|
||||
|
||||
session_stats = mcp_client.get_session_stats()
|
||||
session_key = f"{user.user_id}:{plugin.plugin_name}"
|
||||
session_info = next((s for s in session_stats.get("sessions", []) if s["key"] == session_key), None)
|
||||
|
||||
return {
|
||||
"plugin_id": plugin_id,
|
||||
"plugin_name": plugin.plugin_name,
|
||||
"db_status": plugin.status,
|
||||
"session_status": session_info["status"] if session_info else None,
|
||||
"is_registered": session_info is not None,
|
||||
"error_rate": session_info["error_rate"] if session_info else 0,
|
||||
"in_sync": (plugin.status == session_info["status"]) if session_info else (plugin.status == "inactive"),
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
|
||||
@router.get("/metrics")
|
||||
@@ -495,7 +607,8 @@ async def get_metrics(
|
||||
- avg_duration_ms: 平均耗时(毫秒)
|
||||
- last_call_time: 最后调用时间
|
||||
"""
|
||||
metrics = mcp_tool_service.get_metrics(tool_name)
|
||||
# 使用统一门面获取指标
|
||||
metrics = mcp_client.get_metrics(tool_name)
|
||||
|
||||
return {
|
||||
"metrics": metrics,
|
||||
@@ -518,7 +631,8 @@ async def get_cache_stats(
|
||||
- cache_ttl_minutes: 缓存TTL(分钟)
|
||||
- entries: 各缓存条目详情
|
||||
"""
|
||||
stats = mcp_tool_service.get_cache_stats()
|
||||
# 使用统一门面获取缓存统计
|
||||
stats = mcp_client.get_cache_stats()
|
||||
|
||||
return {
|
||||
"cache_stats": stats,
|
||||
@@ -526,6 +640,27 @@ async def get_cache_stats(
|
||||
}
|
||||
|
||||
|
||||
@router.get("/sessions/stats")
|
||||
async def get_session_stats(
|
||||
user: User = Depends(require_login)
|
||||
):
|
||||
"""
|
||||
获取MCP会话统计信息
|
||||
|
||||
Returns:
|
||||
会话统计信息,包含:
|
||||
- total_sessions: 会话总数
|
||||
- sessions: 各会话详情
|
||||
"""
|
||||
# 使用统一门面获取会话统计
|
||||
stats = mcp_client.get_session_stats()
|
||||
|
||||
return {
|
||||
"session_stats": stats,
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
|
||||
@router.post("/cache/clear")
|
||||
async def clear_cache(
|
||||
user_id: Optional[str] = Query(None, description="用户ID(可选)"),
|
||||
@@ -551,7 +686,8 @@ async def clear_cache(
|
||||
# 如果没有指定user_id,使用当前用户
|
||||
target_user_id = user_id or user.user_id
|
||||
|
||||
mcp_tool_service.clear_cache(target_user_id, plugin_name)
|
||||
# 使用统一门面清理缓存
|
||||
mcp_client.clear_cache(target_user_id, plugin_name)
|
||||
|
||||
message = "已清理"
|
||||
if plugin_name:
|
||||
@@ -594,12 +730,13 @@ async def get_plugin_tools(
|
||||
raise HTTPException(status_code=400, detail="插件未启用")
|
||||
|
||||
try:
|
||||
# 确保插件已加载
|
||||
await _ensure_plugin_loaded(plugin, user.user_id)
|
||||
# 确保插件已注册
|
||||
await _ensure_plugin_registered(plugin, user.user_id)
|
||||
|
||||
tools = await mcp_registry.get_plugin_tools(user.user_id, plugin.plugin_name)
|
||||
# 使用统一门面获取工具列表
|
||||
tools = await mcp_client.get_tools(user.user_id, plugin.plugin_name)
|
||||
|
||||
# 更新缓存
|
||||
# 更新数据库中的工具缓存
|
||||
plugin.tools = tools
|
||||
await db.commit()
|
||||
|
||||
@@ -640,22 +777,22 @@ async def call_mcp_tool(
|
||||
raise HTTPException(status_code=400, detail="插件未启用")
|
||||
|
||||
try:
|
||||
# 确保插件已加载
|
||||
await _ensure_plugin_loaded(plugin, user.user_id)
|
||||
# 确保插件已注册
|
||||
await _ensure_plugin_registered(plugin, user.user_id)
|
||||
|
||||
# 调用工具
|
||||
result = await mcp_registry.call_tool(
|
||||
user.user_id,
|
||||
plugin.plugin_name,
|
||||
data.tool_name,
|
||||
data.arguments
|
||||
# 使用统一门面调用工具
|
||||
tool_result = await mcp_client.call_tool(
|
||||
user_id=user.user_id,
|
||||
plugin_name=plugin.plugin_name,
|
||||
tool_name=data.tool_name,
|
||||
arguments=data.arguments
|
||||
)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"plugin_name": plugin.plugin_name,
|
||||
"tool_name": data.tool_name,
|
||||
"result": result
|
||||
"result": tool_result
|
||||
}
|
||||
except HTTPException:
|
||||
raise
|
||||
|
||||
+235
-552
File diff suppressed because it is too large
Load Diff
+253
-7
@@ -22,7 +22,7 @@ from app.schemas.settings import (
|
||||
from app.user_manager import User
|
||||
from app.logger import get_logger
|
||||
from app.config import settings as app_settings, PROJECT_ROOT
|
||||
from app.services.ai_service import AIService, create_user_ai_service
|
||||
from app.services.ai_service import AIService, create_user_ai_service, create_user_ai_service_with_mcp
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -53,9 +53,14 @@ async def get_user_ai_service(
|
||||
db: AsyncSession = Depends(get_db)
|
||||
) -> AIService:
|
||||
"""
|
||||
依赖:获取当前用户的AI服务实例
|
||||
从数据库读取用户设置并创建对应的AI服务
|
||||
依赖:获取当前用户的AI服务实例(支持MCP工具自动加载)
|
||||
|
||||
从数据库读取用户设置并创建对应的AI服务。
|
||||
自动传递 user_id 和 db_session,使得 AIService 能够加载用户配置的MCP工具。
|
||||
根据用户的所有MCP插件状态决定是否启用MCP:如果有启用的插件则启用,否则禁用。
|
||||
"""
|
||||
from app.models.mcp_plugin import MCPPlugin
|
||||
|
||||
result = await db.execute(
|
||||
select(Settings).where(Settings.user_id == user.user_id)
|
||||
)
|
||||
@@ -73,15 +78,34 @@ async def get_user_ai_service(
|
||||
await db.refresh(settings)
|
||||
logger.info(f"用户 {user.user_id} 首次使用AI服务,已从.env同步设置到数据库")
|
||||
|
||||
# 使用用户设置创建AI服务实例(包括系统提示词)
|
||||
return create_user_ai_service(
|
||||
# 查询用户的所有MCP插件状态
|
||||
mcp_result = await db.execute(
|
||||
select(MCPPlugin).where(MCPPlugin.user_id == user.user_id)
|
||||
)
|
||||
mcp_plugins = mcp_result.scalars().all()
|
||||
|
||||
# 检查是否有启用的MCP插件
|
||||
enable_mcp = any(plugin.enabled for plugin in mcp_plugins) if mcp_plugins else False
|
||||
|
||||
if mcp_plugins:
|
||||
enabled_count = sum(1 for p in mcp_plugins if p.enabled)
|
||||
logger.info(f"用户 {user.user_id} 有 {len(mcp_plugins)} 个MCP插件,{enabled_count} 个启用,{enable_mcp} 决定使用MCP")
|
||||
else:
|
||||
logger.debug(f"用户 {user.user_id} 没有配置MCP插件,禁用MCP")
|
||||
|
||||
# ✅ 使用支持MCP的工厂函数创建AI服务实例
|
||||
# 传递 user_id 和 db_session,使得 AIService 能够自动加载用户配置的MCP工具
|
||||
return create_user_ai_service_with_mcp(
|
||||
api_provider=settings.api_provider,
|
||||
api_key=settings.api_key,
|
||||
api_base_url=settings.api_base_url or "",
|
||||
model_name=settings.llm_model,
|
||||
temperature=settings.temperature,
|
||||
max_tokens=settings.max_tokens,
|
||||
system_prompt=settings.system_prompt # 传递系统提示词
|
||||
user_id=user.user_id, # ✅ 传递 user_id
|
||||
db_session=db, # ✅ 传递 db_session
|
||||
system_prompt=settings.system_prompt,
|
||||
enable_mcp=enable_mcp, # 根据MCP插件状态动态决定
|
||||
)
|
||||
|
||||
|
||||
@@ -327,6 +351,227 @@ class ApiTestRequest(BaseModel):
|
||||
llm_model: str
|
||||
|
||||
|
||||
@router.post("/check-function-calling")
|
||||
async def check_function_calling_support(data: ApiTestRequest):
|
||||
"""
|
||||
检查模型是否支持 Function Calling(工具调用)
|
||||
|
||||
基于业界最佳实践的测试方法:
|
||||
1. 发送包含工具定义的请求
|
||||
2. 检查响应的 finish_reason 是否为 "tool_calls"
|
||||
3. 验证响应中是否包含有效的 tool_calls 数据
|
||||
|
||||
Args:
|
||||
data: 包含 API 配置的请求数据
|
||||
|
||||
Returns:
|
||||
检测结果包含支持状态、详细信息和建议
|
||||
"""
|
||||
api_key = data.api_key
|
||||
api_base_url = data.api_base_url
|
||||
provider = data.provider
|
||||
llm_model = data.llm_model
|
||||
|
||||
try:
|
||||
start_time = time.time()
|
||||
|
||||
# 定义一个简单的测试工具(天气查询)
|
||||
test_tools = [{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"description": "获取指定城市的当前天气信息",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {
|
||||
"type": "string",
|
||||
"description": "城市名称,例如:北京、上海、深圳"
|
||||
},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
"description": "温度单位"
|
||||
}
|
||||
},
|
||||
"required": ["city"]
|
||||
}
|
||||
}
|
||||
}]
|
||||
|
||||
# 测试提示:故意设计一个需要调用工具的问题
|
||||
test_prompt = "请告诉我北京现在的天气情况如何?"
|
||||
|
||||
logger.info(f"🧪 开始检测 Function Calling 支持")
|
||||
logger.info(f" - 提供商: {provider}")
|
||||
logger.info(f" - 模型: {llm_model}")
|
||||
logger.info(f" - 测试工具: get_weather")
|
||||
|
||||
# 创建临时 AI 服务实例进行测试
|
||||
test_service = AIService(
|
||||
api_provider=provider,
|
||||
api_key=api_key,
|
||||
api_base_url=api_base_url,
|
||||
default_model=llm_model,
|
||||
default_temperature=0.3, # 使用较低温度以获得更确定的行为
|
||||
default_max_tokens=200
|
||||
)
|
||||
|
||||
# 发送带工具的测试请求
|
||||
response = await test_service.generate_text(
|
||||
prompt=test_prompt,
|
||||
provider=provider,
|
||||
model=llm_model,
|
||||
temperature=0.3,
|
||||
max_tokens=200,
|
||||
tools=test_tools,
|
||||
tool_choice="auto", # 让模型自动决定是否使用工具
|
||||
auto_mcp=False # 禁用 MCP 自动加载
|
||||
)
|
||||
|
||||
end_time = time.time()
|
||||
response_time = round((end_time - start_time) * 1000, 2)
|
||||
|
||||
# 分析响应以确定是否支持 Function Calling
|
||||
supported = False
|
||||
finish_reason = None
|
||||
tool_calls = None
|
||||
response_content = None
|
||||
|
||||
if isinstance(response, dict):
|
||||
# 检查 finish_reason(OpenAI 标准)
|
||||
finish_reason = response.get("finish_reason")
|
||||
|
||||
# 检查是否有 tool_calls
|
||||
if "tool_calls" in response and response["tool_calls"]:
|
||||
supported = True
|
||||
tool_calls = response["tool_calls"]
|
||||
logger.info(f"✅ 检测到工具调用: {len(tool_calls)} 个")
|
||||
|
||||
# 记录返回的内容(如果有)
|
||||
if "content" in response:
|
||||
response_content = response["content"]
|
||||
elif isinstance(response, str):
|
||||
# 如果只返回字符串,说明不支持工具调用
|
||||
response_content = response
|
||||
|
||||
logger.info(f" - 响应时间: {response_time}ms")
|
||||
logger.info(f" - finish_reason: {finish_reason}")
|
||||
logger.info(f" - 支持状态: {'✅ 支持' if supported else '❌ 不支持'}")
|
||||
|
||||
# 构建详细的返回信息
|
||||
result = {
|
||||
"success": True,
|
||||
"supported": supported,
|
||||
"message": "✅ 模型支持 Function Calling" if supported else "❌ 模型不支持 Function Calling",
|
||||
"response_time_ms": response_time,
|
||||
"provider": provider,
|
||||
"model": llm_model,
|
||||
"details": {
|
||||
"finish_reason": finish_reason,
|
||||
"has_tool_calls": bool(tool_calls),
|
||||
"tool_call_count": len(tool_calls) if tool_calls else 0,
|
||||
"test_tool": "get_weather",
|
||||
"test_prompt": test_prompt,
|
||||
"response_type": "tool_calls" if supported else "text"
|
||||
}
|
||||
}
|
||||
|
||||
# 添加工具调用详情
|
||||
if tool_calls:
|
||||
result["tool_calls"] = tool_calls
|
||||
result["suggestions"] = [
|
||||
"✅ 该模型支持 Function Calling,可以正常使用 MCP 插件",
|
||||
"建议:启用需要的 MCP 插件以扩展 AI 能力",
|
||||
"提示:测试成功检测到工具调用,模型能够正确解析和使用外部工具"
|
||||
]
|
||||
else:
|
||||
result["response_preview"] = response_content[:200] if response_content else None
|
||||
result["suggestions"] = [
|
||||
"❌ 该模型不支持 Function Calling,无法使用 MCP 插件功能",
|
||||
"建议:更换支持工具调用的模型",
|
||||
"推荐模型:GPT-4 系列、GPT-4-turbo、Claude 3 Opus/Sonnet、Gemini 1.5 Pro 等",
|
||||
"说明:模型返回了文本回复而非工具调用,表明不支持该功能"
|
||||
]
|
||||
|
||||
return result
|
||||
|
||||
except ValueError as e:
|
||||
error_msg = str(e)
|
||||
logger.error(f"❌ Function Calling 检测配置错误: {error_msg}")
|
||||
return {
|
||||
"success": False,
|
||||
"supported": False,
|
||||
"message": "配置错误",
|
||||
"error": error_msg,
|
||||
"error_type": "ConfigurationError",
|
||||
"suggestions": [
|
||||
"请检查 API Key 是否正确",
|
||||
"请确认 API Base URL 格式是否正确",
|
||||
"请验证所选提供商与配置是否匹配"
|
||||
]
|
||||
}
|
||||
|
||||
except TimeoutError as e:
|
||||
error_msg = str(e)
|
||||
logger.error(f"❌ Function Calling 检测超时: {error_msg}")
|
||||
return {
|
||||
"success": False,
|
||||
"supported": None,
|
||||
"message": "检测超时",
|
||||
"error": error_msg,
|
||||
"error_type": "TimeoutError",
|
||||
"suggestions": [
|
||||
"请检查网络连接是否正常",
|
||||
"请确认 API 服务是否可访问",
|
||||
"建议:稍后重试或使用其他网络环境"
|
||||
]
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
error_type = type(e).__name__
|
||||
|
||||
logger.error(f"❌ Function Calling 检测失败: {error_msg}")
|
||||
logger.error(f" - 错误类型: {error_type}")
|
||||
|
||||
# 智能分析错误原因
|
||||
suggestions = []
|
||||
if "tool" in error_msg.lower() or "function" in error_msg.lower():
|
||||
suggestions = [
|
||||
"该模型可能不支持 Function Calling 功能",
|
||||
"API 返回了与工具调用相关的错误",
|
||||
"建议:更换支持工具调用的模型或联系 API 提供商"
|
||||
]
|
||||
elif "unauthorized" in error_msg.lower() or "401" in error_msg:
|
||||
suggestions = [
|
||||
"API Key 认证失败",
|
||||
"请检查 API Key 是否正确且有效",
|
||||
"请确认 API Key 是否有足够的权限"
|
||||
]
|
||||
elif "not found" in error_msg.lower() or "404" in error_msg:
|
||||
suggestions = [
|
||||
"模型不存在或不可用",
|
||||
"请检查模型名称是否正确",
|
||||
"请确认该模型在当前 API 中是否可用"
|
||||
]
|
||||
else:
|
||||
suggestions = [
|
||||
"检测过程中遇到未知错误",
|
||||
"建议:检查所有配置参数是否正确",
|
||||
"提示:查看详细错误信息以获取更多线索"
|
||||
]
|
||||
|
||||
return {
|
||||
"success": False,
|
||||
"supported": False,
|
||||
"message": "Function Calling 检测失败",
|
||||
"error": error_msg,
|
||||
"error_type": error_type,
|
||||
"suggestions": suggestions
|
||||
}
|
||||
|
||||
|
||||
@router.post("/test")
|
||||
async def test_api_connection(data: ApiTestRequest):
|
||||
"""
|
||||
@@ -370,7 +615,8 @@ async def test_api_connection(data: ApiTestRequest):
|
||||
provider=provider,
|
||||
model=llm_model,
|
||||
temperature=0.7,
|
||||
max_tokens=8000
|
||||
max_tokens=8000,
|
||||
auto_mcp=False # 测试时不加载MCP工具
|
||||
)
|
||||
|
||||
end_time = time.time()
|
||||
|
||||
+361
-425
File diff suppressed because it is too large
Load Diff
@@ -77,10 +77,8 @@ class Settings(BaseSettings):
|
||||
default_temperature: float = 0.7
|
||||
default_max_tokens: int = 32000
|
||||
|
||||
# MCP适配器配置
|
||||
enable_mcp_adapter: bool = True # 是否启用MCP适配器(自动检测API能力)
|
||||
mcp_adapter_cache_ttl_hours: int = 24 # API能力检测缓存时长(小时)
|
||||
mcp_adapter_auto_fallback: bool = True # 是否启用自动降级(FC失败时切换到提示词注入)
|
||||
# MCP配置
|
||||
mcp_max_rounds: int = 3 # MCP工具调用最大轮数(全局统一控制)
|
||||
|
||||
# LinuxDO OAuth2 配置
|
||||
LINUXDO_CLIENT_ID: Optional[str] = None
|
||||
|
||||
@@ -167,7 +167,7 @@ async def get_db(request: Request):
|
||||
_session_stats["created"] += 1
|
||||
_session_stats["active"] += 1
|
||||
|
||||
logger.debug(f"📊 会话创建 [User:{user_id}][ID:{session_id}] - 活跃:{_session_stats['active']}, 总创建:{_session_stats['created']}, 总关闭:{_session_stats['closed']}")
|
||||
# logger.debug(f"📊 会话创建 [User:{user_id}][ID:{session_id}] - 活跃:{_session_stats['active']}, 总创建:{_session_stats['created']}, 总关闭:{_session_stats['closed']}")
|
||||
|
||||
try:
|
||||
yield session
|
||||
|
||||
@@ -130,11 +130,15 @@ def _configure_third_party_loggers():
|
||||
logging.getLogger('sqlalchemy.dialects').setLevel(logging.WARNING)
|
||||
logging.getLogger('sqlalchemy.orm').setLevel(logging.WARNING)
|
||||
|
||||
# aiosqlite - 异步SQLite,禁用DEBUG日志
|
||||
logging.getLogger('aiosqlite').setLevel(logging.WARNING)
|
||||
|
||||
# Watchfiles - 开发时的文件监控,降低级别
|
||||
logging.getLogger('watchfiles').setLevel(logging.WARNING)
|
||||
|
||||
# httpx - HTTP客户端
|
||||
# httpx/httpcore - HTTP客户端,禁用DEBUG日志
|
||||
logging.getLogger('httpx').setLevel(logging.WARNING)
|
||||
logging.getLogger('httpcore').setLevel(logging.WARNING)
|
||||
|
||||
# openai/anthropic - AI客户端库
|
||||
logging.getLogger('openai').setLevel(logging.WARNING)
|
||||
|
||||
+5
-2
@@ -12,7 +12,7 @@ from app.database import close_db, _session_stats
|
||||
from app.logger import setup_logging, get_logger
|
||||
from app.middleware import RequestIDMiddleware
|
||||
from app.middleware.auth_middleware import AuthMiddleware
|
||||
from app.mcp.registry import mcp_registry
|
||||
from app.mcp import mcp_client, register_status_sync
|
||||
|
||||
setup_logging(
|
||||
level=config_settings.log_level,
|
||||
@@ -27,12 +27,15 @@ logger = get_logger(__name__)
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""应用生命周期管理"""
|
||||
# 注册MCP状态同步服务
|
||||
register_status_sync()
|
||||
|
||||
logger.info("应用启动完成")
|
||||
|
||||
yield
|
||||
|
||||
# 清理MCP插件
|
||||
await mcp_registry.cleanup_all()
|
||||
await mcp_client.cleanup()
|
||||
|
||||
# 清理HTTP客户端池
|
||||
from app.services.ai_service import cleanup_http_clients
|
||||
|
||||
@@ -1,4 +1,36 @@
|
||||
"""MCP插件系统"""
|
||||
from .registry import mcp_registry
|
||||
"""MCP模块 - 统一的MCP客户端管理
|
||||
|
||||
__all__ = ["mcp_registry"]
|
||||
本模块提供MCP(Model Context Protocol)客户端的统一管理接口。
|
||||
|
||||
推荐使用方式:
|
||||
from app.mcp import mcp_client, MCPPluginConfig
|
||||
|
||||
# 注册插件
|
||||
await mcp_client.register(MCPPluginConfig(
|
||||
user_id="user123",
|
||||
plugin_name="exa-search",
|
||||
url="http://localhost:8000/mcp"
|
||||
))
|
||||
|
||||
# 获取工具
|
||||
tools = await mcp_client.get_tools("user123", "exa-search")
|
||||
|
||||
# 调用工具
|
||||
result = await mcp_client.call_tool("user123", "exa-search", "web_search", {"query": "..."})
|
||||
|
||||
# 注册状态变更回调
|
||||
from app.mcp.status_sync import register_status_sync
|
||||
register_status_sync()
|
||||
"""
|
||||
|
||||
from .facade import mcp_client, MCPClientFacade, MCPPluginConfig, MCPError, PluginStatus
|
||||
from .status_sync import register_status_sync
|
||||
|
||||
__all__ = [
|
||||
"mcp_client",
|
||||
"MCPClientFacade",
|
||||
"MCPPluginConfig",
|
||||
"MCPError",
|
||||
"PluginStatus",
|
||||
"register_status_sync",
|
||||
]
|
||||
@@ -1,14 +0,0 @@
|
||||
"""MCP适配器模块 - 支持多种AI API的工具调用方式"""
|
||||
|
||||
from .base import BaseMCPAdapter, AdapterType
|
||||
from .prompt_injection import PromptInjectionAdapter
|
||||
from .function_calling import FunctionCallingAdapter
|
||||
from .universal import UniversalMCPAdapter
|
||||
|
||||
__all__ = [
|
||||
"BaseMCPAdapter",
|
||||
"AdapterType",
|
||||
"PromptInjectionAdapter",
|
||||
"FunctionCallingAdapter",
|
||||
"UniversalMCPAdapter",
|
||||
]
|
||||
@@ -1,89 +0,0 @@
|
||||
"""MCP适配器基类"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from typing import Dict, Any, List, Optional
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
class AdapterType(Enum):
|
||||
"""适配器类型"""
|
||||
FUNCTION_CALLING = "function_calling" # 标准Function Calling
|
||||
PROMPT_INJECTION = "prompt_injection" # 提示词注入
|
||||
REACT = "react" # ReAct模式
|
||||
XML = "xml" # XML标记
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolCallResult:
|
||||
"""工具调用结果"""
|
||||
tool_calls: List[Dict[str, Any]] # 解析出的工具调用
|
||||
raw_response: str # 原始AI响应
|
||||
has_tool_calls: bool # 是否包含工具调用
|
||||
needs_continuation: bool = False # 是否需要继续对话
|
||||
|
||||
|
||||
class BaseMCPAdapter(ABC):
|
||||
"""MCP适配器基类"""
|
||||
|
||||
def __init__(self):
|
||||
self.adapter_type: AdapterType = AdapterType.PROMPT_INJECTION
|
||||
|
||||
@abstractmethod
|
||||
def format_tools_for_prompt(
|
||||
self,
|
||||
tools: List[Dict[str, Any]],
|
||||
user_message: str
|
||||
) -> str:
|
||||
"""
|
||||
将工具列表格式化为提示词
|
||||
|
||||
Args:
|
||||
tools: MCP工具列表
|
||||
user_message: 用户消息
|
||||
|
||||
Returns:
|
||||
格式化后的提示词
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def parse_tool_calls(self, ai_response: str) -> ToolCallResult:
|
||||
"""
|
||||
从AI响应中解析工具调用
|
||||
|
||||
Args:
|
||||
ai_response: AI的原始响应
|
||||
|
||||
Returns:
|
||||
解析结果
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def build_continuation_prompt(
|
||||
self,
|
||||
original_message: str,
|
||||
ai_response: str,
|
||||
tool_results: List[Dict[str, Any]]
|
||||
) -> str:
|
||||
"""
|
||||
构建包含工具结果的继续对话提示词
|
||||
|
||||
Args:
|
||||
original_message: 原始用户消息
|
||||
ai_response: AI响应
|
||||
tool_results: 工具执行结果
|
||||
|
||||
Returns:
|
||||
继续对话的提示词
|
||||
"""
|
||||
pass
|
||||
|
||||
def supports_native_tools(self) -> bool:
|
||||
"""是否支持原生工具调用(如Function Calling)"""
|
||||
return False
|
||||
|
||||
def get_adapter_type(self) -> AdapterType:
|
||||
"""获取适配器类型"""
|
||||
return self.adapter_type
|
||||
@@ -1,171 +0,0 @@
|
||||
"""Function Calling适配器 - 支持原生Function Calling的API"""
|
||||
|
||||
import json
|
||||
from typing import Dict, Any, List
|
||||
from app.mcp.adapters.base import BaseMCPAdapter, AdapterType, ToolCallResult
|
||||
from app.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class FunctionCallingAdapter(BaseMCPAdapter):
|
||||
"""Function Calling适配器 - 用于支持原生工具调用的AI API(如OpenAI)"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.adapter_type = AdapterType.FUNCTION_CALLING
|
||||
|
||||
def supports_native_tools(self) -> bool:
|
||||
"""支持原生工具调用"""
|
||||
return True
|
||||
|
||||
def format_tools_for_prompt(
|
||||
self,
|
||||
tools: List[Dict[str, Any]],
|
||||
user_message: str
|
||||
) -> str:
|
||||
"""
|
||||
Function Calling模式下,工具通过API参数传递,不需要修改提示词
|
||||
|
||||
Returns:
|
||||
原始用户消息
|
||||
"""
|
||||
return user_message
|
||||
|
||||
def get_tools_for_api(self, tools: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取适用于API的工具格式
|
||||
|
||||
Args:
|
||||
tools: MCP工具列表
|
||||
|
||||
Returns:
|
||||
适用于OpenAI Function Calling的工具格式
|
||||
"""
|
||||
return tools
|
||||
|
||||
def parse_tool_calls(self, ai_response: Any) -> ToolCallResult:
|
||||
"""
|
||||
从AI响应中解析工具调用(Function Calling格式)
|
||||
|
||||
Args:
|
||||
ai_response: AI响应对象(通常是OpenAI的ChatCompletion对象)
|
||||
|
||||
Returns:
|
||||
解析结果
|
||||
"""
|
||||
|
||||
try:
|
||||
# 处理不同类型的响应
|
||||
if isinstance(ai_response, dict):
|
||||
# 字典格式(OpenAI API响应)
|
||||
message = ai_response.get("choices", [{}])[0].get("message", {})
|
||||
tool_calls = message.get("tool_calls", [])
|
||||
content = message.get("content", "")
|
||||
|
||||
elif hasattr(ai_response, "choices"):
|
||||
# 对象格式(OpenAI SDK响应)
|
||||
message = ai_response.choices[0].message
|
||||
tool_calls = getattr(message, "tool_calls", None) or []
|
||||
content = getattr(message, "content", "") or ""
|
||||
|
||||
# 转换为字典格式
|
||||
if tool_calls:
|
||||
tool_calls = [
|
||||
{
|
||||
"id": tc.id,
|
||||
"type": tc.type,
|
||||
"function": {
|
||||
"name": tc.function.name,
|
||||
"arguments": tc.function.arguments
|
||||
}
|
||||
}
|
||||
for tc in tool_calls
|
||||
]
|
||||
else:
|
||||
# 字符串格式(降级为文本响应)
|
||||
return ToolCallResult(
|
||||
tool_calls=[],
|
||||
raw_response=str(ai_response),
|
||||
has_tool_calls=False
|
||||
)
|
||||
|
||||
has_tool_calls = len(tool_calls) > 0
|
||||
|
||||
if has_tool_calls:
|
||||
logger.info(f"✅ Function Calling模式解析出 {len(tool_calls)} 个工具调用")
|
||||
for tc in tool_calls:
|
||||
logger.info(f" - {tc['function']['name']}")
|
||||
|
||||
return ToolCallResult(
|
||||
tool_calls=tool_calls,
|
||||
raw_response=content or "",
|
||||
has_tool_calls=has_tool_calls,
|
||||
needs_continuation=has_tool_calls
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 解析Function Calling响应失败: {e}", exc_info=True)
|
||||
return ToolCallResult(
|
||||
tool_calls=[],
|
||||
raw_response=str(ai_response),
|
||||
has_tool_calls=False
|
||||
)
|
||||
|
||||
def build_continuation_prompt(
|
||||
self,
|
||||
original_message: str,
|
||||
ai_response: str,
|
||||
tool_results: List[Dict[str, Any]]
|
||||
) -> str:
|
||||
"""
|
||||
构建包含工具结果的继续对话提示词
|
||||
|
||||
在Function Calling模式下,这通常不需要,因为工具结果会作为消息历史的一部分
|
||||
"""
|
||||
# Function Calling模式下通常通过消息历史传递工具结果
|
||||
# 这里提供一个降级方案
|
||||
results_text = "\n\n".join([
|
||||
f"工具 {r['name']} 的结果:\n{r['content']}"
|
||||
for r in tool_results
|
||||
])
|
||||
|
||||
return f"{original_message}\n\n工具执行结果:\n{results_text}\n\n请基于以上工具结果回答用户的问题。"
|
||||
|
||||
def build_messages_with_tool_results(
|
||||
self,
|
||||
messages: List[Dict[str, Any]],
|
||||
tool_calls: List[Dict[str, Any]],
|
||||
tool_results: List[Dict[str, Any]]
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
构建包含工具结果的消息历史(Function Calling标准格式)
|
||||
|
||||
Args:
|
||||
messages: 原始消息历史
|
||||
tool_calls: AI的工具调用
|
||||
tool_results: 工具执行结果
|
||||
|
||||
Returns:
|
||||
更新后的消息历史
|
||||
"""
|
||||
|
||||
new_messages = messages.copy()
|
||||
|
||||
# 添加助手的工具调用消息
|
||||
new_messages.append({
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": tool_calls
|
||||
})
|
||||
|
||||
# 添加工具结果消息
|
||||
for result in tool_results:
|
||||
new_messages.append({
|
||||
"role": "tool",
|
||||
"tool_call_id": result.get("tool_call_id", ""),
|
||||
"name": result.get("name", ""),
|
||||
"content": result.get("content", "")
|
||||
})
|
||||
|
||||
return new_messages
|
||||
@@ -1,274 +0,0 @@
|
||||
"""提示词注入适配器 - 最通用的MCP工具调用方式"""
|
||||
|
||||
import re
|
||||
import json
|
||||
from typing import Dict, Any, List
|
||||
from app.mcp.adapters.base import BaseMCPAdapter, AdapterType, ToolCallResult
|
||||
from app.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class PromptInjectionAdapter(BaseMCPAdapter):
|
||||
"""提示词注入适配器 - 将工具转换为文本描述,通过提示词引导AI调用"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.adapter_type = AdapterType.PROMPT_INJECTION
|
||||
|
||||
def format_tools_for_prompt(
|
||||
self,
|
||||
tools: List[Dict[str, Any]],
|
||||
user_message: str
|
||||
) -> str:
|
||||
"""将工具列表注入到提示词中"""
|
||||
|
||||
if not tools:
|
||||
return user_message
|
||||
|
||||
# 格式化工具描述
|
||||
tool_descriptions = self._format_tools_as_text(tools)
|
||||
|
||||
# 构建增强的提示词
|
||||
enhanced_prompt = f"""你现在可以使用以下工具来帮助回答用户的问题。
|
||||
|
||||
## 可用工具
|
||||
|
||||
{tool_descriptions}
|
||||
|
||||
## 工具使用说明
|
||||
|
||||
当你需要使用工具时,请按以下XML格式输出(可以一次调用多个工具):
|
||||
|
||||
<tool_calls>
|
||||
<tool_call>
|
||||
<tool_name>工具名称</tool_name>
|
||||
<arguments>
|
||||
{{
|
||||
"参数名1": "参数值1",
|
||||
"参数名2": "参数值2"
|
||||
}}
|
||||
</arguments>
|
||||
</tool_call>
|
||||
</tool_calls>
|
||||
|
||||
## 重要提示
|
||||
|
||||
1. 只有在确实需要使用工具时才调用工具
|
||||
2. 参数必须是有效的JSON格式
|
||||
3. 仔细检查参数是否符合工具的要求
|
||||
4. 可以在一个<tool_calls>标签内包含多个<tool_call>
|
||||
5. 调用工具后,你会收到工具的执行结果,然后需要基于结果继续回答
|
||||
|
||||
---
|
||||
|
||||
用户问题:{user_message}
|
||||
|
||||
请分析问题,判断是否需要使用工具。如果需要,先输出工具调用,然后等待结果。如果不需要,直接回答问题。"""
|
||||
|
||||
return enhanced_prompt
|
||||
|
||||
def _format_tools_as_text(self, tools: List[Dict[str, Any]]) -> str:
|
||||
"""将工具格式化为可读的文本描述"""
|
||||
lines = []
|
||||
|
||||
for i, tool in enumerate(tools, 1):
|
||||
func = tool.get("function", {})
|
||||
name = func.get("name", "unknown")
|
||||
description = func.get("description", "无描述")
|
||||
parameters = func.get("parameters", {})
|
||||
|
||||
lines.append(f"### {i}. {name}")
|
||||
lines.append(f"**描述**: {description}")
|
||||
lines.append("")
|
||||
|
||||
# 格式化参数信息
|
||||
if parameters and "properties" in parameters:
|
||||
lines.append("**参数**:")
|
||||
properties = parameters.get("properties", {})
|
||||
required = parameters.get("required", [])
|
||||
|
||||
for param_name, param_info in properties.items():
|
||||
param_type = param_info.get("type", "string")
|
||||
param_desc = param_info.get("description", "")
|
||||
is_required = "必填" if param_name in required else "可选"
|
||||
|
||||
lines.append(f" - `{param_name}` ({param_type}, {is_required}): {param_desc}")
|
||||
lines.append("")
|
||||
|
||||
# 添加示例
|
||||
if "example" in func:
|
||||
lines.append(f"**示例**: {json.dumps(func['example'], ensure_ascii=False)}")
|
||||
lines.append("")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
def parse_tool_calls(self, ai_response) -> ToolCallResult:
|
||||
"""从AI响应中解析工具调用"""
|
||||
|
||||
tool_calls = []
|
||||
|
||||
try:
|
||||
# 处理不同类型的响应
|
||||
if isinstance(ai_response, dict):
|
||||
# 如果是字典,提取content字段
|
||||
ai_response = ai_response.get("choices", [{}])[0].get("message", {}).get("content", "")
|
||||
if not ai_response:
|
||||
return ToolCallResult(
|
||||
tool_calls=[],
|
||||
raw_response="",
|
||||
has_tool_calls=False
|
||||
)
|
||||
elif not isinstance(ai_response, str):
|
||||
# 转换为字符串
|
||||
ai_response = str(ai_response)
|
||||
|
||||
# 使用正则提取 <tool_calls> 标签内容
|
||||
tool_calls_match = re.search(
|
||||
r'<tool_calls>(.*?)</tool_calls>',
|
||||
ai_response,
|
||||
re.DOTALL | re.IGNORECASE
|
||||
)
|
||||
|
||||
if not tool_calls_match:
|
||||
# 没有找到工具调用
|
||||
return ToolCallResult(
|
||||
tool_calls=[],
|
||||
raw_response=ai_response,
|
||||
has_tool_calls=False
|
||||
)
|
||||
|
||||
tool_calls_content = tool_calls_match.group(1)
|
||||
|
||||
# 提取所有 <tool_call> 标签
|
||||
tool_call_pattern = r'<tool_call>(.*?)</tool_call>'
|
||||
tool_call_matches = re.findall(
|
||||
tool_call_pattern,
|
||||
tool_calls_content,
|
||||
re.DOTALL | re.IGNORECASE
|
||||
)
|
||||
|
||||
for i, tool_call_content in enumerate(tool_call_matches):
|
||||
# 提取工具名称
|
||||
name_match = re.search(
|
||||
r'<tool_name>(.*?)</tool_name>',
|
||||
tool_call_content,
|
||||
re.DOTALL | re.IGNORECASE
|
||||
)
|
||||
|
||||
# 提取参数
|
||||
args_match = re.search(
|
||||
r'<arguments>(.*?)</arguments>',
|
||||
tool_call_content,
|
||||
re.DOTALL | re.IGNORECASE
|
||||
)
|
||||
|
||||
if name_match and args_match:
|
||||
tool_name = name_match.group(1).strip()
|
||||
arguments_str = args_match.group(1).strip()
|
||||
|
||||
try:
|
||||
# 解析JSON参数
|
||||
arguments = json.loads(arguments_str)
|
||||
|
||||
# 构建标准格式的工具调用
|
||||
tool_calls.append({
|
||||
"id": f"call_{i}",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool_name,
|
||||
"arguments": json.dumps(arguments, ensure_ascii=False)
|
||||
}
|
||||
})
|
||||
|
||||
logger.info(f"✅ 解析工具调用: {tool_name}")
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"❌ 解析工具参数失败: {arguments_str}, 错误: {e}")
|
||||
continue
|
||||
|
||||
has_tool_calls = len(tool_calls) > 0
|
||||
|
||||
if has_tool_calls:
|
||||
logger.info(f"✅ 从响应中解析出 {len(tool_calls)} 个工具调用")
|
||||
|
||||
return ToolCallResult(
|
||||
tool_calls=tool_calls,
|
||||
raw_response=ai_response,
|
||||
has_tool_calls=has_tool_calls,
|
||||
needs_continuation=has_tool_calls
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 解析工具调用失败: {e}", exc_info=True)
|
||||
return ToolCallResult(
|
||||
tool_calls=[],
|
||||
raw_response=ai_response,
|
||||
has_tool_calls=False
|
||||
)
|
||||
|
||||
def build_continuation_prompt(
|
||||
self,
|
||||
original_message: str,
|
||||
ai_response: str,
|
||||
tool_results: List[Dict[str, Any]]
|
||||
) -> str:
|
||||
"""构建包含工具结果的继续对话提示词"""
|
||||
|
||||
# 格式化工具结果
|
||||
results_text = self._format_tool_results(tool_results)
|
||||
|
||||
continuation = f"""你之前尝试使用工具来回答用户的问题。
|
||||
|
||||
原始问题:{original_message}
|
||||
|
||||
你的工具调用:
|
||||
{self._extract_tool_calls_text(ai_response)}
|
||||
|
||||
工具执行结果:
|
||||
{results_text}
|
||||
|
||||
现在,请基于这些工具的执行结果,给出完整、详细的回答。不要重复调用工具,直接使用已有的结果来回答用户的问题。"""
|
||||
|
||||
return continuation
|
||||
|
||||
def _format_tool_results(self, tool_results: List[Dict[str, Any]]) -> str:
|
||||
"""格式化工具结果为可读文本"""
|
||||
lines = []
|
||||
|
||||
for i, result in enumerate(tool_results, 1):
|
||||
tool_name = result.get("name", "unknown")
|
||||
success = result.get("success", False)
|
||||
content = result.get("content", "")
|
||||
|
||||
status = "✅ 成功" if success else "❌ 失败"
|
||||
lines.append(f"{i}. {tool_name} - {status}")
|
||||
|
||||
if success:
|
||||
# 尝试美化JSON内容
|
||||
try:
|
||||
if isinstance(content, str):
|
||||
content_obj = json.loads(content)
|
||||
content = json.dumps(content_obj, ensure_ascii=False, indent=2)
|
||||
except:
|
||||
pass
|
||||
lines.append(f"```\n{content}\n```")
|
||||
else:
|
||||
error = result.get("error", "未知错误")
|
||||
lines.append(f"错误信息: {error}")
|
||||
|
||||
lines.append("")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
def _extract_tool_calls_text(self, ai_response: str) -> str:
|
||||
"""从AI响应中提取工具调用部分的文本"""
|
||||
match = re.search(
|
||||
r'<tool_calls>(.*?)</tool_calls>',
|
||||
ai_response,
|
||||
re.DOTALL | re.IGNORECASE
|
||||
)
|
||||
|
||||
if match:
|
||||
return match.group(0)
|
||||
return "(未找到工具调用)"
|
||||
@@ -1,353 +0,0 @@
|
||||
"""通用MCP适配器 - 自动检测API能力并选择最佳适配器"""
|
||||
|
||||
import time
|
||||
import asyncio
|
||||
from typing import Dict, Any, List, Optional
|
||||
from datetime import datetime, timedelta
|
||||
from dataclasses import dataclass
|
||||
|
||||
from app.mcp.adapters.base import BaseMCPAdapter, AdapterType, ToolCallResult
|
||||
from app.mcp.adapters.prompt_injection import PromptInjectionAdapter
|
||||
from app.mcp.adapters.function_calling import FunctionCallingAdapter
|
||||
from app.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class APICapability:
|
||||
"""API能力检测结果"""
|
||||
supports_function_calling: bool
|
||||
tested_at: datetime
|
||||
test_duration_ms: float
|
||||
error_message: Optional[str] = None
|
||||
|
||||
|
||||
class UniversalMCPAdapter:
|
||||
"""
|
||||
通用MCP适配器管理器
|
||||
|
||||
功能:
|
||||
1. 自动检测API是否支持Function Calling
|
||||
2. 缓存检测结果
|
||||
3. 自动降级策略:FC失败时切换到提示词注入
|
||||
4. 提供统一接口
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cache_ttl_hours: int = 24,
|
||||
enable_auto_fallback: bool = True
|
||||
):
|
||||
"""
|
||||
初始化通用适配器
|
||||
|
||||
Args:
|
||||
cache_ttl_hours: 能力检测缓存时长(小时)
|
||||
enable_auto_fallback: 是否启用自动降级
|
||||
"""
|
||||
# 适配器实例
|
||||
self.adapters = {
|
||||
AdapterType.FUNCTION_CALLING: FunctionCallingAdapter(),
|
||||
AdapterType.PROMPT_INJECTION: PromptInjectionAdapter()
|
||||
}
|
||||
|
||||
# API能力缓存: {api_identifier: APICapability}
|
||||
self._capability_cache: Dict[str, APICapability] = {}
|
||||
self._cache_ttl = timedelta(hours=cache_ttl_hours)
|
||||
self._cache_lock = asyncio.Lock()
|
||||
|
||||
# 配置
|
||||
self._enable_auto_fallback = enable_auto_fallback
|
||||
|
||||
logger.info(
|
||||
f"✅ UniversalMCPAdapter初始化完成 "
|
||||
f"(缓存TTL={cache_ttl_hours}小时, 自动降级={'开启' if enable_auto_fallback else '关闭'})"
|
||||
)
|
||||
|
||||
async def get_adapter(
|
||||
self,
|
||||
api_identifier: str,
|
||||
test_function: Optional[callable] = None
|
||||
) -> BaseMCPAdapter:
|
||||
"""
|
||||
获取适合当前API的适配器
|
||||
|
||||
Args:
|
||||
api_identifier: API标识符(如"openai_official", "azure_openai"等)
|
||||
test_function: 可选的测试函数,用于检测API能力
|
||||
|
||||
Returns:
|
||||
最适合的适配器实例
|
||||
"""
|
||||
|
||||
# 检查缓存
|
||||
capability = await self._get_cached_capability(api_identifier)
|
||||
|
||||
if capability is None and test_function:
|
||||
# 缓存未命中,执行检测
|
||||
capability = await self._detect_capability(api_identifier, test_function)
|
||||
|
||||
# 选择适配器
|
||||
if capability and capability.supports_function_calling:
|
||||
logger.info(f"🎯 使用Function Calling适配器: {api_identifier}")
|
||||
return self.adapters[AdapterType.FUNCTION_CALLING]
|
||||
else:
|
||||
logger.info(f"🎯 使用提示词注入适配器: {api_identifier}")
|
||||
return self.adapters[AdapterType.PROMPT_INJECTION]
|
||||
|
||||
async def _get_cached_capability(
|
||||
self,
|
||||
api_identifier: str
|
||||
) -> Optional[APICapability]:
|
||||
"""获取缓存的能力检测结果"""
|
||||
|
||||
async with self._cache_lock:
|
||||
if api_identifier not in self._capability_cache:
|
||||
return None
|
||||
|
||||
capability = self._capability_cache[api_identifier]
|
||||
|
||||
# 检查是否过期
|
||||
if datetime.now() - capability.tested_at > self._cache_ttl:
|
||||
logger.info(f"⏰ API能力缓存过期: {api_identifier}")
|
||||
del self._capability_cache[api_identifier]
|
||||
return None
|
||||
|
||||
logger.debug(f"🎯 API能力缓存命中: {api_identifier}")
|
||||
return capability
|
||||
|
||||
async def _detect_capability(
|
||||
self,
|
||||
api_identifier: str,
|
||||
test_function: callable
|
||||
) -> APICapability:
|
||||
"""
|
||||
检测API能力
|
||||
|
||||
Args:
|
||||
api_identifier: API标识符
|
||||
test_function: 测试函数,应该尝试使用Function Calling
|
||||
|
||||
Returns:
|
||||
能力检测结果
|
||||
"""
|
||||
|
||||
logger.info(f"🔍 开始检测API能力: {api_identifier}")
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# 调用测试函数
|
||||
result = await test_function()
|
||||
|
||||
# 判断是否成功
|
||||
supports_fc = self._is_function_calling_response(result)
|
||||
|
||||
duration_ms = (time.time() - start_time) * 1000
|
||||
|
||||
capability = APICapability(
|
||||
supports_function_calling=supports_fc,
|
||||
tested_at=datetime.now(),
|
||||
test_duration_ms=duration_ms
|
||||
)
|
||||
|
||||
# 缓存结果
|
||||
async with self._cache_lock:
|
||||
self._capability_cache[api_identifier] = capability
|
||||
|
||||
status = "✅ 支持" if supports_fc else "❌ 不支持"
|
||||
logger.info(
|
||||
f"{status} Function Calling: {api_identifier} "
|
||||
f"(耗时: {duration_ms:.2f}ms)"
|
||||
)
|
||||
|
||||
return capability
|
||||
|
||||
except Exception as e:
|
||||
duration_ms = (time.time() - start_time) * 1000
|
||||
|
||||
logger.warning(
|
||||
f"⚠️ API能力检测失败: {api_identifier}, 错误: {e}, "
|
||||
f"将使用提示词注入模式"
|
||||
)
|
||||
|
||||
capability = APICapability(
|
||||
supports_function_calling=False,
|
||||
tested_at=datetime.now(),
|
||||
test_duration_ms=duration_ms,
|
||||
error_message=str(e)
|
||||
)
|
||||
|
||||
# 缓存失败结果(避免重复测试)
|
||||
async with self._cache_lock:
|
||||
self._capability_cache[api_identifier] = capability
|
||||
|
||||
return capability
|
||||
|
||||
def _is_function_calling_response(self, response: Any) -> bool:
|
||||
"""
|
||||
判断响应是否是Function Calling格式
|
||||
|
||||
Args:
|
||||
response: API响应
|
||||
|
||||
Returns:
|
||||
是否支持Function Calling
|
||||
"""
|
||||
|
||||
try:
|
||||
# 检查字典格式
|
||||
if isinstance(response, dict):
|
||||
message = response.get("choices", [{}])[0].get("message", {})
|
||||
return "tool_calls" in message or "function_call" in message
|
||||
|
||||
# 检查对象格式(OpenAI SDK)
|
||||
if hasattr(response, "choices"):
|
||||
message = response.choices[0].message
|
||||
return hasattr(message, "tool_calls") or hasattr(message, "function_call")
|
||||
|
||||
return False
|
||||
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
async def call_with_fallback(
|
||||
self,
|
||||
api_identifier: str,
|
||||
tools: List[Dict[str, Any]],
|
||||
user_message: str,
|
||||
call_function: callable,
|
||||
test_function: Optional[callable] = None
|
||||
) -> ToolCallResult:
|
||||
"""
|
||||
带降级策略的工具调用
|
||||
|
||||
Args:
|
||||
api_identifier: API标识符
|
||||
tools: MCP工具列表
|
||||
user_message: 用户消息
|
||||
call_function: 实际调用API的函数
|
||||
test_function: 可选的测试函数
|
||||
|
||||
Returns:
|
||||
工具调用结果
|
||||
"""
|
||||
|
||||
# 获取适配器
|
||||
adapter = await self.get_adapter(api_identifier, test_function)
|
||||
|
||||
# 首次尝试
|
||||
try:
|
||||
if adapter.supports_native_tools():
|
||||
# Function Calling模式
|
||||
logger.info("🚀 尝试使用Function Calling模式")
|
||||
result = await self._try_function_calling(
|
||||
tools, user_message, call_function, adapter
|
||||
)
|
||||
else:
|
||||
# 提示词注入模式
|
||||
logger.info("🚀 使用提示词注入模式")
|
||||
result = await self._try_prompt_injection(
|
||||
tools, user_message, call_function, adapter
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 工具调用失败: {e}")
|
||||
|
||||
# 自动降级
|
||||
if self._enable_auto_fallback and adapter.supports_native_tools():
|
||||
logger.warning("⚠️ Function Calling失败,降级到提示词注入模式")
|
||||
|
||||
# 更新缓存,标记为不支持
|
||||
async with self._cache_lock:
|
||||
self._capability_cache[api_identifier] = APICapability(
|
||||
supports_function_calling=False,
|
||||
tested_at=datetime.now(),
|
||||
test_duration_ms=0,
|
||||
error_message=str(e)
|
||||
)
|
||||
|
||||
# 使用提示词注入重试
|
||||
fallback_adapter = self.adapters[AdapterType.PROMPT_INJECTION]
|
||||
return await self._try_prompt_injection(
|
||||
tools, user_message, call_function, fallback_adapter
|
||||
)
|
||||
|
||||
raise
|
||||
|
||||
async def _try_function_calling(
|
||||
self,
|
||||
tools: List[Dict[str, Any]],
|
||||
user_message: str,
|
||||
call_function: callable,
|
||||
adapter: FunctionCallingAdapter
|
||||
) -> ToolCallResult:
|
||||
"""尝试Function Calling模式"""
|
||||
|
||||
# Function Calling不需要修改提示词
|
||||
response = await call_function(
|
||||
message=user_message,
|
||||
tools_param=tools,
|
||||
tool_choice_param="auto"
|
||||
)
|
||||
|
||||
return adapter.parse_tool_calls(response)
|
||||
|
||||
async def _try_prompt_injection(
|
||||
self,
|
||||
tools: List[Dict[str, Any]],
|
||||
user_message: str,
|
||||
call_function: callable,
|
||||
adapter: PromptInjectionAdapter
|
||||
) -> ToolCallResult:
|
||||
"""尝试提示词注入模式"""
|
||||
|
||||
# 注入工具到提示词
|
||||
enhanced_prompt = adapter.format_tools_for_prompt(tools, user_message)
|
||||
|
||||
# 调用API(不传tools参数)
|
||||
response = await call_function(
|
||||
message=enhanced_prompt,
|
||||
tools_param=None,
|
||||
tool_choice_param=None
|
||||
)
|
||||
|
||||
# 从文本响应中解析工具调用
|
||||
return adapter.parse_tool_calls(response)
|
||||
|
||||
def clear_cache(self, api_identifier: Optional[str] = None):
|
||||
"""
|
||||
清理能力缓存
|
||||
|
||||
Args:
|
||||
api_identifier: 可选,只清理特定API的缓存
|
||||
"""
|
||||
if api_identifier:
|
||||
if api_identifier in self._capability_cache:
|
||||
del self._capability_cache[api_identifier]
|
||||
logger.info(f"🧹 已清理API能力缓存: {api_identifier}")
|
||||
else:
|
||||
self._capability_cache.clear()
|
||||
logger.info("🧹 已清理所有API能力缓存")
|
||||
|
||||
def get_cache_stats(self) -> Dict[str, Any]:
|
||||
"""获取缓存统计信息"""
|
||||
return {
|
||||
"total_cached": len(self._capability_cache),
|
||||
"cache_ttl_hours": self._cache_ttl.total_seconds() / 3600,
|
||||
"cached_apis": [
|
||||
{
|
||||
"api_identifier": api_id,
|
||||
"supports_fc": cap.supports_function_calling,
|
||||
"tested_at": cap.tested_at.isoformat(),
|
||||
"test_duration_ms": cap.test_duration_ms
|
||||
}
|
||||
for api_id, cap in self._capability_cache.items()
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
# 全局单例
|
||||
universal_mcp_adapter = UniversalMCPAdapter()
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,385 +0,0 @@
|
||||
"""HTTP MCP客户端 - 使用官方 MCP Python SDK 实现"""
|
||||
import asyncio
|
||||
from typing import Dict, Any, List, Optional
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from mcp import ClientSession, types
|
||||
from mcp.client.streamable_http import streamablehttp_client
|
||||
from pydantic import AnyUrl
|
||||
from anyio import ClosedResourceError
|
||||
|
||||
from app.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class MCPError(Exception):
|
||||
"""MCP错误"""
|
||||
pass
|
||||
|
||||
|
||||
class HTTPMCPClient:
|
||||
"""HTTP模式MCP客户端(基于官方 MCP Python SDK)"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
url: str,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
env: Optional[Dict[str, str]] = None,
|
||||
timeout: float = 60.0
|
||||
):
|
||||
"""
|
||||
初始化HTTP MCP客户端
|
||||
|
||||
Args:
|
||||
url: MCP服务器URL
|
||||
headers: HTTP请求头
|
||||
env: 环境变量(用于API Key等)
|
||||
timeout: 超时时间(秒)
|
||||
"""
|
||||
self.url = url.rstrip('/')
|
||||
self.headers = headers or {}
|
||||
self.env = env or {}
|
||||
self.timeout = timeout
|
||||
|
||||
# 如果env中有API Key,添加到headers
|
||||
if 'API_KEY' in self.env:
|
||||
self.headers['Authorization'] = f'Bearer {self.env["API_KEY"]}'
|
||||
|
||||
self._session: Optional[ClientSession] = None
|
||||
self._context_stack = [] # 保存上下文管理器栈
|
||||
self._initialized = False
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
async def _ensure_connected(self):
|
||||
"""确保连接已建立"""
|
||||
async with self._lock:
|
||||
if self._session is None:
|
||||
try:
|
||||
logger.info(f"🔗 连接到MCP服务器: {self.url}")
|
||||
|
||||
# 使用官方 SDK 的 streamable_http_client
|
||||
# 保存上下文管理器以便后续正确清理
|
||||
stream_context = streamablehttp_client(self.url)
|
||||
read_stream, write_stream, _ = await stream_context.__aenter__()
|
||||
self._context_stack.append(('stream', stream_context))
|
||||
|
||||
# 创建客户端会话
|
||||
self._session = ClientSession(read_stream, write_stream)
|
||||
session_context = self._session
|
||||
await session_context.__aenter__()
|
||||
self._context_stack.append(('session', session_context))
|
||||
|
||||
# 初始化会话
|
||||
await self._session.initialize()
|
||||
self._initialized = True
|
||||
|
||||
logger.info(f"✅ MCP会话初始化成功")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ MCP连接失败: {e}")
|
||||
await self._cleanup()
|
||||
raise MCPError(f"连接MCP服务器失败: {str(e)}")
|
||||
|
||||
async def _cleanup(self):
|
||||
"""清理连接资源(按照进入的相反顺序退出)"""
|
||||
# 按照LIFO顺序清理上下文
|
||||
while self._context_stack:
|
||||
ctx_type, ctx = self._context_stack.pop()
|
||||
try:
|
||||
await ctx.__aexit__(None, None, None)
|
||||
except RuntimeError as e:
|
||||
# 忽略 anyio 的任务上下文错误(在关闭时可能发生)
|
||||
if "cancel scope" in str(e).lower() or "different task" in str(e).lower():
|
||||
logger.debug(f"忽略{ctx_type}上下文清理的任务切换警告: {e}")
|
||||
else:
|
||||
logger.error(f"清理{ctx_type}上下文失败: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"清理{ctx_type}上下文失败: {e}")
|
||||
|
||||
self._session = None
|
||||
self._initialized = False
|
||||
|
||||
async def initialize(self) -> Dict[str, Any]:
|
||||
"""
|
||||
初始化MCP会话
|
||||
|
||||
Returns:
|
||||
初始化响应
|
||||
"""
|
||||
await self._ensure_connected()
|
||||
return {"status": "initialized"}
|
||||
|
||||
async def list_tools(self) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
列举可用工具
|
||||
|
||||
Returns:
|
||||
工具列表
|
||||
"""
|
||||
try:
|
||||
await self._ensure_connected()
|
||||
|
||||
result = await self._session.list_tools()
|
||||
|
||||
# 转换为字典格式
|
||||
tools = []
|
||||
for tool in result.tools:
|
||||
tool_dict = {
|
||||
"name": tool.name,
|
||||
"description": tool.description or "",
|
||||
"inputSchema": tool.inputSchema
|
||||
}
|
||||
tools.append(tool_dict)
|
||||
|
||||
logger.info(f"获取到 {len(tools)} 个工具")
|
||||
return tools
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取工具列表失败: {e}")
|
||||
raise MCPError(f"获取工具列表失败: {str(e)}")
|
||||
|
||||
async def call_tool(
|
||||
self,
|
||||
tool_name: str,
|
||||
arguments: Dict[str, Any],
|
||||
max_reconnect_attempts: int = 2
|
||||
) -> Any:
|
||||
"""
|
||||
调用工具(带自动重连)
|
||||
|
||||
Args:
|
||||
tool_name: 工具名称
|
||||
arguments: 工具参数
|
||||
max_reconnect_attempts: 最大重连尝试次数
|
||||
|
||||
Returns:
|
||||
工具执行结果
|
||||
"""
|
||||
for attempt in range(max_reconnect_attempts + 1):
|
||||
try:
|
||||
await self._ensure_connected()
|
||||
|
||||
logger.info(f"调用工具: {tool_name}")
|
||||
logger.debug(f" 参数类型: {type(arguments)}")
|
||||
logger.debug(f" 参数内容: {arguments}")
|
||||
logger.debug(f" 会话状态: initialized={self._initialized}, session={self._session is not None}")
|
||||
|
||||
result = await self._session.call_tool(tool_name, arguments)
|
||||
|
||||
logger.debug(f" 工具返回类型: {type(result)}")
|
||||
logger.debug(f" 返回内容: {result}")
|
||||
|
||||
# 处理返回结果
|
||||
# MCP SDK 返回 CallToolResult 对象
|
||||
if result.content:
|
||||
logger.debug(f" 返回content数量: {len(result.content)}")
|
||||
# 提取第一个content的文本
|
||||
for idx, content in enumerate(result.content):
|
||||
logger.debug(f" content[{idx}]类型: {type(content)}")
|
||||
if isinstance(content, types.TextContent):
|
||||
logger.debug(f" ✅ 返回TextContent: {content.text[:100] if len(content.text) > 100 else content.text}")
|
||||
return content.text
|
||||
elif isinstance(content, types.ImageContent):
|
||||
logger.debug(f" ✅ 返回ImageContent")
|
||||
return {
|
||||
"type": "image",
|
||||
"data": content.data,
|
||||
"mimeType": content.mimeType
|
||||
}
|
||||
# 如果没有文本内容,返回原始内容
|
||||
logger.debug(f" ⚠️ 返回原始content[0]")
|
||||
return result.content[0] if result.content else None
|
||||
|
||||
# 如果有结构化内容(2025-06-18规范)
|
||||
if hasattr(result, 'structuredContent') and result.structuredContent:
|
||||
logger.debug(f" ✅ 返回structuredContent")
|
||||
return result.structuredContent
|
||||
|
||||
logger.warning(f" ⚠️ 工具返回为None")
|
||||
return None
|
||||
|
||||
except ClosedResourceError as e:
|
||||
# 连接已关闭,尝试重连
|
||||
if attempt < max_reconnect_attempts:
|
||||
logger.warning(
|
||||
f"⚠️ MCP连接已关闭,尝试重新连接 "
|
||||
f"(第{attempt + 1}/{max_reconnect_attempts}次重连)"
|
||||
)
|
||||
await self._cleanup()
|
||||
await asyncio.sleep(0.5) # 短暂延迟后重连
|
||||
continue
|
||||
else:
|
||||
logger.error(f"❌ MCP连接重连失败,已达最大重试次数")
|
||||
error_msg = f"连接已关闭且重连失败 (尝试了{max_reconnect_attempts}次)"
|
||||
raise MCPError(error_msg)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"调用工具失败: {tool_name}, 错误: {e}", exc_info=True)
|
||||
logger.error(f" 参数: {arguments}")
|
||||
logger.error(f" 错误类型: {type(e).__name__}")
|
||||
logger.error(f" 错误详情: {repr(e)}")
|
||||
logger.error(f" 错误字符串: '{str(e)}'")
|
||||
error_msg = str(e) or repr(e) or f"未知错误 ({type(e).__name__})"
|
||||
raise MCPError(f"调用工具失败: {error_msg}")
|
||||
|
||||
# 理论上不会到这里
|
||||
raise MCPError(f"工具调用失败: 未知错误")
|
||||
|
||||
async def list_resources(self) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
列举可用资源
|
||||
|
||||
Returns:
|
||||
资源列表
|
||||
"""
|
||||
try:
|
||||
await self._ensure_connected()
|
||||
|
||||
result = await self._session.list_resources()
|
||||
|
||||
# 转换为字典格式
|
||||
resources = []
|
||||
for resource in result.resources:
|
||||
resource_dict = {
|
||||
"uri": str(resource.uri),
|
||||
"name": resource.name,
|
||||
"description": resource.description or "",
|
||||
"mimeType": resource.mimeType or ""
|
||||
}
|
||||
resources.append(resource_dict)
|
||||
|
||||
logger.info(f"获取到 {len(resources)} 个资源")
|
||||
return resources
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取资源列表失败: {e}")
|
||||
raise MCPError(f"获取资源列表失败: {str(e)}")
|
||||
|
||||
async def read_resource(self, uri: str) -> Any:
|
||||
"""
|
||||
读取资源
|
||||
|
||||
Args:
|
||||
uri: 资源URI
|
||||
|
||||
Returns:
|
||||
资源内容
|
||||
"""
|
||||
try:
|
||||
await self._ensure_connected()
|
||||
|
||||
result = await self._session.read_resource(AnyUrl(uri))
|
||||
|
||||
# 提取资源内容
|
||||
if result.contents:
|
||||
content = result.contents[0]
|
||||
if isinstance(content, types.TextContent):
|
||||
return content.text
|
||||
elif isinstance(content, types.ImageContent):
|
||||
return {
|
||||
"type": "image",
|
||||
"data": content.data,
|
||||
"mimeType": content.mimeType
|
||||
}
|
||||
elif isinstance(content, types.BlobResourceContents):
|
||||
return {
|
||||
"type": "blob",
|
||||
"blob": content.blob,
|
||||
"mimeType": content.mimeType
|
||||
}
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"读取资源失败: {uri}, 错误: {e}")
|
||||
raise MCPError(f"读取资源失败: {str(e)}")
|
||||
|
||||
async def test_connection(self) -> Dict[str, Any]:
|
||||
"""
|
||||
测试连接
|
||||
|
||||
Returns:
|
||||
测试结果
|
||||
"""
|
||||
import time
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# 尝试连接并列举工具(直接调用SDK,避免重复日志)
|
||||
await self._ensure_connected()
|
||||
|
||||
result = await self._session.list_tools()
|
||||
|
||||
# 转换为字典格式
|
||||
tools = []
|
||||
for tool in result.tools:
|
||||
tool_dict = {
|
||||
"name": tool.name,
|
||||
"description": tool.description or "",
|
||||
"inputSchema": tool.inputSchema
|
||||
}
|
||||
tools.append(tool_dict)
|
||||
|
||||
end_time = time.time()
|
||||
response_time = round((end_time - start_time) * 1000, 2)
|
||||
|
||||
logger.info(f"✅ 连接测试成功,获取到 {len(tools)} 个工具")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": "连接测试成功",
|
||||
"response_time_ms": response_time,
|
||||
"tools_count": len(tools),
|
||||
"tools": tools
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
end_time = time.time()
|
||||
response_time = round((end_time - start_time) * 1000, 2)
|
||||
|
||||
return {
|
||||
"success": False,
|
||||
"message": "连接测试失败",
|
||||
"response_time_ms": response_time,
|
||||
"error": str(e),
|
||||
"error_type": type(e).__name__,
|
||||
"suggestions": [
|
||||
"请检查服务器URL是否正确",
|
||||
"请确认API Key是否有效",
|
||||
"请检查网络连接",
|
||||
"请确认MCP服务器是否在线"
|
||||
]
|
||||
}
|
||||
|
||||
async def close(self):
|
||||
"""关闭客户端连接"""
|
||||
logger.info(f"关闭MCP客户端: {self.url}")
|
||||
await self._cleanup()
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def create_mcp_client(
|
||||
url: str,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
env: Optional[Dict[str, str]] = None,
|
||||
timeout: float = 60.0
|
||||
):
|
||||
"""
|
||||
创建MCP客户端的上下文管理器
|
||||
|
||||
Args:
|
||||
url: MCP服务器URL
|
||||
headers: HTTP请求头
|
||||
env: 环境变量
|
||||
timeout: 超时时间
|
||||
|
||||
Yields:
|
||||
HTTPMCPClient实例
|
||||
"""
|
||||
client = HTTPMCPClient(url, headers, env, timeout)
|
||||
try:
|
||||
await client.initialize()
|
||||
yield client
|
||||
finally:
|
||||
await client.close()
|
||||
@@ -1,527 +0,0 @@
|
||||
"""MCP插件注册表 - 管理运行时插件实例"""
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Dict, Optional, Any, List
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from app.mcp.http_client import HTTPMCPClient, MCPError
|
||||
from app.mcp.config import mcp_config
|
||||
from app.models.mcp_plugin import MCPPlugin
|
||||
from app.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SessionInfo:
|
||||
"""会话信息"""
|
||||
client: HTTPMCPClient
|
||||
created_at: float
|
||||
last_access: float
|
||||
request_count: int = 0
|
||||
error_count: int = 0
|
||||
status: str = "active" # active, degraded, error
|
||||
|
||||
|
||||
class MCPPluginRegistry:
|
||||
"""MCP插件注册表 - 管理运行时插件实例(优化版)"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_clients: Optional[int] = None,
|
||||
client_ttl: Optional[int] = None
|
||||
):
|
||||
"""
|
||||
初始化注册表
|
||||
|
||||
Args:
|
||||
max_clients: 最大缓存客户端数量(默认使用配置)
|
||||
client_ttl: 客户端过期时间(秒,默认使用配置)
|
||||
"""
|
||||
# 存储格式: {plugin_id: SessionInfo}
|
||||
self._sessions: Dict[str, SessionInfo] = {}
|
||||
|
||||
# 全局锁用于保护会话字典
|
||||
self._sessions_lock = asyncio.Lock()
|
||||
|
||||
# 细粒度锁:每个用户一个锁
|
||||
self._user_locks: Dict[str, asyncio.Lock] = {}
|
||||
self._locks_lock = asyncio.Lock() # 保护locks字典本身
|
||||
|
||||
# 配置参数(使用配置常量)
|
||||
self._max_clients = max_clients or mcp_config.MAX_CLIENTS
|
||||
self._client_ttl = client_ttl or mcp_config.CLIENT_TTL_SECONDS
|
||||
|
||||
# 启动后台清理任务
|
||||
self._cleanup_task = None
|
||||
self._health_check_task = None
|
||||
self._tasks_started = False
|
||||
|
||||
def _ensure_background_tasks(self):
|
||||
"""确保后台任务已启动(延迟初始化)"""
|
||||
if not self._tasks_started:
|
||||
try:
|
||||
# 检查是否有运行中的事件循环
|
||||
loop = asyncio.get_running_loop()
|
||||
if self._cleanup_task is None:
|
||||
self._cleanup_task = asyncio.create_task(self._cleanup_loop())
|
||||
logger.info("✅ MCP插件注册表后台清理任务已启动")
|
||||
|
||||
if self._health_check_task is None:
|
||||
self._health_check_task = asyncio.create_task(self._health_check_loop())
|
||||
logger.info("✅ MCP会话健康检查任务已启动")
|
||||
|
||||
self._tasks_started = True
|
||||
except RuntimeError:
|
||||
# 没有运行中的事件循环,稍后再试
|
||||
pass
|
||||
|
||||
async def _cleanup_loop(self):
|
||||
"""后台清理过期客户端"""
|
||||
while True:
|
||||
try:
|
||||
await asyncio.sleep(mcp_config.CLEANUP_INTERVAL_SECONDS)
|
||||
await self._cleanup_expired_sessions()
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"清理任务异常: {e}")
|
||||
|
||||
async def _health_check_loop(self):
|
||||
"""后台健康检查"""
|
||||
while True:
|
||||
try:
|
||||
await asyncio.sleep(mcp_config.HEALTH_CHECK_INTERVAL_SECONDS)
|
||||
await self._check_session_health()
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"健康检查任务异常: {e}")
|
||||
|
||||
async def _cleanup_expired_sessions(self):
|
||||
"""清理过期的会话"""
|
||||
now = time.time()
|
||||
expired_ids = []
|
||||
|
||||
async with self._sessions_lock:
|
||||
# 收集过期的plugin_id
|
||||
for plugin_id, session in list(self._sessions.items()):
|
||||
if now - session.last_access > self._client_ttl:
|
||||
expired_ids.append(plugin_id)
|
||||
|
||||
if expired_ids:
|
||||
logger.info(f"🧹 清理 {len(expired_ids)} 个过期的MCP会话")
|
||||
for plugin_id in expired_ids:
|
||||
# 提取user_id来获取对应的锁
|
||||
user_id = plugin_id.split(':', 1)[0]
|
||||
user_lock = await self._get_user_lock(user_id)
|
||||
|
||||
async with user_lock:
|
||||
async with self._sessions_lock:
|
||||
if plugin_id in self._sessions:
|
||||
await self._unload_plugin_unsafe(plugin_id)
|
||||
|
||||
async def _check_session_health(self):
|
||||
"""增强的会话健康检查"""
|
||||
async with self._sessions_lock:
|
||||
for plugin_id, session in list(self._sessions.items()):
|
||||
# 计算错误率
|
||||
if session.request_count > mcp_config.MIN_REQUESTS_FOR_HEALTH_CHECK:
|
||||
error_rate = session.error_count / session.request_count
|
||||
|
||||
# 动态调整状态(使用配置常量)
|
||||
if error_rate > mcp_config.ERROR_RATE_CRITICAL:
|
||||
if session.status != "error":
|
||||
session.status = "error"
|
||||
logger.error(
|
||||
f"❌ 会话 {plugin_id} 错误率过高 "
|
||||
f"({error_rate:.1%}), 标记为error"
|
||||
)
|
||||
elif error_rate > mcp_config.ERROR_RATE_WARNING:
|
||||
if session.status == "active":
|
||||
session.status = "degraded"
|
||||
logger.warning(
|
||||
f"⚠️ 会话 {plugin_id} 健康状况下降 "
|
||||
f"(错误率: {error_rate:.1%})"
|
||||
)
|
||||
elif session.status == "degraded":
|
||||
# 错误率降低,恢复正常
|
||||
session.status = "active"
|
||||
logger.info(f"✅ 会话 {plugin_id} 恢复正常")
|
||||
|
||||
# 检查即将过期的会话(最后1分钟提醒)
|
||||
idle_time = time.time() - session.last_access
|
||||
time_until_expiry = self._client_ttl - idle_time
|
||||
|
||||
# 仅在最后1分钟(60秒)内提醒一次
|
||||
if 0 < time_until_expiry <= 60:
|
||||
# 使用会话属性避免重复提醒
|
||||
if not hasattr(session, '_expiry_warned') or not session._expiry_warned:
|
||||
logger.warning(
|
||||
f"⏰ 会话 {plugin_id} 即将过期 "
|
||||
f"(剩余 {time_until_expiry:.0f} 秒)"
|
||||
)
|
||||
session._expiry_warned = True
|
||||
elif time_until_expiry > 60:
|
||||
# 重置警告标志(如果会话被重新使用)
|
||||
if hasattr(session, '_expiry_warned'):
|
||||
session._expiry_warned = False
|
||||
|
||||
async def _get_user_lock(self, user_id: str) -> asyncio.Lock:
|
||||
"""
|
||||
获取用户专属的锁(细粒度锁)
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
|
||||
Returns:
|
||||
该用户的锁对象
|
||||
"""
|
||||
async with self._locks_lock:
|
||||
if user_id not in self._user_locks:
|
||||
self._user_locks[user_id] = asyncio.Lock()
|
||||
return self._user_locks[user_id]
|
||||
|
||||
def _touch_session(self, plugin_id: str):
|
||||
"""
|
||||
更新会话的最后访问时间(需要在锁内调用)
|
||||
|
||||
Args:
|
||||
plugin_id: 插件ID
|
||||
"""
|
||||
if plugin_id in self._sessions:
|
||||
session = self._sessions[plugin_id]
|
||||
session.last_access = time.time()
|
||||
session.request_count += 1
|
||||
|
||||
async def _evict_lru_session(self):
|
||||
"""驱逐最久未使用的会话(当达到max_clients限制时)"""
|
||||
if len(self._sessions) >= self._max_clients:
|
||||
# 找到最旧的会话
|
||||
oldest_id = None
|
||||
oldest_time = float('inf')
|
||||
|
||||
for plugin_id, session in self._sessions.items():
|
||||
if session.last_access < oldest_time:
|
||||
oldest_time = session.last_access
|
||||
oldest_id = plugin_id
|
||||
|
||||
if oldest_id:
|
||||
logger.info(f"📤 达到最大会话数量限制,驱逐: {oldest_id}")
|
||||
await self._unload_plugin_unsafe(oldest_id)
|
||||
|
||||
async def load_plugin(self, plugin: MCPPlugin) -> bool:
|
||||
"""
|
||||
从配置加载插件
|
||||
|
||||
Args:
|
||||
plugin: 插件配置
|
||||
|
||||
Returns:
|
||||
是否加载成功
|
||||
"""
|
||||
# 确保后台任务已启动
|
||||
self._ensure_background_tasks()
|
||||
|
||||
# 使用细粒度锁(只锁定当前用户)
|
||||
user_lock = await self._get_user_lock(plugin.user_id)
|
||||
async with user_lock:
|
||||
try:
|
||||
plugin_id = f"{plugin.user_id}:{plugin.plugin_name}"
|
||||
|
||||
# 如果已加载,先卸载
|
||||
async with self._sessions_lock:
|
||||
if plugin_id in self._sessions:
|
||||
await self._unload_plugin_unsafe(plugin_id)
|
||||
|
||||
# 检查是否需要驱逐LRU会话
|
||||
await self._evict_lru_session()
|
||||
|
||||
# 目前只支持HTTP类型
|
||||
if plugin.plugin_type == "http":
|
||||
if not plugin.server_url:
|
||||
logger.error(f"HTTP插件缺少server_url: {plugin.plugin_name}")
|
||||
return False
|
||||
|
||||
# 为每个插件创建独立的HTTP客户端
|
||||
client = HTTPMCPClient(
|
||||
url=plugin.server_url,
|
||||
headers=plugin.headers or {},
|
||||
env=plugin.env or {},
|
||||
timeout=plugin.config.get('timeout', 60.0) if plugin.config else 60.0
|
||||
)
|
||||
|
||||
# 创建会话信息
|
||||
now = time.time()
|
||||
session = SessionInfo(
|
||||
client=client,
|
||||
created_at=now,
|
||||
last_access=now,
|
||||
request_count=0,
|
||||
error_count=0,
|
||||
status="active"
|
||||
)
|
||||
|
||||
# 存储会话
|
||||
async with self._sessions_lock:
|
||||
self._sessions[plugin_id] = session
|
||||
|
||||
logger.info(f"✅ 加载MCP插件: {plugin_id} (独立会话)")
|
||||
return True
|
||||
else:
|
||||
logger.warning(f"暂不支持的插件类型: {plugin.plugin_type}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"加载插件失败 {plugin.plugin_name}: {e}")
|
||||
return False
|
||||
|
||||
async def unload_plugin(self, user_id: str, plugin_name: str):
|
||||
"""
|
||||
卸载插件
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
plugin_name: 插件名称
|
||||
"""
|
||||
# 使用细粒度锁(只锁定当前用户)
|
||||
user_lock = await self._get_user_lock(user_id)
|
||||
async with user_lock:
|
||||
plugin_id = f"{user_id}:{plugin_name}"
|
||||
async with self._sessions_lock:
|
||||
await self._unload_plugin_unsafe(plugin_id)
|
||||
|
||||
async def _unload_plugin_unsafe(self, plugin_id: str):
|
||||
"""卸载插件(不加锁,内部使用,需要在sessions_lock内调用)"""
|
||||
if plugin_id in self._sessions:
|
||||
session = self._sessions[plugin_id]
|
||||
try:
|
||||
await session.client.close()
|
||||
except Exception as e:
|
||||
logger.error(f"关闭插件客户端失败 {plugin_id}: {e}")
|
||||
|
||||
del self._sessions[plugin_id]
|
||||
logger.info(f"卸载MCP插件: {plugin_id}")
|
||||
|
||||
async def reload_plugin(self, plugin: MCPPlugin) -> bool:
|
||||
"""
|
||||
重新加载插件
|
||||
|
||||
Args:
|
||||
plugin: 插件配置
|
||||
|
||||
Returns:
|
||||
是否重载成功
|
||||
"""
|
||||
await self.unload_plugin(plugin.user_id, plugin.plugin_name)
|
||||
return await self.load_plugin(plugin)
|
||||
|
||||
def get_client(self, user_id: str, plugin_name: str) -> Optional[HTTPMCPClient]:
|
||||
"""
|
||||
获取插件客户端(线程安全,支持访问时间更新)
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
plugin_name: 插件名称
|
||||
|
||||
Returns:
|
||||
客户端实例或None
|
||||
"""
|
||||
plugin_id = f"{user_id}:{plugin_name}"
|
||||
|
||||
session = self._sessions.get(plugin_id)
|
||||
if session:
|
||||
# 检查会话状态
|
||||
if session.status == "error":
|
||||
logger.warning(
|
||||
f"⚠️ 会话 {plugin_id} 处于错误状态,"
|
||||
f"建议调用者重新加载插件"
|
||||
)
|
||||
# 不返回错误状态的客户端
|
||||
return None
|
||||
|
||||
# ✅ 使用锁保护状态更新,避免并发问题
|
||||
# 注意:这里使用原子操作更新简单字段,不需要异步锁
|
||||
session.last_access = time.time()
|
||||
session.request_count += 1
|
||||
return session.client
|
||||
return None
|
||||
|
||||
async def get_or_reconnect_client(
|
||||
self,
|
||||
user_id: str,
|
||||
plugin_name: str,
|
||||
plugin: MCPPlugin
|
||||
) -> HTTPMCPClient:
|
||||
"""
|
||||
获取或重连客户端(自动处理错误状态)
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
plugin_name: 插件名称
|
||||
plugin: 插件配置对象
|
||||
|
||||
Returns:
|
||||
客户端实例
|
||||
|
||||
Raises:
|
||||
ValueError: 插件加载失败
|
||||
"""
|
||||
plugin_id = f"{user_id}:{plugin_name}"
|
||||
|
||||
# 获取用户锁
|
||||
user_lock = await self._get_user_lock(user_id)
|
||||
async with user_lock:
|
||||
session = self._sessions.get(plugin_id)
|
||||
|
||||
# 检查会话健康状态
|
||||
if session and session.status == "error":
|
||||
logger.warning(f"会话 {plugin_id} 处于错误状态,尝试重连")
|
||||
async with self._sessions_lock:
|
||||
await self._unload_plugin_unsafe(plugin_id)
|
||||
session = None
|
||||
|
||||
# 如果没有会话,加载插件
|
||||
if not session:
|
||||
success = await self.load_plugin(plugin)
|
||||
if not success:
|
||||
raise ValueError(f"插件加载失败: {plugin_name}")
|
||||
session = self._sessions[plugin_id]
|
||||
|
||||
return session.client
|
||||
|
||||
async def call_tool(
|
||||
self,
|
||||
user_id: str,
|
||||
plugin_name: str,
|
||||
tool_name: str,
|
||||
arguments: Dict[str, Any]
|
||||
) -> Any:
|
||||
"""
|
||||
调用插件工具(带错误计数和状态管理)
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
plugin_name: 插件名称
|
||||
tool_name: 工具名称
|
||||
arguments: 工具参数
|
||||
|
||||
Returns:
|
||||
工具执行结果
|
||||
|
||||
Raises:
|
||||
ValueError: 插件不存在或未启用
|
||||
MCPError: 工具调用失败
|
||||
"""
|
||||
plugin_id = f"{user_id}:{plugin_name}"
|
||||
|
||||
# 获取会话
|
||||
session = self._sessions.get(plugin_id)
|
||||
if not session:
|
||||
raise ValueError(f"插件未加载: {plugin_name}")
|
||||
|
||||
try:
|
||||
result = await session.client.call_tool(tool_name, arguments)
|
||||
logger.info(f"✅ 工具调用成功: {plugin_name}.{tool_name}")
|
||||
|
||||
# 调用成功,重置状态(如果之前是degraded)
|
||||
if session.status == "degraded":
|
||||
session.status = "active"
|
||||
logger.info(f"✅ 会话 {plugin_id} 恢复正常")
|
||||
|
||||
return result
|
||||
except Exception as e:
|
||||
# 增加错误计数
|
||||
session.error_count += 1
|
||||
|
||||
# 根据错误率更新状态
|
||||
if session.request_count > 0:
|
||||
error_rate = session.error_count / session.request_count
|
||||
if error_rate > 0.5:
|
||||
session.status = "error"
|
||||
elif error_rate > 0.3:
|
||||
session.status = "degraded"
|
||||
|
||||
logger.error(
|
||||
f"❌ 工具调用失败: {plugin_name}.{tool_name}, "
|
||||
f"错误: {e} (错误计数: {session.error_count}/{session.request_count})"
|
||||
)
|
||||
raise
|
||||
|
||||
async def get_plugin_tools(
|
||||
self,
|
||||
user_id: str,
|
||||
plugin_name: str
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取插件的工具列表
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
plugin_name: 插件名称
|
||||
|
||||
Returns:
|
||||
工具列表
|
||||
"""
|
||||
client = self.get_client(user_id, plugin_name)
|
||||
|
||||
if not client:
|
||||
raise ValueError(f"插件未加载: {plugin_name}")
|
||||
|
||||
try:
|
||||
tools = await client.list_tools()
|
||||
return tools
|
||||
except Exception as e:
|
||||
logger.error(f"获取工具列表失败: {plugin_name}, 错误: {e}")
|
||||
raise
|
||||
|
||||
async def test_plugin(
|
||||
self,
|
||||
user_id: str,
|
||||
plugin_name: str
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
测试插件连接
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
plugin_name: 插件名称
|
||||
|
||||
Returns:
|
||||
测试结果
|
||||
"""
|
||||
client = self.get_client(user_id, plugin_name)
|
||||
|
||||
if not client:
|
||||
raise ValueError(f"插件未加载: {plugin_name}")
|
||||
|
||||
return await client.test_connection()
|
||||
|
||||
async def cleanup_all(self):
|
||||
"""清理所有插件和资源"""
|
||||
# 停止后台任务
|
||||
if self._cleanup_task:
|
||||
self._cleanup_task.cancel()
|
||||
try:
|
||||
await self._cleanup_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
if self._health_check_task:
|
||||
self._health_check_task.cancel()
|
||||
try:
|
||||
await self._health_check_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
# 清理所有会话
|
||||
async with self._sessions_lock:
|
||||
plugin_ids = list(self._sessions.keys())
|
||||
for plugin_id in plugin_ids:
|
||||
await self._unload_plugin_unsafe(plugin_id)
|
||||
|
||||
logger.info("✅ 已清理所有MCP插件和资源")
|
||||
|
||||
|
||||
# 全局注册表实例
|
||||
mcp_registry = MCPPluginRegistry()
|
||||
@@ -0,0 +1,50 @@
|
||||
"""MCP插件状态同步服务
|
||||
|
||||
将内存中的会话状态变更同步到数据库,确保状态一致性。
|
||||
"""
|
||||
|
||||
from typing import Dict, Any
|
||||
from sqlalchemy import update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
|
||||
from app.models.mcp_plugin import MCPPlugin
|
||||
from app.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
async def sync_status_to_db(event: Dict[str, Any]):
|
||||
"""
|
||||
状态变更回调 - 同步到数据库
|
||||
"""
|
||||
user_id = event["user_id"]
|
||||
plugin_name = event["plugin_name"]
|
||||
new_status = event["new_status"]
|
||||
reason = event.get("reason", "")
|
||||
|
||||
try:
|
||||
from app.database import get_engine
|
||||
|
||||
engine = await get_engine(user_id)
|
||||
AsyncSessionLocal = async_sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
|
||||
|
||||
async with AsyncSessionLocal() as db:
|
||||
stmt = (
|
||||
update(MCPPlugin)
|
||||
.where(MCPPlugin.user_id == user_id, MCPPlugin.plugin_name == plugin_name)
|
||||
.values(status=new_status, last_error=reason if new_status == "error" else None)
|
||||
)
|
||||
await db.execute(stmt)
|
||||
await db.commit()
|
||||
|
||||
logger.debug(f"✅ 状态已同步到数据库: {plugin_name} -> {new_status}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 状态同步失败: {plugin_name}, 错误: {e}")
|
||||
|
||||
|
||||
def register_status_sync():
|
||||
"""注册状态同步回调到MCP客户端"""
|
||||
from app.mcp import mcp_client
|
||||
mcp_client.register_status_callback(sync_status_to_db)
|
||||
logger.info("✅ MCP状态同步服务已注册")
|
||||
@@ -1,5 +1,5 @@
|
||||
"""职业相关的Pydantic模型"""
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, ConfigDict
|
||||
from typing import Optional, List, Dict, Any
|
||||
from datetime import datetime
|
||||
|
||||
@@ -63,8 +63,7 @@ class CareerResponse(BaseModel):
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class CareerListResponse(BaseModel):
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
"""章节相关的Pydantic模型"""
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, ConfigDict
|
||||
from typing import Optional, List, Dict, Any
|
||||
from datetime import datetime
|
||||
|
||||
@@ -58,8 +58,7 @@ class ChapterResponse(BaseModel):
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class ChapterListResponse(BaseModel):
|
||||
@@ -142,8 +141,7 @@ class ExpansionPlanUpdate(BaseModel):
|
||||
estimated_words: Optional[int] = Field(None, description="预估字数", ge=500, le=10000)
|
||||
scenes: Optional[List[SceneData]] = Field(None, description="场景列表")
|
||||
|
||||
class Config:
|
||||
json_schema_extra = {
|
||||
model_config = ConfigDict(json_schema_extra={
|
||||
"example": {
|
||||
"key_events": ["主角遇到挑战", "关键决策时刻"],
|
||||
"character_focus": ["张三", "李四"],
|
||||
@@ -159,7 +157,7 @@ class ExpansionPlanUpdate(BaseModel):
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
|
||||
class ExpansionPlanResponse(BaseModel):
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
"""角色相关的Pydantic模型"""
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, ConfigDict
|
||||
from typing import Optional, List, Dict, Any
|
||||
from datetime import datetime
|
||||
|
||||
@@ -98,8 +98,7 @@ class CharacterResponse(CharacterBase):
|
||||
main_career_stage: Optional[int] = Field(None, description="主职业阶段")
|
||||
sub_careers: Optional[List[Dict[str, Any]]] = Field(None, description="副职业列表")
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class CharacterGenerateRequest(BaseModel):
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
"""MCP插件Pydantic模式"""
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, ConfigDict
|
||||
from typing import Optional, Dict, Any, List
|
||||
from datetime import datetime
|
||||
|
||||
@@ -82,8 +82,7 @@ class MCPPluginResponse(BaseModel):
|
||||
# 时间戳
|
||||
created_at: datetime
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class MCPToolCall(BaseModel):
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
"""大纲相关的Pydantic模型"""
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, ConfigDict
|
||||
from typing import Optional, List, Dict, Any
|
||||
from datetime import datetime
|
||||
|
||||
@@ -103,8 +103,7 @@ class OutlineResponse(BaseModel):
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class OutlineGenerateRequest(BaseModel):
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
"""项目相关的Pydantic模型"""
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, ConfigDict
|
||||
from typing import Optional, Literal
|
||||
from datetime import datetime
|
||||
|
||||
@@ -59,8 +59,7 @@ class ProjectResponse(ProjectBase):
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class ProjectListResponse(BaseModel):
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
"""关系管理相关的Pydantic模型"""
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, ConfigDict
|
||||
from typing import Optional, List
|
||||
from datetime import datetime
|
||||
|
||||
@@ -17,8 +17,7 @@ class RelationshipTypeResponse(BaseModel):
|
||||
description: Optional[str] = None
|
||||
created_at: datetime
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
# ============ 角色关系相关 ============
|
||||
@@ -62,8 +61,7 @@ class CharacterRelationshipResponse(CharacterRelationshipBase):
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class RelationshipGraphNode(BaseModel):
|
||||
@@ -127,8 +125,7 @@ class OrganizationResponse(OrganizationBase):
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class OrganizationDetailResponse(BaseModel):
|
||||
@@ -185,8 +182,7 @@ class OrganizationMemberResponse(OrganizationMemberBase):
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class OrganizationMemberDetailResponse(BaseModel):
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
"""写作风格 Schema"""
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, ConfigDict
|
||||
from typing import Optional
|
||||
from datetime import datetime
|
||||
|
||||
@@ -48,8 +48,7 @@ class WritingStyleResponse(BaseModel):
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class WritingStyleListResponse(BaseModel):
|
||||
|
||||
@@ -71,7 +71,18 @@ class AnthropicClient:
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
system_prompt: Optional[str] = None,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
tools: Optional[list] = None,
|
||||
tool_choice: Optional[str] = None,
|
||||
) -> AsyncGenerator[Dict[str, Any], None]:
|
||||
"""
|
||||
流式生成,支持工具调用
|
||||
|
||||
Yields:
|
||||
Dict with keys:
|
||||
- content: str - 文本内容块
|
||||
- tool_calls: list - 工具调用列表(如果有)
|
||||
- done: bool - 是否结束
|
||||
"""
|
||||
kwargs = {
|
||||
"model": model,
|
||||
"max_tokens": max_tokens,
|
||||
@@ -80,12 +91,42 @@ class AnthropicClient:
|
||||
}
|
||||
if system_prompt:
|
||||
kwargs["system"] = system_prompt
|
||||
if tools:
|
||||
kwargs["tools"] = tools
|
||||
if tool_choice == "required":
|
||||
kwargs["tool_choice"] = {"type": "any"}
|
||||
elif tool_choice == "auto":
|
||||
kwargs["tool_choice"] = {"type": "auto"}
|
||||
|
||||
try:
|
||||
async with self.client.messages.stream(**kwargs) as stream:
|
||||
try:
|
||||
async for text in stream.text_stream:
|
||||
yield text
|
||||
tool_calls = []
|
||||
async for chunk in stream:
|
||||
# 处理不同类型的块
|
||||
if chunk.type == "text_delta":
|
||||
yield {"content": chunk.text}
|
||||
elif chunk.type == "tool_use_delta":
|
||||
# 工具调用增量
|
||||
if not tool_calls or tool_calls[-1].get("id") != chunk.id:
|
||||
tool_calls.append({
|
||||
"id": chunk.id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": chunk.name,
|
||||
"arguments": ""
|
||||
}
|
||||
})
|
||||
# 追加参数
|
||||
if tool_calls[-1]["function"]["arguments"] is None:
|
||||
tool_calls[-1]["function"]["arguments"] = ""
|
||||
tool_calls[-1]["function"]["arguments"] += chunk.input_gets_new_text or ""
|
||||
elif chunk.type == "message_delta":
|
||||
if chunk.stop_reason:
|
||||
# 流结束
|
||||
if tool_calls:
|
||||
yield {"tool_calls": tool_calls}
|
||||
yield {"done": True, "finish_reason": chunk.stop_reason}
|
||||
except GeneratorExit:
|
||||
# 生成器被关闭,这是正常的清理过程
|
||||
logger.debug("Anthropic 流式响应生成器被关闭(GeneratorExit)")
|
||||
|
||||
@@ -111,7 +111,18 @@ class GeminiClient:
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
system_prompt: Optional[str] = None,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
tools: Optional[list] = None,
|
||||
tool_choice: Optional[str] = None,
|
||||
) -> AsyncGenerator[Dict[str, Any], None]:
|
||||
"""
|
||||
流式生成,支持工具调用
|
||||
|
||||
Yields:
|
||||
Dict with keys:
|
||||
- content: str - 文本内容块
|
||||
- tool_calls: list - 工具调用列表(如果有)
|
||||
- done: bool - 是否结束
|
||||
"""
|
||||
url = f"{self.base_url}/models/{model}:streamGenerateContent?key={self.api_key}&alt=sse"
|
||||
|
||||
contents = []
|
||||
@@ -125,6 +136,8 @@ class GeminiClient:
|
||||
}
|
||||
if system_prompt:
|
||||
payload["systemInstruction"] = {"parts": [{"text": system_prompt}]}
|
||||
if tools:
|
||||
payload["tools"] = self._convert_tools_to_gemini(tools)
|
||||
|
||||
try:
|
||||
async with self.client.stream("POST", url, json=payload) as response:
|
||||
@@ -139,9 +152,26 @@ class GeminiClient:
|
||||
if candidates and len(candidates) > 0:
|
||||
parts = candidates[0].get("content", {}).get("parts", [])
|
||||
if parts and len(parts) > 0:
|
||||
text = parts[0].get("text", "")
|
||||
text = ""
|
||||
function_calls = []
|
||||
for part in parts:
|
||||
if "text" in part:
|
||||
text += part["text"]
|
||||
elif "functionCall" in part:
|
||||
fc = part["functionCall"]
|
||||
function_calls.append({
|
||||
"id": f"call_{fc['name']}",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": fc["name"],
|
||||
"arguments": fc.get("args", {})
|
||||
}
|
||||
})
|
||||
|
||||
if text:
|
||||
yield text
|
||||
yield {"content": text}
|
||||
if function_calls:
|
||||
yield {"tool_calls": function_calls}
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
except GeneratorExit:
|
||||
|
||||
@@ -86,8 +86,21 @@ class OpenAIClient(BaseAIClient):
|
||||
model: str,
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
payload = self._build_payload(messages, model, temperature, max_tokens, stream=True)
|
||||
tools: Optional[list] = None,
|
||||
tool_choice: Optional[str] = None,
|
||||
) -> AsyncGenerator[Dict[str, Any], None]:
|
||||
"""
|
||||
流式生成,支持工具调用
|
||||
|
||||
Yields:
|
||||
Dict with keys:
|
||||
- content: str - 文本内容块
|
||||
- tool_calls: list - 工具调用列表(如果有)
|
||||
- done: bool - 是否结束
|
||||
"""
|
||||
payload = self._build_payload(messages, model, temperature, max_tokens, tools, tool_choice, stream=True)
|
||||
|
||||
tool_calls_buffer = {} # 收集工具调用块
|
||||
|
||||
try:
|
||||
async with await self._request_with_retry("POST", "/chat/completions", payload, stream=True) as response:
|
||||
@@ -97,14 +110,38 @@ class OpenAIClient(BaseAIClient):
|
||||
if line.startswith("data: "):
|
||||
data_str = line[6:]
|
||||
if data_str.strip() == "[DONE]":
|
||||
# 流结束,检查是否有工具调用需要处理
|
||||
if tool_calls_buffer:
|
||||
yield {"tool_calls": list(tool_calls_buffer.values()), "done": True}
|
||||
yield {"done": True}
|
||||
break
|
||||
try:
|
||||
data = json.loads(data_str)
|
||||
choices = data.get("choices", [])
|
||||
if choices and len(choices) > 0:
|
||||
content = choices[0].get("delta", {}).get("content", "")
|
||||
delta = choices[0].get("delta", {})
|
||||
content = delta.get("content", "")
|
||||
|
||||
# 检查工具调用
|
||||
tc_list = delta.get("tool_calls")
|
||||
if tc_list:
|
||||
for tc in tc_list:
|
||||
index = tc.get("index", 0)
|
||||
if index not in tool_calls_buffer:
|
||||
tool_calls_buffer[index] = tc
|
||||
else:
|
||||
existing = tool_calls_buffer[index]
|
||||
# 合并 function.arguments
|
||||
if "function" in tc and "function" in existing:
|
||||
if tc["function"].get("arguments"):
|
||||
existing["function"]["arguments"] = (
|
||||
existing["function"].get("arguments", "") +
|
||||
tc["function"]["arguments"]
|
||||
)
|
||||
|
||||
if content:
|
||||
yield content
|
||||
yield {"content": content}
|
||||
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
except GeneratorExit:
|
||||
|
||||
@@ -1,9 +1,12 @@
|
||||
"""Anthropic Provider"""
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional
|
||||
|
||||
from app.logger import get_logger
|
||||
from app.services.ai_clients.anthropic_client import AnthropicClient
|
||||
from .base_provider import BaseAIProvider
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class AnthropicProvider(BaseAIProvider):
|
||||
"""Anthropic 提供商"""
|
||||
@@ -39,7 +42,62 @@ class AnthropicProvider(BaseAIProvider):
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
system_prompt: Optional[str] = None,
|
||||
tools: Optional[List[Dict]] = None,
|
||||
tool_choice: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
# 如果有工具,使用真正的流式工具调用
|
||||
if tools:
|
||||
logger.debug(f"🔧 AnthropicProvider: 有 {len(tools)} 个工具,使用流式处理")
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
actual_tool_choice = tool_choice if tool_choice else "auto"
|
||||
|
||||
tool_calls_buffer = []
|
||||
|
||||
async for chunk in self.client.chat_completion_stream(
|
||||
messages=messages,
|
||||
model=model,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
system_prompt=system_prompt,
|
||||
tools=tools,
|
||||
tool_choice=actual_tool_choice,
|
||||
):
|
||||
# 检查是否有工具调用
|
||||
if chunk.get("tool_calls"):
|
||||
tool_calls_buffer.extend(chunk["tool_calls"])
|
||||
logger.debug(f"🔧 收到工具调用: {len(chunk['tool_calls'])} 个")
|
||||
|
||||
# 检查是否结束
|
||||
if chunk.get("done"):
|
||||
if tool_calls_buffer:
|
||||
logger.info(f"🔧 流式结束,处理 {len(tool_calls_buffer)} 个工具调用")
|
||||
from app.mcp import mcp_client
|
||||
actual_user_id = user_id or ""
|
||||
tool_results = await mcp_client.batch_call_tools(
|
||||
user_id=actual_user_id,
|
||||
tool_calls=tool_calls_buffer
|
||||
)
|
||||
# 将工具结果注入到上下文中
|
||||
tool_context = mcp_client.build_tool_context(tool_results, format="markdown")
|
||||
|
||||
# 构建最终提示词,要求AI基于工具结果回答
|
||||
final_prompt = f"{prompt}\n\n{tool_context}\n\n请基于以上工具查询结果,给出完整详细的回答。"
|
||||
final_messages = [{"role": "user", "content": final_prompt}]
|
||||
|
||||
# 递归调用生成最终结果
|
||||
async for final_chunk in self._generate_with_tools(
|
||||
final_messages, model, temperature, max_tokens, system_prompt, tools, user_id
|
||||
):
|
||||
yield final_chunk
|
||||
break
|
||||
|
||||
# 输出文本内容
|
||||
if chunk.get("content"):
|
||||
yield chunk["content"]
|
||||
return
|
||||
|
||||
# 无工具时普通流式生成
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
async for chunk in self.client.chat_completion_stream(
|
||||
messages=messages,
|
||||
@@ -48,4 +106,56 @@ class AnthropicProvider(BaseAIProvider):
|
||||
max_tokens=max_tokens,
|
||||
system_prompt=system_prompt,
|
||||
):
|
||||
yield chunk
|
||||
# 确保只 yield 字符串内容,避免 yield 字典导致类型错误
|
||||
if isinstance(chunk, dict):
|
||||
if chunk.get("content"):
|
||||
yield chunk["content"]
|
||||
else:
|
||||
yield chunk
|
||||
|
||||
async def _generate_with_tools(
|
||||
self,
|
||||
messages: list,
|
||||
model: str,
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
system_prompt: Optional[str] = None,
|
||||
tools: list = None,
|
||||
user_id: Optional[str] = None,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""辅助方法:带工具的流式生成"""
|
||||
tool_calls_buffer = []
|
||||
|
||||
async for chunk in self.client.chat_completion_stream(
|
||||
messages=messages,
|
||||
model=model,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
system_prompt=system_prompt,
|
||||
tools=tools,
|
||||
tool_choice="auto",
|
||||
):
|
||||
if chunk.get("tool_calls"):
|
||||
tool_calls_buffer.extend(chunk["tool_calls"])
|
||||
logger.debug(f"🔧 _generate_with_tools 收到工具调用: {len(chunk['tool_calls'])} 个")
|
||||
|
||||
if chunk.get("done"):
|
||||
if tool_calls_buffer:
|
||||
from app.mcp import mcp_client
|
||||
actual_user_id = user_id or ""
|
||||
tool_results = await mcp_client.batch_call_tools(
|
||||
user_id=actual_user_id,
|
||||
tool_calls=tool_calls_buffer
|
||||
)
|
||||
tool_context = mcp_client.build_tool_context(tool_results, format="markdown")
|
||||
|
||||
messages.append({"role": "user", "content": f"{tool_context}\n\n请基于以上工具查询结果,给出完整详细的回答。"})
|
||||
|
||||
async for final_chunk in self._generate_with_tools(
|
||||
messages, model, temperature, max_tokens, system_prompt, tools, user_id
|
||||
):
|
||||
yield final_chunk
|
||||
break
|
||||
|
||||
if chunk.get("content"):
|
||||
yield chunk["content"]
|
||||
@@ -28,6 +28,9 @@ class BaseAIProvider(ABC):
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
system_prompt: Optional[str] = None,
|
||||
tools: Optional[List[Dict]] = None,
|
||||
tool_choice: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""流式生成"""
|
||||
pass
|
||||
@@ -1,8 +1,12 @@
|
||||
"""Gemini Provider"""
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional
|
||||
|
||||
from app.logger import get_logger
|
||||
from app.services.ai_clients.gemini_client import GeminiClient
|
||||
from .base_provider import BaseAIProvider
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class GeminiProvider(BaseAIProvider):
|
||||
def __init__(self, client: GeminiClient):
|
||||
@@ -36,7 +40,62 @@ class GeminiProvider(BaseAIProvider):
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
system_prompt: Optional[str] = None,
|
||||
tools: Optional[List[Dict]] = None,
|
||||
tool_choice: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
# 如果有工具,使用真正的流式工具调用
|
||||
if tools:
|
||||
logger.debug(f"🔧 GeminiProvider: 有 {len(tools)} 个工具,使用流式处理")
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
actual_tool_choice = tool_choice if tool_choice else "auto"
|
||||
|
||||
tool_calls_buffer = []
|
||||
|
||||
async for chunk in self.client.chat_completion_stream(
|
||||
messages=messages,
|
||||
model=model,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
system_prompt=system_prompt,
|
||||
tools=tools,
|
||||
tool_choice=actual_tool_choice,
|
||||
):
|
||||
# 检查是否有工具调用
|
||||
if chunk.get("tool_calls"):
|
||||
tool_calls_buffer.extend(chunk["tool_calls"])
|
||||
logger.debug(f"🔧 收到工具调用: {len(chunk['tool_calls'])} 个")
|
||||
|
||||
# 检查是否结束
|
||||
if chunk.get("done"):
|
||||
if tool_calls_buffer:
|
||||
logger.info(f"🔧 流式结束,处理 {len(tool_calls_buffer)} 个工具调用")
|
||||
from app.mcp import mcp_client
|
||||
actual_user_id = user_id or ""
|
||||
tool_results = await mcp_client.batch_call_tools(
|
||||
user_id=actual_user_id,
|
||||
tool_calls=tool_calls_buffer
|
||||
)
|
||||
# 将工具结果注入到上下文中
|
||||
tool_context = mcp_client.build_tool_context(tool_results, format="markdown")
|
||||
|
||||
# 构建最终提示词,要求AI基于工具结果回答
|
||||
final_prompt = f"{prompt}\n\n{tool_context}\n\n请基于以上工具查询结果,给出完整详细的回答。"
|
||||
final_messages = [{"role": "user", "content": final_prompt}]
|
||||
|
||||
# 递归调用生成最终结果
|
||||
async for final_chunk in self._generate_with_tools(
|
||||
final_messages, model, temperature, max_tokens, system_prompt, tools, user_id
|
||||
):
|
||||
yield final_chunk
|
||||
break
|
||||
|
||||
# 输出文本内容
|
||||
if chunk.get("content"):
|
||||
yield chunk["content"]
|
||||
return
|
||||
|
||||
# 无工具时普通流式生成
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
async for chunk in self.client.chat_completion_stream(
|
||||
messages=messages,
|
||||
@@ -45,4 +104,56 @@ class GeminiProvider(BaseAIProvider):
|
||||
max_tokens=max_tokens,
|
||||
system_prompt=system_prompt,
|
||||
):
|
||||
yield chunk
|
||||
# 确保只 yield 字符串内容,避免 yield 字典导致类型错误
|
||||
if isinstance(chunk, dict):
|
||||
if chunk.get("content"):
|
||||
yield chunk["content"]
|
||||
else:
|
||||
yield chunk
|
||||
|
||||
async def _generate_with_tools(
|
||||
self,
|
||||
messages: list,
|
||||
model: str,
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
system_prompt: Optional[str] = None,
|
||||
tools: list = None,
|
||||
user_id: Optional[str] = None,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""辅助方法:带工具的流式生成"""
|
||||
tool_calls_buffer = []
|
||||
|
||||
async for chunk in self.client.chat_completion_stream(
|
||||
messages=messages,
|
||||
model=model,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
system_prompt=system_prompt,
|
||||
tools=tools,
|
||||
tool_choice="auto",
|
||||
):
|
||||
if chunk.get("tool_calls"):
|
||||
tool_calls_buffer.extend(chunk["tool_calls"])
|
||||
logger.debug(f"🔧 _generate_with_tools 收到工具调用: {len(chunk['tool_calls'])} 个")
|
||||
|
||||
if chunk.get("done"):
|
||||
if tool_calls_buffer:
|
||||
from app.mcp import mcp_client
|
||||
actual_user_id = user_id or ""
|
||||
tool_results = await mcp_client.batch_call_tools(
|
||||
user_id=actual_user_id,
|
||||
tool_calls=tool_calls_buffer
|
||||
)
|
||||
tool_context = mcp_client.build_tool_context(tool_results, format="markdown")
|
||||
|
||||
messages.append({"role": "user", "content": f"{tool_context}\n\n请基于以上工具查询结果,给出完整详细的回答。"})
|
||||
|
||||
async for final_chunk in self._generate_with_tools(
|
||||
messages, model, temperature, max_tokens, system_prompt, tools, user_id
|
||||
):
|
||||
yield final_chunk
|
||||
break
|
||||
|
||||
if chunk.get("content"):
|
||||
yield chunk["content"]
|
||||
@@ -1,9 +1,12 @@
|
||||
"""OpenAI Provider"""
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional
|
||||
|
||||
from app.logger import get_logger
|
||||
from app.services.ai_clients.openai_client import OpenAIClient
|
||||
from .base_provider import BaseAIProvider
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class OpenAIProvider(BaseAIProvider):
|
||||
"""OpenAI 提供商"""
|
||||
@@ -42,16 +45,117 @@ class OpenAIProvider(BaseAIProvider):
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
system_prompt: Optional[str] = None,
|
||||
tools: Optional[List[Dict]] = None,
|
||||
tool_choice: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
messages = []
|
||||
if system_prompt:
|
||||
messages.append({"role": "system", "content": system_prompt})
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
|
||||
# 如果有工具,使用真正的流式工具调用
|
||||
if tools:
|
||||
logger.debug(f"🔧 OpenAIProvider: 有 {len(tools)} 个工具,使用流式处理")
|
||||
actual_tool_choice = tool_choice if tool_choice else "auto"
|
||||
|
||||
tool_calls_buffer = []
|
||||
|
||||
async for chunk in self.client.chat_completion_stream(
|
||||
messages=messages,
|
||||
model=model,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
tools=tools,
|
||||
tool_choice=actual_tool_choice,
|
||||
):
|
||||
# 检查是否有工具调用
|
||||
if chunk.get("tool_calls"):
|
||||
tool_calls_buffer.extend(chunk["tool_calls"])
|
||||
logger.debug(f"🔧 收到工具调用: {len(chunk['tool_calls'])} 个")
|
||||
|
||||
# 检查是否结束
|
||||
if chunk.get("done"):
|
||||
if tool_calls_buffer:
|
||||
logger.info(f"🔧 流式结束,处理 {len(tool_calls_buffer)} 个工具调用")
|
||||
from app.mcp import mcp_client
|
||||
actual_user_id = user_id or ""
|
||||
tool_results = await mcp_client.batch_call_tools(
|
||||
user_id=actual_user_id,
|
||||
tool_calls=tool_calls_buffer
|
||||
)
|
||||
# 将工具结果注入到上下文中
|
||||
tool_context = mcp_client.build_tool_context(tool_results, format="markdown")
|
||||
|
||||
# 构建最终提示词,要求AI基于工具结果回答
|
||||
final_prompt = f"{prompt}\n\n{tool_context}\n\n请基于以上工具查询结果,给出完整详细的回答。"
|
||||
final_messages = messages.copy()
|
||||
final_messages.append({"role": "user", "content": final_prompt})
|
||||
|
||||
# 递归调用生成最终结果
|
||||
async for final_chunk in self._generate_with_tools(
|
||||
final_messages, model, temperature, max_tokens, tools, user_id
|
||||
):
|
||||
yield final_chunk
|
||||
break
|
||||
|
||||
# 输出文本内容
|
||||
if chunk.get("content"):
|
||||
yield chunk["content"]
|
||||
return
|
||||
|
||||
# 无工具时普通流式生成
|
||||
async for chunk in self.client.chat_completion_stream(
|
||||
messages=messages,
|
||||
model=model,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
):
|
||||
yield chunk
|
||||
# 确保只 yield 字符串内容,避免 yield 字典导致类型错误
|
||||
if isinstance(chunk, dict):
|
||||
if chunk.get("content"):
|
||||
yield chunk["content"]
|
||||
else:
|
||||
yield chunk
|
||||
|
||||
async def _generate_with_tools(
|
||||
self,
|
||||
messages: list,
|
||||
model: str,
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
tools: list,
|
||||
user_id: Optional[str] = None,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""辅助方法:带工具的流式生成(无tool_choice,AI自由决定)"""
|
||||
async for chunk in self.client.chat_completion_stream(
|
||||
messages=messages,
|
||||
model=model,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
tools=tools,
|
||||
tool_choice="auto",
|
||||
):
|
||||
if chunk.get("tool_calls"):
|
||||
from app.mcp import mcp_client
|
||||
actual_user_id = user_id or ""
|
||||
tool_results = await mcp_client.batch_call_tools(
|
||||
user_id=actual_user_id,
|
||||
tool_calls=chunk["tool_calls"]
|
||||
)
|
||||
tool_context = mcp_client.build_tool_context(tool_results, format="markdown")
|
||||
|
||||
# 再次调用获取最终回答
|
||||
messages.append({"role": "user", "content": f"{tool_context}\n\n请基于以上工具查询结果,给出完整详细的回答。"})
|
||||
|
||||
async for final_chunk in self._generate_with_tools(
|
||||
messages, model, temperature, max_tokens, tools, user_id
|
||||
):
|
||||
yield final_chunk
|
||||
break
|
||||
|
||||
if chunk.get("done"):
|
||||
break
|
||||
|
||||
if chunk.get("content"):
|
||||
yield chunk["content"]
|
||||
+391
-112
@@ -1,4 +1,10 @@
|
||||
"""AI服务封装 - 统一的AI接口"""
|
||||
"""AI服务封装 - 统一的AI接口
|
||||
|
||||
重构后支持自动MCP工具加载:
|
||||
- 所有AI方法在请求前自动检查用户MCP配置
|
||||
- 如果有启用的MCP插件且有可用工具,自动发送tools
|
||||
- 通过 auto_mcp 参数控制是否启用自动工具加载
|
||||
"""
|
||||
from typing import Optional, AsyncGenerator, List, Dict, Any, Union
|
||||
|
||||
from app.config import settings as app_settings
|
||||
@@ -13,7 +19,6 @@ from app.services.ai_providers.anthropic_provider import AnthropicProvider
|
||||
from app.services.ai_providers.gemini_provider import GeminiProvider
|
||||
from app.services.ai_providers.base_provider import BaseAIProvider
|
||||
from app.services.json_helper import clean_json_response, parse_json
|
||||
from app.mcp.adapters.universal import universal_mcp_adapter
|
||||
|
||||
# 导出清理函数
|
||||
cleanup_http_clients = cleanup_all_clients
|
||||
@@ -22,7 +27,41 @@ logger = get_logger(__name__)
|
||||
|
||||
|
||||
class AIService:
|
||||
"""AI服务统一接口"""
|
||||
"""
|
||||
AI服务统一接口
|
||||
|
||||
MCP工具支持:
|
||||
- 在创建服务时传入 user_id 和 db_session
|
||||
- 根据用户MCP插件的enabled状态自动决定是否启用MCP
|
||||
- 如果有任意一个MCP插件启用,则加载并使用工具
|
||||
- 如果所有插件都关闭,则不使用任何MCP工具
|
||||
- 通过 auto_mcp=False 可临时禁用自动工具加载
|
||||
- 通过 mcp_max_rounds 控制工具调用轮数
|
||||
- 通过 clear_mcp_cache() 可清理MCP工具缓存
|
||||
|
||||
MCP启用逻辑(backend/app/api/settings.py 中的 get_user_ai_service):
|
||||
- 查询用户的所有MCP插件
|
||||
- 如果有启用的插件 (enabled=True),则 enable_mcp=True
|
||||
- 如果所有插件都关闭或没有插件,则 enable_mcp=False
|
||||
|
||||
使用示例:
|
||||
# 创建支持MCP的AI服务(根据插件状态自动决定是否启用)
|
||||
ai_service = create_user_ai_service_with_mcp(
|
||||
api_provider="openai",
|
||||
api_key="...",
|
||||
user_id="user123",
|
||||
db_session=db
|
||||
)
|
||||
|
||||
# 自动加载MCP工具(如果有启用的插件)
|
||||
result = await ai_service.generate_text(prompt="...")
|
||||
|
||||
# 临时禁用MCP工具
|
||||
result = await ai_service.generate_text(prompt="...", auto_mcp=False)
|
||||
|
||||
# 自定义轮数
|
||||
result = await ai_service.generate_text(prompt="...", mcp_max_rounds=3)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -33,8 +72,11 @@ class AIService:
|
||||
default_temperature: Optional[float] = None,
|
||||
default_max_tokens: Optional[int] = None,
|
||||
default_system_prompt: Optional[str] = None,
|
||||
enable_mcp_adapter: bool = True,
|
||||
config: Optional[AIClientConfig] = None,
|
||||
# MCP支持参数
|
||||
user_id: Optional[str] = None,
|
||||
db_session: Optional[Any] = None,
|
||||
enable_mcp: bool = True,
|
||||
):
|
||||
self.api_provider = api_provider or app_settings.default_ai_provider
|
||||
self.default_model = default_model or app_settings.default_model
|
||||
@@ -43,7 +85,12 @@ class AIService:
|
||||
self.default_system_prompt = default_system_prompt
|
||||
self.config = config or default_config
|
||||
|
||||
self.mcp_adapter = universal_mcp_adapter if enable_mcp_adapter else None
|
||||
# MCP配置
|
||||
self.user_id = user_id
|
||||
self.db_session = db_session
|
||||
self._enable_mcp = enable_mcp
|
||||
self._cached_tools: Optional[List[Dict]] = None
|
||||
self._tools_loaded = False
|
||||
|
||||
self._openai_provider: Optional[OpenAIProvider] = None
|
||||
self._anthropic_provider: Optional[AnthropicProvider] = None
|
||||
@@ -68,6 +115,36 @@ class AIService:
|
||||
client = GeminiClient(api_key, api_base_url, self.config)
|
||||
self._gemini_provider = GeminiProvider(client)
|
||||
|
||||
@property
|
||||
def enable_mcp(self) -> bool:
|
||||
"""是否启用MCP工具"""
|
||||
return self._enable_mcp
|
||||
|
||||
@enable_mcp.setter
|
||||
def enable_mcp(self, value: bool):
|
||||
"""设置MCP启用状态,如果禁用则清理缓存"""
|
||||
if value is False and self._enable_mcp is True:
|
||||
# 从启用变为禁用,清理缓存
|
||||
self.clear_mcp_cache()
|
||||
self._enable_mcp = value
|
||||
|
||||
def clear_mcp_cache(self):
|
||||
"""
|
||||
清理MCP工具缓存
|
||||
|
||||
当禁用MCP时调用此方法,确保后续AI调用不会使用缓存的工具。
|
||||
同时更新 _tools_loaded 状态,使下次调用时重新检查。
|
||||
"""
|
||||
if self._cached_tools is not None:
|
||||
logger.info(f"🔧 清理MCP工具缓存,移除 {len(self._cached_tools)} 个工具")
|
||||
self._cached_tools = None
|
||||
else:
|
||||
logger.debug(f"🔧 MCP工具缓存已经是空,无需清理")
|
||||
|
||||
# 更新加载状态,确保下次调用会重新检查
|
||||
self._tools_loaded = False
|
||||
logger.debug(f"🔧 MCP工具状态已重置: enable_mcp={self._enable_mcp}, _tools_loaded=False")
|
||||
|
||||
def _get_provider(self, provider: Optional[str] = None) -> BaseAIProvider:
|
||||
"""获取对应的 Provider"""
|
||||
p = provider or self.api_provider
|
||||
@@ -79,6 +156,166 @@ class AIService:
|
||||
return self._gemini_provider
|
||||
raise ValueError(f"Provider {p} 未初始化")
|
||||
|
||||
async def _prepare_mcp_tools(self, auto_mcp: bool = True, force_refresh: bool = False) -> Optional[List[Dict]]:
|
||||
"""
|
||||
预处理MCP工具
|
||||
|
||||
检查用户MCP配置并加载可用工具。
|
||||
结果会被缓存,避免重复加载。
|
||||
|
||||
Args:
|
||||
auto_mcp: 是否自动加载MCP工具(来自调用方参数)
|
||||
force_refresh: 是否强制刷新缓存
|
||||
|
||||
Returns:
|
||||
- None: 无可用工具(未配置/未启用/加载失败)
|
||||
- List[Dict]: OpenAI格式的工具列表
|
||||
"""
|
||||
# 前置条件检查
|
||||
if not self._enable_mcp:
|
||||
logger.debug(f"🔧 MCP工具未启用 (_enable_mcp=False)")
|
||||
# 即使有缓存也清理掉,确保不使用
|
||||
self._cached_tools = None
|
||||
self._tools_loaded = False
|
||||
return None
|
||||
|
||||
if not auto_mcp:
|
||||
logger.debug(f"🔧 auto_mcp=False,跳过MCP工具加载")
|
||||
# 即使有缓存也清理掉,确保不使用
|
||||
self._cached_tools = None
|
||||
self._tools_loaded = False
|
||||
return None
|
||||
|
||||
if not self.user_id:
|
||||
logger.debug(f"🔧 MCP工具加载跳过: user_id未设置")
|
||||
return None
|
||||
|
||||
if not self.db_session:
|
||||
logger.debug(f"🔧 MCP工具加载跳过: db_session未设置")
|
||||
return None
|
||||
|
||||
# 使用缓存(只有 enable_mcp=True 时才使用缓存)
|
||||
if self._tools_loaded and not force_refresh:
|
||||
if self._cached_tools:
|
||||
logger.debug(f"🔧 使用缓存的MCP工具 ({len(self._cached_tools)}个)")
|
||||
return self._cached_tools
|
||||
|
||||
try:
|
||||
from app.services.mcp_tools_loader import mcp_tools_loader
|
||||
|
||||
self._cached_tools = await mcp_tools_loader.get_user_tools(
|
||||
user_id=self.user_id,
|
||||
db_session=self.db_session,
|
||||
use_cache=True,
|
||||
force_refresh=force_refresh
|
||||
)
|
||||
self._tools_loaded = True
|
||||
|
||||
if self._cached_tools:
|
||||
logger.info(f"🔧 已加载 {len(self._cached_tools)} 个MCP工具")
|
||||
else:
|
||||
logger.debug(f"📭 用户 {self.user_id} 没有可用的MCP工具")
|
||||
|
||||
return self._cached_tools
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"⚠️ 加载MCP工具失败: {e}")
|
||||
self._tools_loaded = True
|
||||
self._cached_tools = None
|
||||
return None
|
||||
|
||||
async def _handle_tool_calls(
|
||||
self,
|
||||
original_prompt: str,
|
||||
response: Dict[str, Any],
|
||||
max_rounds: int = 2,
|
||||
**kwargs
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
处理AI返回的工具调用
|
||||
|
||||
Args:
|
||||
original_prompt: 原始提示词
|
||||
response: AI响应(包含tool_calls)
|
||||
max_rounds: 最大工具调用轮数
|
||||
**kwargs: 传递给generate_text的其他参数
|
||||
|
||||
Returns:
|
||||
最终的AI响应
|
||||
"""
|
||||
from app.mcp import mcp_client
|
||||
|
||||
tool_calls = response.get("tool_calls", [])
|
||||
if not tool_calls or not self.user_id:
|
||||
return response
|
||||
|
||||
result = {
|
||||
"content": response.get("content", ""),
|
||||
"tool_calls_made": 0,
|
||||
"tools_used": [],
|
||||
"finish_reason": response.get("finish_reason", ""),
|
||||
"mcp_enhanced": True
|
||||
}
|
||||
|
||||
prompt = original_prompt
|
||||
|
||||
for round_num in range(max_rounds):
|
||||
logger.info(f"🔧 工具调用 - 第{round_num+1}/{max_rounds}轮,{len(tool_calls)}个工具")
|
||||
|
||||
try:
|
||||
# 批量执行工具调用
|
||||
tool_results = await mcp_client.batch_call_tools(
|
||||
user_id=self.user_id,
|
||||
tool_calls=tool_calls
|
||||
)
|
||||
|
||||
# 记录使用的工具
|
||||
for tc in tool_calls:
|
||||
name = tc["function"]["name"]
|
||||
if name not in result["tools_used"]:
|
||||
result["tools_used"].append(name)
|
||||
result["tool_calls_made"] += len(tool_calls)
|
||||
|
||||
# 构建工具上下文
|
||||
tool_context = mcp_client.build_tool_context(tool_results, format="markdown")
|
||||
|
||||
# 更新提示词
|
||||
if round_num == max_rounds - 1:
|
||||
# 最后一轮,强制要求回答
|
||||
prompt = f"{original_prompt}\n\n{tool_context}\n\n⚠️ 重要:请基于以上工具查询结果,给出完整详细的最终答案。不要再调用工具。"
|
||||
tool_choice = "none"
|
||||
else:
|
||||
prompt = f"{original_prompt}\n\n{tool_context}\n\n请基于以上工具查询结果,继续完成任务。"
|
||||
tool_choice = kwargs.get("tool_choice", "auto")
|
||||
|
||||
# 继续调用AI
|
||||
prov = self._get_provider(kwargs.get("provider"))
|
||||
next_response = await prov.generate(
|
||||
prompt=prompt,
|
||||
model=kwargs.get("model") or self.default_model,
|
||||
temperature=kwargs.get("temperature") or self.default_temperature,
|
||||
max_tokens=kwargs.get("max_tokens") or self.default_max_tokens,
|
||||
system_prompt=kwargs.get("system_prompt") or self.default_system_prompt,
|
||||
tools=None if tool_choice == "none" else self._cached_tools,
|
||||
tool_choice=tool_choice,
|
||||
)
|
||||
|
||||
tool_calls = next_response.get("tool_calls", [])
|
||||
|
||||
if not tool_calls:
|
||||
# 没有更多工具调用,返回结果
|
||||
result["content"] = next_response.get("content", "")
|
||||
result["finish_reason"] = next_response.get("finish_reason", "stop")
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 工具调用失败: {e}")
|
||||
result["content"] = response.get("content", "")
|
||||
result["finish_reason"] = "tool_error"
|
||||
break
|
||||
|
||||
return result
|
||||
|
||||
async def generate_text(
|
||||
self,
|
||||
prompt: str,
|
||||
@@ -89,10 +326,39 @@ class AIService:
|
||||
system_prompt: Optional[str] = None,
|
||||
tools: Optional[List[Dict]] = None,
|
||||
tool_choice: Optional[str] = None,
|
||||
auto_mcp: bool = True,
|
||||
handle_tool_calls: bool = True,
|
||||
mcp_max_rounds: Optional[int] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""生成文本"""
|
||||
"""
|
||||
生成文本(自动支持MCP工具)
|
||||
|
||||
Args:
|
||||
prompt: 用户提示词
|
||||
provider: AI提供商
|
||||
model: 模型名称
|
||||
temperature: 温度
|
||||
max_tokens: 最大令牌数
|
||||
system_prompt: 系统提示词
|
||||
tools: 手动指定的工具列表(优先级高于自动加载)
|
||||
tool_choice: 工具选择策略
|
||||
auto_mcp: 是否自动加载MCP工具(默认True)
|
||||
handle_tool_calls: 是否自动处理工具调用(默认True)
|
||||
mcp_max_rounds: 最大工具调用轮数(None使用默认值3)
|
||||
|
||||
Returns:
|
||||
包含生成内容的字典
|
||||
"""
|
||||
# 使用全局配置的MCP轮数(如果未指定)
|
||||
if mcp_max_rounds is None:
|
||||
mcp_max_rounds = app_settings.mcp_max_rounds
|
||||
|
||||
# 自动加载MCP工具
|
||||
if auto_mcp and tools is None:
|
||||
tools = await self._prepare_mcp_tools(auto_mcp=auto_mcp)
|
||||
|
||||
prov = self._get_provider(provider)
|
||||
return await prov.generate(
|
||||
response = await prov.generate(
|
||||
prompt=prompt,
|
||||
model=model or self.default_model,
|
||||
temperature=temperature or self.default_temperature,
|
||||
@@ -101,6 +367,22 @@ class AIService:
|
||||
tools=tools,
|
||||
tool_choice=tool_choice,
|
||||
)
|
||||
|
||||
# 处理工具调用
|
||||
if handle_tool_calls and response.get("tool_calls"):
|
||||
return await self._handle_tool_calls(
|
||||
original_prompt=prompt,
|
||||
response=response,
|
||||
provider=provider,
|
||||
model=model,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
system_prompt=system_prompt,
|
||||
tool_choice=tool_choice,
|
||||
max_rounds=mcp_max_rounds,
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
async def generate_text_stream(
|
||||
self,
|
||||
@@ -110,15 +392,51 @@ class AIService:
|
||||
temperature: Optional[float] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
system_prompt: Optional[str] = None,
|
||||
tool_choice: Optional[str] = None,
|
||||
auto_mcp: bool = True,
|
||||
mcp_max_rounds: Optional[int] = None,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""流式生成"""
|
||||
"""
|
||||
流式生成文本(自动支持MCP工具)
|
||||
|
||||
工具调用在 Provider 层通过流式方式处理,支持真正的流式工具调用。
|
||||
|
||||
Args:
|
||||
prompt: 用户提示词
|
||||
provider: AI提供商
|
||||
model: 模型名称
|
||||
temperature: 温度
|
||||
max_tokens: 最大令牌数
|
||||
system_prompt: 系统提示词
|
||||
tool_choice: 工具选择策略("auto"/"none"/"required")
|
||||
auto_mcp: 是否自动加载MCP工具
|
||||
mcp_max_rounds: 最大工具调用轮数(None使用默认值3)
|
||||
|
||||
Yields:
|
||||
生成的文本块
|
||||
"""
|
||||
logger.debug(f"🔧 generate_text_stream: auto_mcp={auto_mcp}, tool_choice={tool_choice}")
|
||||
|
||||
tools_to_use = None
|
||||
|
||||
# 加载MCP工具
|
||||
if auto_mcp:
|
||||
tools_to_use = await self._prepare_mcp_tools(auto_mcp=auto_mcp)
|
||||
if tools_to_use:
|
||||
logger.info(f"🔧 已获取 {len(tools_to_use)} 个MCP工具")
|
||||
|
||||
# 流式生成(Provider 层处理工具调用)
|
||||
prov = self._get_provider(provider)
|
||||
logger.debug(f"🔧 开始流式生成,provider={provider or self.api_provider}, tools_count={len(tools_to_use) if tools_to_use else 0}")
|
||||
async for chunk in prov.generate_stream(
|
||||
prompt=prompt,
|
||||
model=model or self.default_model,
|
||||
temperature=temperature or self.default_temperature,
|
||||
max_tokens=max_tokens or self.default_max_tokens,
|
||||
system_prompt=system_prompt or self.default_system_prompt,
|
||||
tools=tools_to_use,
|
||||
tool_choice=tool_choice,
|
||||
user_id=self.user_id,
|
||||
):
|
||||
yield chunk
|
||||
|
||||
@@ -132,8 +450,25 @@ class AIService:
|
||||
provider: Optional[str] = None,
|
||||
model: Optional[str] = None,
|
||||
expected_type: Optional[str] = None,
|
||||
auto_mcp: bool = True,
|
||||
) -> Union[Dict, List]:
|
||||
"""带重试的 JSON 调用"""
|
||||
"""
|
||||
带重试的 JSON 调用(自动支持MCP工具)
|
||||
|
||||
Args:
|
||||
prompt: 用户提示词
|
||||
system_prompt: 系统提示词
|
||||
max_retries: 最大重试次数
|
||||
temperature: 温度
|
||||
max_tokens: 最大令牌数
|
||||
provider: AI提供商
|
||||
model: 模型名称
|
||||
expected_type: 期望的返回类型("object"或"array")
|
||||
auto_mcp: 是否自动加载MCP工具
|
||||
|
||||
Returns:
|
||||
解析后的JSON数据
|
||||
"""
|
||||
last_response = ""
|
||||
|
||||
for attempt in range(1, max_retries + 1):
|
||||
@@ -146,6 +481,8 @@ class AIService:
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
system_prompt=system_prompt,
|
||||
auto_mcp=auto_mcp,
|
||||
handle_tool_calls=True,
|
||||
)
|
||||
|
||||
last_response = result.get("content", "")
|
||||
@@ -172,108 +509,6 @@ class AIService:
|
||||
"""清洗 JSON 响应"""
|
||||
return clean_json_response(text)
|
||||
|
||||
async def generate_text_with_mcp(
|
||||
self,
|
||||
prompt: str,
|
||||
user_id: str,
|
||||
db_session,
|
||||
enable_mcp: bool = True,
|
||||
max_tool_rounds: int = 3,
|
||||
tool_choice: str = "auto",
|
||||
**kwargs
|
||||
) -> Dict[str, Any]:
|
||||
"""支持MCP工具的AI文本生成"""
|
||||
from app.services.mcp_tool_service import mcp_tool_service, MCPToolServiceError
|
||||
|
||||
result = {"content": "", "tool_calls_made": 0, "tools_used": [], "finish_reason": "", "mcp_enhanced": False}
|
||||
tools = None
|
||||
|
||||
if enable_mcp:
|
||||
try:
|
||||
tools = await mcp_tool_service.get_user_enabled_tools(user_id=user_id, db_session=db_session)
|
||||
if tools:
|
||||
result["mcp_enhanced"] = True
|
||||
except MCPToolServiceError:
|
||||
tools = None
|
||||
|
||||
original_prompt = prompt # 保存原始提示词
|
||||
|
||||
for round_num in range(max_tool_rounds):
|
||||
logger.debug(f"🔄 MCP工具调用 - 第{round_num+1}/{max_tool_rounds}轮")
|
||||
logger.debug(f" prompt长度: {len(prompt)}, tools数量: {len(tools) if tools else 0}, tool_choice: {tool_choice}")
|
||||
|
||||
ai_response = await self.generate_text(prompt=prompt, tools=tools, tool_choice=tool_choice, **kwargs)
|
||||
logger.debug(f" AI响应: finish_reason={ai_response.get('finish_reason')}, content长度={len(ai_response.get('content', ''))}")
|
||||
|
||||
tool_calls = ai_response.get("tool_calls", [])
|
||||
|
||||
if not tool_calls:
|
||||
content = ai_response.get("content", "")
|
||||
result["content"] = content
|
||||
result["finish_reason"] = ai_response.get("finish_reason", "stop")
|
||||
logger.debug(f" ✅ 无工具调用,返回内容长度: {len(content)}")
|
||||
|
||||
# 🔧 修复:如果内容为空且已经调用过工具,强制要求AI给出答案
|
||||
if not content.strip() and result["tool_calls_made"] > 0:
|
||||
logger.warning(f"⚠️ AI在工具调用后返回空内容,尝试强制要求回答(第{round_num+1}轮)")
|
||||
prompt = f"{prompt}\n\n⚠️ 请注意:你必须基于以上工具查询结果,给出完整的回答。不要返回空内容。"
|
||||
tools = None
|
||||
tool_choice = "none" # 强制不使用工具
|
||||
continue
|
||||
|
||||
break
|
||||
|
||||
logger.info(f"🔧 检测到 {len(tool_calls)} 个工具调用")
|
||||
for idx, tc in enumerate(tool_calls):
|
||||
logger.debug(f" 工具{idx+1}: {tc.get('function', {}).get('name')} - 参数: {tc.get('function', {}).get('arguments')}")
|
||||
|
||||
try:
|
||||
logger.debug(f" 开始执行工具调用...")
|
||||
tool_results = await mcp_tool_service.execute_tool_calls(user_id=user_id, tool_calls=tool_calls, db_session=db_session)
|
||||
logger.debug(f" 工具执行完成,结果数量: {len(tool_results)}")
|
||||
|
||||
# 🔍 检查工具结果
|
||||
for idx, tr in enumerate(tool_results):
|
||||
success = tr.get("success", False)
|
||||
content_preview = tr.get("content", "")[:200] if tr.get("content") else "None"
|
||||
logger.debug(f" 工具结果[{idx}]: success={success}, content预览={content_preview}")
|
||||
|
||||
for tc in tool_calls:
|
||||
name = tc["function"]["name"]
|
||||
if name not in result["tools_used"]:
|
||||
result["tools_used"].append(name)
|
||||
result["tool_calls_made"] += len(tool_calls)
|
||||
|
||||
tool_context = await mcp_tool_service.build_tool_context(tool_results, format="markdown")
|
||||
logger.debug(f" 工具上下文长度: {len(tool_context)}")
|
||||
logger.debug(f" 工具上下文预览: {tool_context[:300] if len(tool_context) > 300 else tool_context}")
|
||||
|
||||
# 🔧 改进:在最后一轮时,明确要求AI给出完整答案
|
||||
if round_num == max_tool_rounds - 1:
|
||||
logger.info(f"⚠️ 最后一轮,强制要求AI给出最终答案")
|
||||
prompt = f"{original_prompt}\n\n{tool_context}\n\n⚠️ 重要:这是最后一轮,请基于以上工具查询的参考资料,给出完整详细的最终答案。不要再调用工具。"
|
||||
tool_choice = "none"
|
||||
else:
|
||||
prompt = f"{original_prompt}\n\n{tool_context}\n\n请基于以上工具查询结果,继续完成任务。"
|
||||
logger.debug(f" 新prompt长度: {len(prompt)}")
|
||||
|
||||
tools = None # 工具调用后禁用工具列表,避免重复调用
|
||||
logger.debug(f" ✅ 工具调用成功,准备下一轮")
|
||||
|
||||
except Exception as tool_error:
|
||||
logger.error(f"❌ 工具调用执行失败: {tool_error}", exc_info=True)
|
||||
logger.error(f" 错误类型: {type(tool_error).__name__}")
|
||||
logger.error(f" AI响应内容: {ai_response.get('content', '')[:200]}")
|
||||
result["content"] = ai_response.get("content", "")
|
||||
result["finish_reason"] = "tool_error"
|
||||
break
|
||||
|
||||
return result
|
||||
|
||||
|
||||
# 全局实例
|
||||
ai_service = AIService()
|
||||
|
||||
|
||||
def create_user_ai_service(
|
||||
api_provider: str,
|
||||
@@ -284,7 +519,7 @@ def create_user_ai_service(
|
||||
max_tokens: int,
|
||||
system_prompt: Optional[str] = None,
|
||||
) -> AIService:
|
||||
"""创建用户 AI 服务"""
|
||||
"""创建用户 AI 服务(不带MCP支持)"""
|
||||
return AIService(
|
||||
api_provider=api_provider,
|
||||
api_key=api_key,
|
||||
@@ -293,4 +528,48 @@ def create_user_ai_service(
|
||||
default_temperature=temperature,
|
||||
default_max_tokens=max_tokens,
|
||||
default_system_prompt=system_prompt,
|
||||
)
|
||||
|
||||
|
||||
def create_user_ai_service_with_mcp(
|
||||
api_provider: str,
|
||||
api_key: str,
|
||||
api_base_url: str,
|
||||
model_name: str,
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
user_id: str,
|
||||
db_session,
|
||||
system_prompt: Optional[str] = None,
|
||||
enable_mcp: bool = True,
|
||||
) -> AIService:
|
||||
"""
|
||||
创建支持MCP的用户AI服务
|
||||
|
||||
Args:
|
||||
api_provider: AI提供商
|
||||
api_key: API密钥
|
||||
api_base_url: API基础URL
|
||||
model_name: 模型名称
|
||||
temperature: 温度
|
||||
max_tokens: 最大令牌数
|
||||
user_id: 用户ID(用于加载MCP工具)
|
||||
db_session: 数据库会话
|
||||
system_prompt: 系统提示词
|
||||
enable_mcp: 是否启用MCP工具
|
||||
|
||||
Returns:
|
||||
配置好的AIService实例
|
||||
"""
|
||||
return AIService(
|
||||
api_provider=api_provider,
|
||||
api_key=api_key,
|
||||
api_base_url=api_base_url,
|
||||
default_model=model_name,
|
||||
default_temperature=temperature,
|
||||
default_max_tokens=max_tokens,
|
||||
default_system_prompt=system_prompt,
|
||||
user_id=user_id,
|
||||
db_session=db_session,
|
||||
enable_mcp=enable_mcp,
|
||||
)
|
||||
@@ -269,25 +269,11 @@ class AutoCharacterService:
|
||||
)
|
||||
|
||||
try:
|
||||
# 调用AI分析(使用统一的JSON调用方法)
|
||||
if enable_mcp and user_id:
|
||||
result = await self.ai_service.generate_text_with_mcp(
|
||||
prompt=prompt,
|
||||
user_id=user_id,
|
||||
db_session=db,
|
||||
enable_mcp=True,
|
||||
max_tool_rounds=2
|
||||
)
|
||||
content = result.get("content", "")
|
||||
# 使用统一的JSON清洗方法
|
||||
cleaned = self.ai_service._clean_json_response(content)
|
||||
analysis = json.loads(cleaned)
|
||||
else:
|
||||
# 非MCP调用:使用带自动重试的JSON调用
|
||||
analysis = await self.ai_service.call_with_json_retry(
|
||||
prompt=prompt,
|
||||
max_retries=3
|
||||
)
|
||||
# 使用统一的JSON调用方法(支持自动MCP工具加载)
|
||||
analysis = await self.ai_service.call_with_json_retry(
|
||||
prompt=prompt,
|
||||
max_retries=3,
|
||||
)
|
||||
|
||||
logger.info(f" ✅ AI分析完成: needs_new_characters={analysis.get('needs_new_characters')}")
|
||||
return analysis
|
||||
@@ -364,16 +350,16 @@ class AutoCharacterService:
|
||||
existing_characters=existing_chars_summary + careers_info,
|
||||
plot_context="根据剧情需要引入的新角色",
|
||||
character_specification=json.dumps(spec, ensure_ascii=False, indent=2),
|
||||
mcp_references="" # 暂时不使用MCP增强
|
||||
mcp_references="" # MCP工具通过AI服务自动加载
|
||||
)
|
||||
|
||||
# 调用AI生成(禁用MCP,避免累积超时导致卡死)
|
||||
logger.info(f"🔧 角色详情生成: enable_mcp={enable_mcp}")
|
||||
|
||||
# 调用AI生成
|
||||
try:
|
||||
# 🔧 优化:角色详情生成不使用MCP,只在分析阶段使用MCP
|
||||
# 这样可以减少大量的外部工具调用,避免超时和卡死
|
||||
character_data = await self.ai_service.call_with_json_retry(
|
||||
prompt=prompt,
|
||||
max_retries=2 # 减少重试次数以加快速度
|
||||
max_retries=2, # 减少重试次数以加快速度
|
||||
)
|
||||
|
||||
char_name = character_data.get('name', '未知')
|
||||
|
||||
@@ -292,25 +292,11 @@ class AutoOrganizationService:
|
||||
)
|
||||
|
||||
try:
|
||||
# 调用AI分析(使用统一的JSON调用方法)
|
||||
if enable_mcp and user_id:
|
||||
result = await self.ai_service.generate_text_with_mcp(
|
||||
prompt=prompt,
|
||||
user_id=user_id,
|
||||
db_session=db,
|
||||
enable_mcp=True,
|
||||
max_tool_rounds=2
|
||||
)
|
||||
content = result.get("content", "")
|
||||
# 使用统一的JSON清洗方法
|
||||
cleaned = self.ai_service._clean_json_response(content)
|
||||
analysis = json.loads(cleaned)
|
||||
else:
|
||||
# 非MCP调用:使用带自动重试的JSON调用
|
||||
analysis = await self.ai_service.call_with_json_retry(
|
||||
prompt=prompt,
|
||||
max_retries=3
|
||||
)
|
||||
# 使用统一的JSON调用方法(支持自动MCP工具加载)
|
||||
analysis = await self.ai_service.call_with_json_retry(
|
||||
prompt=prompt,
|
||||
max_retries=3,
|
||||
)
|
||||
|
||||
logger.info(f" ✅ AI分析完成: needs_new_organizations={analysis.get('needs_new_organizations')}")
|
||||
return analysis
|
||||
@@ -362,24 +348,11 @@ class AutoOrganizationService:
|
||||
|
||||
# 调用AI生成(使用统一的JSON调用方法)
|
||||
try:
|
||||
if enable_mcp and user_id:
|
||||
result = await self.ai_service.generate_text_with_mcp(
|
||||
prompt=prompt,
|
||||
user_id=user_id,
|
||||
db_session=db,
|
||||
enable_mcp=True,
|
||||
max_tool_rounds=2
|
||||
)
|
||||
content = result.get("content", "")
|
||||
# 使用统一的JSON清洗方法
|
||||
cleaned = self.ai_service._clean_json_response(content)
|
||||
organization_data = json.loads(cleaned)
|
||||
else:
|
||||
# 非MCP调用:使用带自动重试的JSON调用
|
||||
organization_data = await self.ai_service.call_with_json_retry(
|
||||
prompt=prompt,
|
||||
max_retries=3
|
||||
)
|
||||
# 使用统一的JSON调用方法(支持自动MCP工具加载)
|
||||
organization_data = await self.ai_service.call_with_json_retry(
|
||||
prompt=prompt,
|
||||
max_retries=3,
|
||||
)
|
||||
|
||||
org_name = organization_data.get('name', '未知')
|
||||
logger.info(f" ✅ 组织详情生成成功: {org_name}")
|
||||
|
||||
@@ -1,4 +1,7 @@
|
||||
"""MCP插件测试服务 - 专门处理插件测试逻辑"""
|
||||
"""MCP插件测试服务 - 专门处理插件测试逻辑
|
||||
|
||||
重构后使用统一的MCPClientFacade门面来管理所有MCP操作。
|
||||
"""
|
||||
|
||||
import time
|
||||
import json
|
||||
@@ -10,7 +13,7 @@ from sqlalchemy import select
|
||||
|
||||
from app.models.mcp_plugin import MCPPlugin
|
||||
from app.models.settings import Settings as UserSettings
|
||||
from app.mcp.registry import mcp_registry
|
||||
from app.mcp import mcp_client, MCPPluginConfig # 使用新的统一门面
|
||||
from app.services.ai_service import create_user_ai_service
|
||||
from app.schemas.mcp_plugin import MCPTestResult
|
||||
from app.services.prompt_service import prompt_service
|
||||
@@ -21,7 +24,32 @@ logger = get_logger(__name__)
|
||||
|
||||
|
||||
class MCPTestService:
|
||||
"""MCP插件测试服务(分离的测试逻辑)"""
|
||||
"""MCP插件测试服务(使用统一门面重构)"""
|
||||
|
||||
async def _ensure_plugin_registered(
|
||||
self,
|
||||
plugin: MCPPlugin,
|
||||
user_id: str
|
||||
) -> bool:
|
||||
"""
|
||||
确保插件已注册到统一门面
|
||||
|
||||
Args:
|
||||
plugin: 插件配置
|
||||
user_id: 用户ID
|
||||
|
||||
Returns:
|
||||
是否成功
|
||||
"""
|
||||
if plugin.plugin_type in ("http", "streamable_http", "sse") and plugin.server_url:
|
||||
return await mcp_client.ensure_registered(
|
||||
user_id=user_id,
|
||||
plugin_name=plugin.plugin_name,
|
||||
url=plugin.server_url,
|
||||
plugin_type=plugin.plugin_type,
|
||||
headers=plugin.headers
|
||||
)
|
||||
return False
|
||||
|
||||
async def test_plugin_connection(
|
||||
self,
|
||||
@@ -41,19 +69,18 @@ class MCPTestService:
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# 确保插件已加载
|
||||
if not mcp_registry.get_client(user_id, plugin.plugin_name):
|
||||
success = await mcp_registry.load_plugin(plugin)
|
||||
if not success:
|
||||
return MCPTestResult(
|
||||
success=False,
|
||||
message="插件加载失败",
|
||||
error="无法创建MCP客户端",
|
||||
suggestions=["请检查插件配置", "请确认服务器URL正确"]
|
||||
)
|
||||
# 确保插件已注册
|
||||
registered = await self._ensure_plugin_registered(plugin, user_id)
|
||||
if not registered:
|
||||
return MCPTestResult(
|
||||
success=False,
|
||||
message="插件注册失败",
|
||||
error="无法创建MCP客户端",
|
||||
suggestions=["请检查插件配置", "请确认服务器URL正确"]
|
||||
)
|
||||
|
||||
# 测试连接并获取工具列表
|
||||
test_result = await mcp_registry.test_plugin(user_id, plugin.plugin_name)
|
||||
# 使用统一门面测试连接
|
||||
test_result = await mcp_client.test_connection(user_id, plugin.plugin_name)
|
||||
|
||||
end_time = time.time()
|
||||
response_time = round((end_time - start_time) * 1000, 2)
|
||||
@@ -70,7 +97,18 @@ class MCPTestService:
|
||||
]
|
||||
)
|
||||
else:
|
||||
return MCPTestResult(**test_result)
|
||||
return MCPTestResult(
|
||||
success=False,
|
||||
message="❌ 连接测试失败",
|
||||
response_time_ms=response_time,
|
||||
error=test_result.get("message", "未知错误"),
|
||||
error_type=test_result.get("error_type"),
|
||||
suggestions=[
|
||||
"请检查服务器是否在线",
|
||||
"请确认配置正确",
|
||||
"请检查API Key是否有效"
|
||||
]
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
end_time = time.time()
|
||||
@@ -117,8 +155,8 @@ class MCPTestService:
|
||||
if not connection_result.success:
|
||||
return connection_result
|
||||
|
||||
# 2. 获取工具列表
|
||||
tools = await mcp_registry.get_plugin_tools(user.user_id, plugin.plugin_name)
|
||||
# 2. 使用统一门面获取工具列表
|
||||
tools = await mcp_client.get_tools(user.user_id, plugin.plugin_name)
|
||||
|
||||
if not tools:
|
||||
return MCPTestResult(
|
||||
@@ -162,8 +200,8 @@ class MCPTestService:
|
||||
max_tokens=1000
|
||||
)
|
||||
|
||||
# 转换为OpenAI Function Calling格式
|
||||
openai_tools = self._convert_tools_to_openai_format(tools)
|
||||
# 使用统一门面转换为OpenAI Function Calling格式
|
||||
openai_tools = mcp_client.format_tools_for_openai(tools, plugin.plugin_name)
|
||||
|
||||
logger.info(f"📋 转换后的OpenAI工具数量: {len(openai_tools)}")
|
||||
logger.debug(f"📋 OpenAI工具列表: {[t['function']['name'] for t in openai_tools]}")
|
||||
@@ -175,26 +213,16 @@ class MCPTestService:
|
||||
db=db_session
|
||||
)
|
||||
|
||||
# 注意: generate_text_stream 返回的是异步生成器,但在 tool_choice="required" 模式下
|
||||
# AI服务会直接返回包含 tool_calls 的完整响应,而不是流式chunks
|
||||
# 因此这里需要特殊处理
|
||||
accumulated_text = ""
|
||||
tool_calls = None
|
||||
|
||||
async for chunk in ai_service.generate_text_stream(
|
||||
# 使用 generate_text 进行 Function Calling(非流式)
|
||||
ai_response = await ai_service.generate_text(
|
||||
prompt=prompts["user"],
|
||||
system_prompt=prompts["system"],
|
||||
tools=openai_tools,
|
||||
tool_choice="required"
|
||||
):
|
||||
# 在 function calling 模式下,chunk 可能是字典格式包含 tool_calls
|
||||
if isinstance(chunk, dict):
|
||||
if "tool_calls" in chunk:
|
||||
tool_calls = chunk["tool_calls"]
|
||||
if "content" in chunk:
|
||||
accumulated_text += chunk.get("content", "")
|
||||
else:
|
||||
accumulated_text += chunk
|
||||
tool_choice="auto"
|
||||
)
|
||||
|
||||
accumulated_text = ai_response.get("content", "")
|
||||
tool_calls = ai_response.get("tool_calls")
|
||||
|
||||
# 5. 检查AI是否返回工具调用
|
||||
if not tool_calls:
|
||||
@@ -214,7 +242,7 @@ class MCPTestService:
|
||||
# 6. 解析工具调用
|
||||
tool_call = tool_calls[0]
|
||||
function = tool_call["function"]
|
||||
tool_name = function["name"]
|
||||
tool_name_with_prefix = function["name"]
|
||||
test_arguments = function["arguments"]
|
||||
|
||||
if isinstance(test_arguments, str):
|
||||
@@ -231,17 +259,23 @@ class MCPTestService:
|
||||
tools_count=len(tools)
|
||||
)
|
||||
|
||||
# 解析插件名和工具名
|
||||
try:
|
||||
_, tool_name = mcp_client.parse_function_name(tool_name_with_prefix)
|
||||
except ValueError:
|
||||
tool_name = tool_name_with_prefix
|
||||
|
||||
logger.info(f"🤖 AI选择的工具: {tool_name}")
|
||||
logger.info(f"📝 AI生成的参数: {test_arguments}")
|
||||
|
||||
# 7. 调用MCP工具
|
||||
# 7. 使用统一门面调用MCP工具
|
||||
call_start = time.time()
|
||||
try:
|
||||
tool_result = await mcp_registry.call_tool(
|
||||
user.user_id,
|
||||
plugin.plugin_name,
|
||||
tool_name,
|
||||
test_arguments
|
||||
tool_result = await mcp_client.call_tool(
|
||||
user_id=user.user_id,
|
||||
plugin_name=plugin.plugin_name,
|
||||
tool_name=tool_name,
|
||||
arguments=test_arguments
|
||||
)
|
||||
|
||||
call_end = time.time()
|
||||
@@ -307,22 +341,6 @@ class MCPTestService:
|
||||
"请检查API Key是否有效"
|
||||
]
|
||||
)
|
||||
|
||||
def _convert_tools_to_openai_format(self, tools: list) -> list:
|
||||
"""将MCP工具格式转换为OpenAI Function Calling格式"""
|
||||
openai_tools = []
|
||||
for tool in tools:
|
||||
openai_tool = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool["name"],
|
||||
"description": tool.get("description", ""),
|
||||
}
|
||||
}
|
||||
if "inputSchema" in tool:
|
||||
openai_tool["function"]["parameters"] = tool["inputSchema"]
|
||||
openai_tools.append(openai_tool)
|
||||
return openai_tools
|
||||
|
||||
|
||||
# 全局单例
|
||||
|
||||
@@ -1,691 +0,0 @@
|
||||
"""MCP工具服务 - 统一管理MCP工具的注入和执行"""
|
||||
|
||||
from typing import List, Dict, Any, Optional
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
from datetime import datetime, timedelta
|
||||
from dataclasses import dataclass, field
|
||||
from collections import defaultdict
|
||||
|
||||
from app.models.mcp_plugin import MCPPlugin
|
||||
from app.mcp.registry import mcp_registry
|
||||
from app.mcp.config import mcp_config
|
||||
from app.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolMetrics:
|
||||
"""工具调用指标"""
|
||||
total_calls: int = 0
|
||||
success_calls: int = 0
|
||||
failed_calls: int = 0
|
||||
total_duration_ms: float = 0.0
|
||||
avg_duration_ms: float = 0.0
|
||||
last_call_time: Optional[datetime] = None
|
||||
|
||||
def update_success(self, duration_ms: float):
|
||||
"""更新成功调用指标"""
|
||||
self.total_calls += 1
|
||||
self.success_calls += 1
|
||||
self.total_duration_ms += duration_ms
|
||||
self.avg_duration_ms = self.total_duration_ms / self.total_calls
|
||||
self.last_call_time = datetime.now()
|
||||
|
||||
def update_failure(self, duration_ms: float):
|
||||
"""更新失败调用指标"""
|
||||
self.total_calls += 1
|
||||
self.failed_calls += 1
|
||||
self.total_duration_ms += duration_ms
|
||||
self.avg_duration_ms = self.total_duration_ms / self.total_calls
|
||||
self.last_call_time = datetime.now()
|
||||
|
||||
@property
|
||||
def success_rate(self) -> float:
|
||||
"""成功率"""
|
||||
if self.total_calls == 0:
|
||||
return 0.0
|
||||
return self.success_calls / self.total_calls
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolCacheEntry:
|
||||
"""工具缓存条目"""
|
||||
tools: List[Dict[str, Any]]
|
||||
expire_time: datetime
|
||||
hit_count: int = 0
|
||||
|
||||
|
||||
class MCPToolServiceError(Exception):
|
||||
"""MCP工具服务异常"""
|
||||
pass
|
||||
|
||||
|
||||
class MCPToolService:
|
||||
"""MCP工具服务 - 统一管理MCP工具的注入和执行(优化版)"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cache_ttl_minutes: Optional[int] = None,
|
||||
max_retries: Optional[int] = None
|
||||
):
|
||||
"""
|
||||
初始化MCP工具服务
|
||||
|
||||
Args:
|
||||
cache_ttl_minutes: 工具缓存TTL(分钟,默认使用配置)
|
||||
max_retries: 最大重试次数(默认使用配置)
|
||||
"""
|
||||
# 工具定义缓存: {cache_key: ToolCacheEntry}
|
||||
self._tool_cache: Dict[str, ToolCacheEntry] = {}
|
||||
self._cache_ttl = timedelta(
|
||||
minutes=cache_ttl_minutes or mcp_config.TOOL_CACHE_TTL_MINUTES
|
||||
)
|
||||
|
||||
# 调用指标: {tool_key: ToolMetrics}
|
||||
self._metrics: Dict[str, ToolMetrics] = defaultdict(ToolMetrics)
|
||||
|
||||
# 重试配置(使用配置常量)
|
||||
self._max_retries = max_retries or mcp_config.MAX_RETRIES
|
||||
self._base_retry_delay = mcp_config.BASE_RETRY_DELAY_SECONDS
|
||||
self._max_retry_delay = mcp_config.MAX_RETRY_DELAY_SECONDS
|
||||
|
||||
logger.info(
|
||||
f"✅ MCPToolService初始化完成 "
|
||||
f"(缓存TTL={self._cache_ttl.total_seconds()/60:.1f}分钟, "
|
||||
f"最大重试={self._max_retries}次)"
|
||||
)
|
||||
|
||||
async def get_user_enabled_tools(
|
||||
self,
|
||||
user_id: str,
|
||||
db_session: AsyncSession,
|
||||
category: Optional[str] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取用户启用的MCP工具列表
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
db_session: 数据库会话
|
||||
category: 工具类别筛选(search/analysis/filesystem等)
|
||||
|
||||
Returns:
|
||||
工具定义列表,格式符合OpenAI Function Calling规范
|
||||
"""
|
||||
try:
|
||||
# 1. 查询用户启用的插件(enabled=True即可,不强制要求status=active)
|
||||
# 因为新启用的插件status可能还是inactive,需要给它机会被调用
|
||||
query = select(MCPPlugin).where(
|
||||
MCPPlugin.user_id == user_id,
|
||||
MCPPlugin.enabled == True
|
||||
)
|
||||
|
||||
if category:
|
||||
query = query.where(MCPPlugin.category == category)
|
||||
|
||||
result = await db_session.execute(query)
|
||||
plugins = result.scalars().all()
|
||||
|
||||
if not plugins:
|
||||
logger.info(f"用户 {user_id} 没有启用的MCP插件")
|
||||
return []
|
||||
|
||||
# 2. 获取所有工具定义(使用缓存)
|
||||
all_tools = []
|
||||
for plugin in plugins:
|
||||
try:
|
||||
# 确保插件已加载到注册表
|
||||
if not mcp_registry.get_client(user_id, plugin.plugin_name):
|
||||
logger.info(f"插件 {plugin.plugin_name} 未加载,尝试加载...")
|
||||
success = await mcp_registry.load_plugin(plugin)
|
||||
if not success:
|
||||
logger.warning(f"插件 {plugin.plugin_name} 加载失败,跳过")
|
||||
continue
|
||||
|
||||
# ✅ 使用缓存获取工具列表
|
||||
plugin_tools = await self._get_plugin_tools_cached(
|
||||
user_id=user_id,
|
||||
plugin_name=plugin.plugin_name
|
||||
)
|
||||
|
||||
# 格式化为Function Calling格式
|
||||
formatted_tools = self._format_tools_for_ai(
|
||||
plugin_tools,
|
||||
plugin.plugin_name
|
||||
)
|
||||
all_tools.extend(formatted_tools)
|
||||
|
||||
logger.info(
|
||||
f"从插件 {plugin.plugin_name} 加载了 "
|
||||
f"{len(formatted_tools)} 个工具"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"获取插件 {plugin.plugin_name} 的工具失败: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
continue
|
||||
|
||||
logger.info(f"用户 {user_id} 共加载 {len(all_tools)} 个MCP工具")
|
||||
return all_tools
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取用户MCP工具失败: {e}", exc_info=True)
|
||||
raise MCPToolServiceError(f"获取MCP工具失败: {str(e)}")
|
||||
|
||||
def _format_tools_for_ai(
|
||||
self,
|
||||
plugin_tools: List[Dict[str, Any]],
|
||||
plugin_name: str
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
将MCP工具定义格式化为AI Function Calling格式
|
||||
|
||||
Args:
|
||||
plugin_tools: MCP插件的工具列表
|
||||
plugin_name: 插件名称
|
||||
|
||||
Returns:
|
||||
格式化后的工具列表
|
||||
"""
|
||||
formatted_tools = []
|
||||
|
||||
for tool in plugin_tools:
|
||||
formatted_tool = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": f"{plugin_name}_{tool['name']}", # 加插件前缀避免冲突
|
||||
"description": tool.get("description", ""),
|
||||
"parameters": tool.get("inputSchema", {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": []
|
||||
})
|
||||
}
|
||||
}
|
||||
formatted_tools.append(formatted_tool)
|
||||
|
||||
return formatted_tools
|
||||
|
||||
async def _get_plugin_tools_cached(
|
||||
self,
|
||||
user_id: str,
|
||||
plugin_name: str
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
带缓存的工具列表获取
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
plugin_name: 插件名称
|
||||
|
||||
Returns:
|
||||
工具列表
|
||||
"""
|
||||
cache_key = f"{user_id}:{plugin_name}"
|
||||
now = datetime.now()
|
||||
|
||||
# 检查缓存
|
||||
if cache_key in self._tool_cache:
|
||||
entry = self._tool_cache[cache_key]
|
||||
if now < entry.expire_time:
|
||||
entry.hit_count += 1
|
||||
logger.debug(
|
||||
f"🎯 工具缓存命中: {cache_key} "
|
||||
f"(命中次数: {entry.hit_count})"
|
||||
)
|
||||
return entry.tools
|
||||
else:
|
||||
logger.debug(f"⏰ 工具缓存过期: {cache_key}")
|
||||
del self._tool_cache[cache_key]
|
||||
|
||||
# 缓存未命中,从MCP获取
|
||||
logger.debug(f"🔍 工具缓存未命中,从MCP获取: {cache_key}")
|
||||
tools = await mcp_registry.get_plugin_tools(user_id, plugin_name)
|
||||
|
||||
# 更新缓存
|
||||
self._tool_cache[cache_key] = ToolCacheEntry(
|
||||
tools=tools,
|
||||
expire_time=now + self._cache_ttl,
|
||||
hit_count=0
|
||||
)
|
||||
|
||||
return tools
|
||||
|
||||
def clear_cache(self, user_id: Optional[str] = None, plugin_name: Optional[str] = None):
|
||||
"""
|
||||
清理缓存
|
||||
|
||||
Args:
|
||||
user_id: 用户ID(可选,清理特定用户的缓存)
|
||||
plugin_name: 插件名称(可选,清理特定插件的缓存)
|
||||
"""
|
||||
if user_id is None and plugin_name is None:
|
||||
# 清理所有缓存
|
||||
self._tool_cache.clear()
|
||||
logger.info("🧹 已清理所有工具缓存")
|
||||
elif user_id and plugin_name:
|
||||
# 清理特定插件缓存
|
||||
cache_key = f"{user_id}:{plugin_name}"
|
||||
if cache_key in self._tool_cache:
|
||||
del self._tool_cache[cache_key]
|
||||
logger.info(f"🧹 已清理缓存: {cache_key}")
|
||||
elif user_id:
|
||||
# 清理用户所有缓存
|
||||
keys_to_delete = [
|
||||
key for key in self._tool_cache.keys()
|
||||
if key.startswith(f"{user_id}:")
|
||||
]
|
||||
for key in keys_to_delete:
|
||||
del self._tool_cache[key]
|
||||
logger.info(f"🧹 已清理用户缓存: {user_id} ({len(keys_to_delete)}个)")
|
||||
|
||||
async def execute_tool_calls(
|
||||
self,
|
||||
user_id: str,
|
||||
tool_calls: List[Dict[str, Any]],
|
||||
db_session: AsyncSession,
|
||||
timeout: Optional[float] = None,
|
||||
max_concurrent: int = 2
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
批量执行AI请求的工具调用(限制并发数,避免超时)
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
tool_calls: AI返回的工具调用列表
|
||||
db_session: 数据库会话
|
||||
timeout: 单个工具调用的超时时间(秒,默认使用配置)
|
||||
max_concurrent: 最大并发工具调用数(默认2)
|
||||
|
||||
Returns:
|
||||
工具调用结果列表
|
||||
"""
|
||||
if not tool_calls:
|
||||
return []
|
||||
|
||||
# 使用配置的默认超时
|
||||
actual_timeout = timeout or mcp_config.TOOL_CALL_TIMEOUT_SECONDS
|
||||
|
||||
logger.info(f"开始执行 {len(tool_calls)} 个工具调用 (超时={actual_timeout}s, 最大并发={max_concurrent})")
|
||||
|
||||
# ✅ 分批执行,每批最多max_concurrent个
|
||||
all_results = []
|
||||
for i in range(0, len(tool_calls), max_concurrent):
|
||||
batch = tool_calls[i:i+max_concurrent]
|
||||
batch_num = i // max_concurrent + 1
|
||||
total_batches = (len(tool_calls) + max_concurrent - 1) // max_concurrent
|
||||
|
||||
logger.info(f"执行工具批次 {batch_num}/{total_batches}, 数量: {len(batch)}")
|
||||
|
||||
# 创建当前批次的异步任务
|
||||
tasks = [
|
||||
self._execute_single_tool(
|
||||
user_id=user_id,
|
||||
tool_call=tool_call,
|
||||
db_session=db_session,
|
||||
timeout=actual_timeout
|
||||
)
|
||||
for tool_call in batch
|
||||
]
|
||||
|
||||
# 并行执行当前批次
|
||||
batch_results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# 处理批次结果
|
||||
for j, result in enumerate(batch_results):
|
||||
tool_call = batch[j]
|
||||
|
||||
if isinstance(result, Exception):
|
||||
# 工具调用异常
|
||||
all_results.append({
|
||||
"tool_call_id": tool_call.get("id", f"call_{i+j}"),
|
||||
"role": "tool",
|
||||
"name": tool_call["function"]["name"],
|
||||
"content": f"工具调用失败: {str(result)}",
|
||||
"success": False,
|
||||
"error": str(result)
|
||||
})
|
||||
else:
|
||||
all_results.append(result)
|
||||
|
||||
# 批次间增加短暂延迟,避免API限流
|
||||
if i + max_concurrent < len(tool_calls):
|
||||
await asyncio.sleep(0.5)
|
||||
logger.debug(f"批次间延迟 0.5 秒...")
|
||||
|
||||
return all_results
|
||||
|
||||
async def _execute_single_tool(
|
||||
self,
|
||||
user_id: str,
|
||||
tool_call: Dict[str, Any],
|
||||
db_session: AsyncSession,
|
||||
timeout: float
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
执行单个工具调用
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
tool_call: 工具调用信息
|
||||
db_session: 数据库会话
|
||||
timeout: 超时时间
|
||||
|
||||
Returns:
|
||||
工具调用结果
|
||||
"""
|
||||
tool_call_id = tool_call.get("id", "unknown")
|
||||
function_name = tool_call["function"]["name"]
|
||||
|
||||
try:
|
||||
# 解析插件名和工具名
|
||||
logger.debug(f"🔍 解析工具名称: {function_name}")
|
||||
if "_" in function_name:
|
||||
plugin_name, tool_name = function_name.split("_", 1)
|
||||
logger.debug(f" 插件: {plugin_name}, 工具: {tool_name}")
|
||||
else:
|
||||
raise ValueError(f"无效的工具名称格式: {function_name}")
|
||||
|
||||
# 解析参数
|
||||
arguments_str = tool_call["function"]["arguments"]
|
||||
logger.debug(f"🔍 解析参数:")
|
||||
logger.debug(f" 原始类型: {type(arguments_str)}")
|
||||
logger.debug(f" 原始内容: {arguments_str}")
|
||||
|
||||
if isinstance(arguments_str, str):
|
||||
try:
|
||||
arguments = json.loads(arguments_str)
|
||||
logger.debug(f" ✅ JSON解析成功: {arguments}")
|
||||
except json.JSONDecodeError as je:
|
||||
logger.error(f" ❌ JSON解析失败: {je}")
|
||||
logger.error(f" 原始字符串: '{arguments_str}'")
|
||||
raise ValueError(f"参数JSON解析失败: {je}")
|
||||
else:
|
||||
arguments = arguments_str
|
||||
logger.debug(f" 直接使用dict类型参数")
|
||||
|
||||
logger.info(
|
||||
f"执行工具: {plugin_name}.{tool_name}, "
|
||||
f"参数: {arguments}"
|
||||
)
|
||||
|
||||
# ✅ 使用带重试的调用
|
||||
tool_key = f"{plugin_name}.{tool_name}"
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
result = await self._call_tool_with_retry(
|
||||
user_id=user_id,
|
||||
plugin_name=plugin_name,
|
||||
tool_name=tool_name,
|
||||
arguments=arguments,
|
||||
timeout=timeout
|
||||
)
|
||||
|
||||
# 记录成功指标
|
||||
duration_ms = (time.time() - start_time) * 1000
|
||||
self._metrics[tool_key].update_success(duration_ms)
|
||||
|
||||
logger.info(
|
||||
f"✅ 工具调用成功: {tool_key} "
|
||||
f"(耗时: {duration_ms:.2f}ms)"
|
||||
)
|
||||
|
||||
# 成功返回
|
||||
return {
|
||||
"tool_call_id": tool_call_id,
|
||||
"role": "tool",
|
||||
"name": function_name,
|
||||
"content": json.dumps(result, ensure_ascii=False),
|
||||
"success": True,
|
||||
"error": None
|
||||
}
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
# 记录失败指标
|
||||
duration_ms = (time.time() - start_time) * 1000
|
||||
self._metrics[tool_key].update_failure(duration_ms)
|
||||
raise MCPToolServiceError(
|
||||
f"工具调用超时(>{timeout}秒)"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
# 记录失败指标
|
||||
tool_key = f"{plugin_name}.{tool_name}" if 'plugin_name' in locals() else function_name
|
||||
duration_ms = (time.time() - start_time) * 1000
|
||||
self._metrics[tool_key].update_failure(duration_ms)
|
||||
|
||||
logger.error(
|
||||
f"❌ 工具 {function_name} 调用失败: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
return {
|
||||
"tool_call_id": tool_call_id,
|
||||
"role": "tool",
|
||||
"name": function_name,
|
||||
"content": f"工具调用失败: {str(e)}",
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
async def _call_tool_with_retry(
|
||||
self,
|
||||
user_id: str,
|
||||
plugin_name: str,
|
||||
tool_name: str,
|
||||
arguments: Dict[str, Any],
|
||||
timeout: float
|
||||
) -> Any:
|
||||
"""
|
||||
带指数退避重试的工具调用
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
plugin_name: 插件名称
|
||||
tool_name: 工具名称
|
||||
arguments: 工具参数
|
||||
timeout: 超时时间
|
||||
|
||||
Returns:
|
||||
工具执行结果
|
||||
|
||||
Raises:
|
||||
MCPToolServiceError: 工具调用失败
|
||||
asyncio.TimeoutError: 调用超时
|
||||
"""
|
||||
last_exception = None
|
||||
|
||||
for attempt in range(self._max_retries):
|
||||
try:
|
||||
# 尝试调用工具
|
||||
result = await asyncio.wait_for(
|
||||
mcp_registry.call_tool(
|
||||
user_id=user_id,
|
||||
plugin_name=plugin_name,
|
||||
tool_name=tool_name,
|
||||
arguments=arguments
|
||||
),
|
||||
timeout=timeout
|
||||
)
|
||||
|
||||
# 成功则返回
|
||||
if attempt > 0:
|
||||
logger.info(
|
||||
f"✅ 重试成功: {plugin_name}.{tool_name} "
|
||||
f"(第{attempt + 1}次尝试)"
|
||||
)
|
||||
return result
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
# 超时不重试,直接抛出
|
||||
raise
|
||||
|
||||
except Exception as e:
|
||||
last_exception = e
|
||||
|
||||
# 最后一次尝试失败
|
||||
if attempt == self._max_retries - 1:
|
||||
logger.error(
|
||||
f"❌ 重试失败: {plugin_name}.{tool_name} "
|
||||
f"(已尝试{self._max_retries}次): {e}"
|
||||
)
|
||||
raise MCPToolServiceError(
|
||||
f"工具调用失败(已重试{self._max_retries}次): {str(e)}"
|
||||
)
|
||||
|
||||
# 计算指数退避延迟
|
||||
delay = min(
|
||||
self._base_retry_delay * (2 ** attempt),
|
||||
self._max_retry_delay
|
||||
)
|
||||
|
||||
logger.warning(
|
||||
f"⚠️ 工具调用失败,{delay:.1f}秒后重试 "
|
||||
f"(第{attempt + 1}/{self._max_retries}次): "
|
||||
f"{plugin_name}.{tool_name} - {e}"
|
||||
)
|
||||
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
# 理论上不会到这里,但为了类型安全
|
||||
raise MCPToolServiceError(f"工具调用失败: {last_exception}")
|
||||
|
||||
def get_metrics(self, tool_name: Optional[str] = None) -> Dict[str, Dict[str, Any]]:
|
||||
"""
|
||||
获取工具调用指标
|
||||
|
||||
Args:
|
||||
tool_name: 工具名称(可选,获取特定工具的指标)
|
||||
|
||||
Returns:
|
||||
指标字典
|
||||
"""
|
||||
if tool_name:
|
||||
if tool_name in self._metrics:
|
||||
metric = self._metrics[tool_name]
|
||||
return {
|
||||
tool_name: {
|
||||
"total_calls": metric.total_calls,
|
||||
"success_calls": metric.success_calls,
|
||||
"failed_calls": metric.failed_calls,
|
||||
"success_rate": metric.success_rate,
|
||||
"avg_duration_ms": round(metric.avg_duration_ms, 2),
|
||||
"last_call_time": metric.last_call_time.isoformat() if metric.last_call_time else None
|
||||
}
|
||||
}
|
||||
return {}
|
||||
|
||||
# 返回所有工具的指标
|
||||
result = {}
|
||||
for tool_key, metric in self._metrics.items():
|
||||
result[tool_key] = {
|
||||
"total_calls": metric.total_calls,
|
||||
"success_calls": metric.success_calls,
|
||||
"failed_calls": metric.failed_calls,
|
||||
"success_rate": round(metric.success_rate, 3),
|
||||
"avg_duration_ms": round(metric.avg_duration_ms, 2),
|
||||
"last_call_time": metric.last_call_time.isoformat() if metric.last_call_time else None
|
||||
}
|
||||
return result
|
||||
|
||||
def get_cache_stats(self) -> Dict[str, Any]:
|
||||
"""获取缓存统计信息"""
|
||||
total_entries = len(self._tool_cache)
|
||||
total_hits = sum(entry.hit_count for entry in self._tool_cache.values())
|
||||
|
||||
return {
|
||||
"total_entries": total_entries,
|
||||
"total_hits": total_hits,
|
||||
"cache_ttl_minutes": self._cache_ttl.total_seconds() / 60,
|
||||
"entries": [
|
||||
{
|
||||
"key": key,
|
||||
"tools_count": len(entry.tools),
|
||||
"hit_count": entry.hit_count,
|
||||
"expire_time": entry.expire_time.isoformat()
|
||||
}
|
||||
for key, entry in self._tool_cache.items()
|
||||
]
|
||||
}
|
||||
|
||||
async def build_tool_context(
|
||||
self,
|
||||
tool_results: List[Dict[str, Any]],
|
||||
format: str = "markdown"
|
||||
) -> str:
|
||||
"""
|
||||
将工具调用结果格式化为上下文文本
|
||||
|
||||
Args:
|
||||
tool_results: 工具调用结果列表
|
||||
format: 输出格式(markdown/json/plain)
|
||||
|
||||
Returns:
|
||||
格式化的上下文字符串
|
||||
"""
|
||||
if not tool_results:
|
||||
return ""
|
||||
|
||||
if format == "markdown":
|
||||
return self._build_markdown_context(tool_results)
|
||||
elif format == "json":
|
||||
return json.dumps(tool_results, ensure_ascii=False, indent=2)
|
||||
else: # plain
|
||||
return self._build_plain_context(tool_results)
|
||||
|
||||
def _build_markdown_context(
|
||||
self,
|
||||
tool_results: List[Dict[str, Any]]
|
||||
) -> str:
|
||||
"""构建Markdown格式的工具上下文"""
|
||||
lines = ["## 🔧 工具调用结果\n"]
|
||||
|
||||
for i, result in enumerate(tool_results, 1):
|
||||
tool_name = result.get("name", "unknown")
|
||||
success = result.get("success", False)
|
||||
content = result.get("content", "")
|
||||
|
||||
status_emoji = "✅" if success else "❌"
|
||||
lines.append(f"### {status_emoji} {i}. {tool_name}\n")
|
||||
|
||||
if success:
|
||||
# 尝试美化JSON内容
|
||||
try:
|
||||
content_obj = json.loads(content)
|
||||
content = json.dumps(content_obj, ensure_ascii=False, indent=2)
|
||||
except:
|
||||
pass
|
||||
lines.append(f"```json\n{content}\n```\n")
|
||||
else:
|
||||
lines.append(f"**错误**: {content}\n")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
def _build_plain_context(
|
||||
self,
|
||||
tool_results: List[Dict[str, Any]]
|
||||
) -> str:
|
||||
"""构建纯文本格式的工具上下文"""
|
||||
lines = ["=== 工具调用结果 ===\n"]
|
||||
|
||||
for i, result in enumerate(tool_results, 1):
|
||||
tool_name = result.get("name", "unknown")
|
||||
success = result.get("success", False)
|
||||
content = result.get("content", "")
|
||||
|
||||
status = "成功" if success else "失败"
|
||||
lines.append(f"{i}. {tool_name} - {status}")
|
||||
lines.append(f" 结果: {content}\n")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
# 全局单例
|
||||
mcp_tool_service = MCPToolService()
|
||||
@@ -0,0 +1,235 @@
|
||||
"""MCP工具加载器 - 统一的工具获取入口
|
||||
|
||||
在AI请求之前,自动检查用户MCP配置并加载可用工具。
|
||||
"""
|
||||
from typing import Optional, List, Dict, Any
|
||||
from datetime import datetime, timedelta
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.logger import get_logger
|
||||
from app.models.mcp_plugin import MCPPlugin
|
||||
from app.mcp import mcp_client
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class UserToolsCache:
|
||||
"""用户工具缓存条目"""
|
||||
tools: Optional[List[Dict[str, Any]]]
|
||||
expire_time: datetime
|
||||
hit_count: int = 0
|
||||
|
||||
|
||||
class MCPToolsLoader:
|
||||
"""
|
||||
MCP工具加载器
|
||||
|
||||
负责:
|
||||
1. 检查用户是否配置并启用了MCP插件
|
||||
2. 从各个启用的插件加载工具列表
|
||||
3. 将工具转换为OpenAI Function Calling格式
|
||||
4. 缓存结果以提升性能
|
||||
"""
|
||||
|
||||
_instance: Optional['MCPToolsLoader'] = None
|
||||
|
||||
def __new__(cls):
|
||||
"""单例模式"""
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
cls._instance._initialized = False
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
# 用户工具缓存: user_id -> UserToolsCache
|
||||
self._cache: Dict[str, UserToolsCache] = {}
|
||||
|
||||
# 缓存TTL(5分钟)
|
||||
self._cache_ttl = timedelta(minutes=5)
|
||||
|
||||
self._initialized = True
|
||||
logger.info("✅ MCPToolsLoader 初始化完成")
|
||||
|
||||
async def has_enabled_plugins(
|
||||
self,
|
||||
user_id: str,
|
||||
db_session: AsyncSession
|
||||
) -> bool:
|
||||
"""
|
||||
检查用户是否有启用的MCP插件
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
db_session: 数据库会话
|
||||
|
||||
Returns:
|
||||
是否有启用的插件
|
||||
"""
|
||||
try:
|
||||
query = select(MCPPlugin.id).where(
|
||||
MCPPlugin.user_id == user_id,
|
||||
MCPPlugin.enabled == True,
|
||||
MCPPlugin.plugin_type.in_(["http", "streamable_http", "sse"])
|
||||
).limit(1)
|
||||
|
||||
result = await db_session.execute(query)
|
||||
return result.scalar() is not None
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"检查用户MCP插件失败: {e}")
|
||||
return False
|
||||
|
||||
async def get_user_tools(
|
||||
self,
|
||||
user_id: str,
|
||||
db_session: AsyncSession,
|
||||
use_cache: bool = True,
|
||||
force_refresh: bool = False
|
||||
) -> Optional[List[Dict[str, Any]]]:
|
||||
"""
|
||||
获取用户的MCP工具列表(OpenAI格式)
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
db_session: 数据库会话
|
||||
use_cache: 是否使用缓存
|
||||
force_refresh: 是否强制刷新
|
||||
|
||||
Returns:
|
||||
- None: 用户未配置或未启用任何MCP插件
|
||||
- []: 有配置但没有可用工具
|
||||
- List[Dict]: OpenAI Function Calling格式的工具列表
|
||||
"""
|
||||
now = datetime.now()
|
||||
|
||||
# 检查缓存
|
||||
if use_cache and not force_refresh and user_id in self._cache:
|
||||
cache_entry = self._cache[user_id]
|
||||
if now < cache_entry.expire_time:
|
||||
cache_entry.hit_count += 1
|
||||
logger.debug(f"🎯 用户工具缓存命中: {user_id} (命中次数: {cache_entry.hit_count})")
|
||||
return cache_entry.tools
|
||||
else:
|
||||
del self._cache[user_id]
|
||||
logger.debug(f"⏰ 用户工具缓存过期: {user_id}")
|
||||
|
||||
# 从数据库加载
|
||||
try:
|
||||
tools = await self._load_user_tools(user_id, db_session)
|
||||
|
||||
# 更新缓存
|
||||
self._cache[user_id] = UserToolsCache(
|
||||
tools=tools,
|
||||
expire_time=now + self._cache_ttl
|
||||
)
|
||||
|
||||
if tools:
|
||||
logger.info(f"🔧 用户 {user_id} 加载了 {len(tools)} 个MCP工具")
|
||||
else:
|
||||
logger.debug(f"📭 用户 {user_id} 没有可用的MCP工具")
|
||||
|
||||
return tools
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 加载用户MCP工具失败: {e}")
|
||||
return None
|
||||
|
||||
async def _load_user_tools(
|
||||
self,
|
||||
user_id: str,
|
||||
db_session: AsyncSession
|
||||
) -> Optional[List[Dict[str, Any]]]:
|
||||
"""
|
||||
从数据库加载用户启用的MCP插件并获取工具
|
||||
"""
|
||||
# 查询启用的插件
|
||||
query = select(MCPPlugin).where(
|
||||
MCPPlugin.user_id == user_id,
|
||||
MCPPlugin.enabled == True,
|
||||
MCPPlugin.plugin_type.in_(["http", "streamable_http", "sse"])
|
||||
).order_by(MCPPlugin.sort_order)
|
||||
|
||||
result = await db_session.execute(query)
|
||||
plugins = result.scalars().all()
|
||||
|
||||
if not plugins:
|
||||
return None
|
||||
|
||||
all_tools = []
|
||||
|
||||
for plugin in plugins:
|
||||
try:
|
||||
# 确定插件类型
|
||||
plugin_type = plugin.plugin_type
|
||||
if plugin_type == "http":
|
||||
plugin_type = "streamable_http" # 默认使用streamable_http
|
||||
|
||||
# 确保插件已注册到MCP客户端
|
||||
await mcp_client.ensure_registered(
|
||||
user_id=user_id,
|
||||
plugin_name=plugin.plugin_name,
|
||||
url=plugin.server_url,
|
||||
plugin_type=plugin_type,
|
||||
headers=plugin.headers
|
||||
)
|
||||
|
||||
# 获取工具列表
|
||||
plugin_tools = await mcp_client.get_tools(user_id, plugin.plugin_name)
|
||||
|
||||
# 转换为OpenAI格式
|
||||
formatted = mcp_client.format_tools_for_openai(plugin_tools, plugin.plugin_name)
|
||||
all_tools.extend(formatted)
|
||||
|
||||
logger.debug(f"✅ 从插件 {plugin.plugin_name} 加载了 {len(formatted)} 个工具")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"⚠️ 加载插件 {plugin.plugin_name} 工具失败: {e}")
|
||||
continue
|
||||
|
||||
return all_tools if all_tools else None
|
||||
|
||||
def invalidate_cache(self, user_id: Optional[str] = None):
|
||||
"""
|
||||
使缓存失效
|
||||
|
||||
Args:
|
||||
user_id: 用户ID,为None时清空所有缓存
|
||||
"""
|
||||
if user_id:
|
||||
if user_id in self._cache:
|
||||
del self._cache[user_id]
|
||||
logger.debug(f"🧹 清理用户工具缓存: {user_id}")
|
||||
else:
|
||||
count = len(self._cache)
|
||||
self._cache.clear()
|
||||
logger.info(f"🧹 清理所有用户工具缓存 ({count}个)")
|
||||
|
||||
def get_cache_stats(self) -> Dict[str, Any]:
|
||||
"""获取缓存统计"""
|
||||
now = datetime.now()
|
||||
return {
|
||||
"total_entries": len(self._cache),
|
||||
"total_hits": sum(e.hit_count for e in self._cache.values()),
|
||||
"cache_ttl_minutes": self._cache_ttl.total_seconds() / 60,
|
||||
"entries": [
|
||||
{
|
||||
"user_id": uid,
|
||||
"tools_count": len(e.tools) if e.tools else 0,
|
||||
"hit_count": e.hit_count,
|
||||
"expired": now >= e.expire_time,
|
||||
"expire_time": e.expire_time.isoformat()
|
||||
}
|
||||
for uid, e in self._cache.items()
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
# 全局单例
|
||||
mcp_tools_loader = MCPToolsLoader()
|
||||
@@ -1,13 +1,231 @@
|
||||
"""Server-Sent Events (SSE) 响应工具类"""
|
||||
import json
|
||||
import asyncio
|
||||
from typing import AsyncGenerator, Dict, Any, Optional
|
||||
from enum import Enum
|
||||
from typing import AsyncGenerator, Dict, Any, Optional, Callable
|
||||
from dataclasses import dataclass
|
||||
from fastapi.responses import StreamingResponse
|
||||
from app.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class ProgressStage(Enum):
|
||||
"""标准化进度阶段枚举"""
|
||||
# 初始化阶段 (0-5%)
|
||||
INIT = "init"
|
||||
# 加载数据阶段 (5-15%)
|
||||
LOADING = "loading"
|
||||
# 准备提示词阶段 (15-20%)
|
||||
PREPARING = "preparing"
|
||||
# AI生成阶段 (20-85%)
|
||||
GENERATING = "generating"
|
||||
# 解析数据阶段 (85-92%)
|
||||
PARSING = "parsing"
|
||||
# 保存数据阶段 (92-98%)
|
||||
SAVING = "saving"
|
||||
# 完成阶段 (100%)
|
||||
COMPLETE = "complete"
|
||||
|
||||
|
||||
@dataclass
|
||||
class StageConfig:
|
||||
"""阶段配置"""
|
||||
start: int # 起始进度
|
||||
end: int # 结束进度
|
||||
default_message: str # 默认消息
|
||||
|
||||
|
||||
# 标准进度阶段配置
|
||||
STAGE_CONFIGS: Dict[ProgressStage, StageConfig] = {
|
||||
ProgressStage.INIT: StageConfig(0, 5, "开始处理..."),
|
||||
ProgressStage.LOADING: StageConfig(5, 15, "加载数据中..."),
|
||||
ProgressStage.PREPARING: StageConfig(15, 20, "准备AI提示词..."),
|
||||
ProgressStage.GENERATING: StageConfig(20, 85, "AI生成中..."),
|
||||
ProgressStage.PARSING: StageConfig(85, 92, "解析数据..."),
|
||||
ProgressStage.SAVING: StageConfig(92, 98, "保存到数据库..."),
|
||||
ProgressStage.COMPLETE: StageConfig(100, 100, "完成!"),
|
||||
}
|
||||
|
||||
|
||||
class WizardProgressTracker:
|
||||
"""
|
||||
向导进度追踪器 - 标准化管理SSE进度推送
|
||||
|
||||
使用示例:
|
||||
tracker = WizardProgressTracker("世界观")
|
||||
yield await tracker.start()
|
||||
yield await tracker.loading("加载项目信息")
|
||||
yield await tracker.preparing()
|
||||
async for chunk in ai_stream:
|
||||
yield await tracker.generating_chunk(chunk, len(accumulated))
|
||||
yield await tracker.parsing()
|
||||
yield await tracker.saving("保存世界观数据")
|
||||
yield await tracker.complete()
|
||||
"""
|
||||
|
||||
def __init__(self, task_name: str = "任务"):
|
||||
"""
|
||||
初始化进度追踪器
|
||||
|
||||
Args:
|
||||
task_name: 任务名称,用于消息前缀
|
||||
"""
|
||||
self.task_name = task_name
|
||||
self.current_stage = ProgressStage.INIT
|
||||
self.current_progress = 0
|
||||
self._last_generating_progress = 20 # 生成阶段的最后进度值
|
||||
|
||||
def _get_stage_progress(
|
||||
self,
|
||||
stage: ProgressStage,
|
||||
sub_progress: float = 0.0
|
||||
) -> int:
|
||||
"""
|
||||
计算阶段内的进度值
|
||||
|
||||
Args:
|
||||
stage: 当前阶段
|
||||
sub_progress: 阶段内子进度 (0.0-1.0)
|
||||
|
||||
Returns:
|
||||
总进度值 (0-100)
|
||||
"""
|
||||
config = STAGE_CONFIGS[stage]
|
||||
if sub_progress <= 0:
|
||||
return config.start
|
||||
if sub_progress >= 1:
|
||||
return config.end
|
||||
return config.start + int((config.end - config.start) * sub_progress)
|
||||
|
||||
async def start(self, message: str = None) -> str:
|
||||
"""开始阶段"""
|
||||
self.current_stage = ProgressStage.INIT
|
||||
self.current_progress = 0
|
||||
msg = message or f"开始生成{self.task_name}..."
|
||||
return await SSEResponse.send_progress(msg, 0, "processing")
|
||||
|
||||
async def loading(self, message: str = None, sub_progress: float = 0.5) -> str:
|
||||
"""加载数据阶段"""
|
||||
self.current_stage = ProgressStage.LOADING
|
||||
progress = self._get_stage_progress(ProgressStage.LOADING, sub_progress)
|
||||
self.current_progress = progress
|
||||
msg = message or STAGE_CONFIGS[ProgressStage.LOADING].default_message
|
||||
return await SSEResponse.send_progress(msg, progress, "processing")
|
||||
|
||||
async def preparing(self, message: str = None) -> str:
|
||||
"""准备提示词阶段"""
|
||||
self.current_stage = ProgressStage.PREPARING
|
||||
progress = self._get_stage_progress(ProgressStage.PREPARING, 0.5)
|
||||
self.current_progress = progress
|
||||
msg = message or STAGE_CONFIGS[ProgressStage.PREPARING].default_message
|
||||
return await SSEResponse.send_progress(msg, progress, "processing")
|
||||
|
||||
async def generating(
|
||||
self,
|
||||
current_chars: int = 0,
|
||||
estimated_total: int = 5000,
|
||||
message: str = None,
|
||||
retry_count: int = 0,
|
||||
max_retries: int = 3
|
||||
) -> str:
|
||||
"""
|
||||
AI生成阶段进度更新
|
||||
|
||||
Args:
|
||||
current_chars: 当前已生成字符数
|
||||
estimated_total: 预估总字符数
|
||||
message: 自定义消息
|
||||
retry_count: 当前重试次数
|
||||
max_retries: 最大重试次数
|
||||
"""
|
||||
self.current_stage = ProgressStage.GENERATING
|
||||
|
||||
# 计算生成进度 (0.0-1.0)
|
||||
sub_progress = min(current_chars / max(estimated_total, 1), 1.0)
|
||||
progress = self._get_stage_progress(ProgressStage.GENERATING, sub_progress)
|
||||
|
||||
# 确保进度单调递增
|
||||
if progress < self._last_generating_progress:
|
||||
progress = self._last_generating_progress
|
||||
else:
|
||||
self._last_generating_progress = progress
|
||||
|
||||
self.current_progress = progress
|
||||
|
||||
# 构建消息
|
||||
retry_suffix = f" (重试 {retry_count}/{max_retries})" if retry_count > 0 else ""
|
||||
if message:
|
||||
msg = f"{message}{retry_suffix}"
|
||||
else:
|
||||
msg = f"生成{self.task_name}中... ({current_chars}字符){retry_suffix}"
|
||||
|
||||
return await SSEResponse.send_progress(msg, progress, "processing")
|
||||
|
||||
async def generating_chunk(self, chunk: str) -> str:
|
||||
"""发送生成的内容块"""
|
||||
return await SSEResponse.send_chunk(chunk)
|
||||
|
||||
async def parsing(self, message: str = None, sub_progress: float = 0.5) -> str:
|
||||
"""解析数据阶段"""
|
||||
self.current_stage = ProgressStage.PARSING
|
||||
progress = self._get_stage_progress(ProgressStage.PARSING, sub_progress)
|
||||
self.current_progress = progress
|
||||
msg = message or f"解析{self.task_name}数据..."
|
||||
return await SSEResponse.send_progress(msg, progress, "processing")
|
||||
|
||||
async def saving(self, message: str = None, sub_progress: float = 0.5) -> str:
|
||||
"""保存数据阶段"""
|
||||
self.current_stage = ProgressStage.SAVING
|
||||
progress = self._get_stage_progress(ProgressStage.SAVING, sub_progress)
|
||||
self.current_progress = progress
|
||||
msg = message or f"保存{self.task_name}到数据库..."
|
||||
return await SSEResponse.send_progress(msg, progress, "processing")
|
||||
|
||||
async def complete(self, message: str = None) -> str:
|
||||
"""完成阶段"""
|
||||
self.current_stage = ProgressStage.COMPLETE
|
||||
self.current_progress = 100
|
||||
msg = message or f"{self.task_name}生成完成!"
|
||||
return await SSEResponse.send_progress(msg, 100, "success")
|
||||
|
||||
async def warning(self, message: str) -> str:
|
||||
"""发送警告消息(保持当前进度)"""
|
||||
return await SSEResponse.send_progress(
|
||||
f"⚠️ {message}",
|
||||
self.current_progress,
|
||||
"warning"
|
||||
)
|
||||
|
||||
async def retry(self, retry_count: int, max_retries: int, reason: str = "准备重试") -> str:
|
||||
"""发送重试消息"""
|
||||
return await SSEResponse.send_progress(
|
||||
f"⚠️ {reason}... ({retry_count}/{max_retries})",
|
||||
self.current_progress,
|
||||
"warning"
|
||||
)
|
||||
|
||||
async def error(self, error_message: str, code: int = 500) -> str:
|
||||
"""发送错误消息"""
|
||||
return await SSEResponse.send_error(error_message, code)
|
||||
|
||||
async def result(self, data: Dict[str, Any]) -> str:
|
||||
"""发送结果数据"""
|
||||
return await SSEResponse.send_result(data)
|
||||
|
||||
async def done(self) -> str:
|
||||
"""发送完成信号"""
|
||||
return await SSEResponse.send_done()
|
||||
|
||||
async def heartbeat(self) -> str:
|
||||
"""发送心跳"""
|
||||
return await SSEResponse.send_heartbeat()
|
||||
|
||||
def reset_generating_progress(self):
|
||||
"""重置生成阶段进度(用于重试时)"""
|
||||
self._last_generating_progress = 20
|
||||
|
||||
|
||||
class SSEResponse:
|
||||
"""SSE响应构建器"""
|
||||
|
||||
|
||||
@@ -19,10 +19,11 @@ anthropic==0.72.0
|
||||
|
||||
# 工具库
|
||||
httpx==0.28.1
|
||||
python-dotenv==1.0.0
|
||||
python-dotenv==1.1.0
|
||||
psutil==6.1.1
|
||||
# MCP官方库(Model Context Protocol Python SDK)
|
||||
mcp==1.21.0
|
||||
mcp==1.22.0
|
||||
fastmcp==2.13.3
|
||||
|
||||
|
||||
# NumPy版本锁定(兼容性要求)
|
||||
|
||||
@@ -37,6 +37,14 @@ interface GenerationSteps {
|
||||
outline: GenerationStep;
|
||||
}
|
||||
|
||||
interface WorldBuildingResult {
|
||||
project_id: string;
|
||||
time_period: string;
|
||||
location: string;
|
||||
atmosphere: string;
|
||||
rules: string;
|
||||
}
|
||||
|
||||
export const AIProjectGenerator: React.FC<AIProjectGeneratorProps> = ({
|
||||
config,
|
||||
storagePrefix,
|
||||
@@ -64,7 +72,7 @@ export const AIProjectGenerator: React.FC<AIProjectGeneratorProps> = ({
|
||||
// 保存生成数据,用于重试
|
||||
const [generationData, setGenerationData] = useState<GenerationConfig | null>(null);
|
||||
// 保存世界观生成结果,用于后续步骤
|
||||
const [worldBuildingResult, setWorldBuildingResult] = useState<any>(null);
|
||||
const [worldBuildingResult, setWorldBuildingResult] = useState<WorldBuildingResult | null>(null);
|
||||
|
||||
// LocalStorage 键名
|
||||
const storageKeys = {
|
||||
@@ -102,6 +110,7 @@ export const AIProjectGenerator: React.FC<AIProjectGeneratorProps> = ({
|
||||
handleAutoGenerate(config);
|
||||
}
|
||||
}
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
}, [config, resumeProjectId]);
|
||||
|
||||
// 恢复未完成项目的生成
|
||||
@@ -125,33 +134,40 @@ export const AIProjectGenerator: React.FC<AIProjectGeneratorProps> = ({
|
||||
const wizardStep = project.wizard_step || 0;
|
||||
|
||||
// 根据wizard_step判断从哪里继续
|
||||
// wizard_step: 0=未开始, 1=世界观已完成, 2=职业体系已完成, 3=角色已完成, 4=大纲已完成
|
||||
// 获取世界观数据(用于后续步骤)
|
||||
const worldResult = {
|
||||
project_id: projectIdParam,
|
||||
time_period: project.world_time_period || '',
|
||||
location: project.world_location || '',
|
||||
atmosphere: project.world_atmosphere || '',
|
||||
rules: project.world_rules || ''
|
||||
};
|
||||
|
||||
if (wizardStep === 0) {
|
||||
// 从世界观开始
|
||||
message.info('从世界观步骤开始生成...');
|
||||
setGenerationSteps({ worldBuilding: 'processing', careers: 'pending', characters: 'pending', outline: 'pending' });
|
||||
await resumeFromWorldBuilding(data);
|
||||
} else if (wizardStep === 1) {
|
||||
// 世界观已完成,从角色开始
|
||||
message.info('世界观已完成,从角色步骤继续...');
|
||||
setGenerationSteps({ worldBuilding: 'completed', careers: 'completed', characters: 'processing', outline: 'pending' });
|
||||
|
||||
// 获取世界观数据
|
||||
const worldResult = {
|
||||
project_id: projectIdParam,
|
||||
time_period: project.world_time_period || '',
|
||||
location: project.world_location || '',
|
||||
atmosphere: project.world_atmosphere || '',
|
||||
rules: project.world_rules || ''
|
||||
};
|
||||
// 世界观已完成,从职业体系开始
|
||||
message.info('世界观已完成,从职业体系步骤继续...');
|
||||
setGenerationSteps({ worldBuilding: 'completed', careers: 'processing', characters: 'pending', outline: 'pending' });
|
||||
setWorldBuildingResult(worldResult);
|
||||
setProgress(33);
|
||||
|
||||
await resumeFromCharacters(data, worldResult);
|
||||
setProgress(20);
|
||||
await resumeFromCareers(data, worldResult);
|
||||
} else if (wizardStep === 2) {
|
||||
// 世界观和角色已完成,从大纲开始
|
||||
message.info('世界观和角色已完成,从大纲步骤继续...');
|
||||
// 职业体系已完成,从角色开始
|
||||
message.info('职业体系已完成,从角色步骤继续...');
|
||||
setGenerationSteps({ worldBuilding: 'completed', careers: 'completed', characters: 'processing', outline: 'pending' });
|
||||
setWorldBuildingResult(worldResult);
|
||||
setProgress(40);
|
||||
await resumeFromCharacters(data, worldResult);
|
||||
} else if (wizardStep === 3) {
|
||||
// 角色已完成,从大纲开始
|
||||
message.info('角色已完成,从大纲步骤继续...');
|
||||
setGenerationSteps({ worldBuilding: 'completed', careers: 'completed', characters: 'completed', outline: 'processing' });
|
||||
setProgress(66);
|
||||
setProgress(70);
|
||||
await resumeFromOutline(data, projectIdParam);
|
||||
} else {
|
||||
// 已全部完成
|
||||
@@ -211,11 +227,47 @@ export const AIProjectGenerator: React.FC<AIProjectGeneratorProps> = ({
|
||||
}
|
||||
);
|
||||
|
||||
await resumeFromCareers(data, worldResult);
|
||||
};
|
||||
|
||||
// 恢复:从职业体系步骤继续
|
||||
const resumeFromCareers = async (data: GenerationConfig, worldResult: WorldBuildingResult) => {
|
||||
const pid = projectId || worldResult.project_id;
|
||||
|
||||
setGenerationSteps(prev => ({ ...prev, careers: 'processing' }));
|
||||
setProgressMessage('正在生成职业体系...');
|
||||
|
||||
await wizardStreamApi.generateCareerSystemStream(
|
||||
{
|
||||
project_id: pid,
|
||||
},
|
||||
{
|
||||
onProgress: (msg, prog) => {
|
||||
setProgress(prog);
|
||||
setProgressMessage(msg);
|
||||
},
|
||||
onResult: (result) => {
|
||||
console.log(`成功生成职业体系:主职业${result.main_careers_count}个,副职业${result.sub_careers_count}个`);
|
||||
setGenerationSteps(prev => ({ ...prev, careers: 'completed' }));
|
||||
},
|
||||
onError: (error) => {
|
||||
console.error('职业体系生成失败:', error);
|
||||
setErrorDetails(`职业体系生成失败: ${error}`);
|
||||
setGenerationSteps(prev => ({ ...prev, careers: 'error' }));
|
||||
setLoading(false);
|
||||
throw new Error(error);
|
||||
},
|
||||
onComplete: () => {
|
||||
console.log('职业体系生成完成');
|
||||
}
|
||||
}
|
||||
);
|
||||
|
||||
await resumeFromCharacters(data, worldResult);
|
||||
};
|
||||
|
||||
// 恢复:从角色步骤继续
|
||||
const resumeFromCharacters = async (data: GenerationConfig, worldResult: any) => {
|
||||
const resumeFromCharacters = async (data: GenerationConfig, worldResult: WorldBuildingResult) => {
|
||||
const genreString = Array.isArray(data.genre) ? data.genre.join('、') : data.genre;
|
||||
const pid = projectId || worldResult.project_id;
|
||||
|
||||
@@ -342,26 +394,11 @@ export const AIProjectGenerator: React.FC<AIProjectGeneratorProps> = ({
|
||||
// 直接使用后端返回的进度值
|
||||
setProgress(prog);
|
||||
setProgressMessage(msg);
|
||||
|
||||
// 检测职业体系生成阶段
|
||||
if (msg.includes('职业体系')) {
|
||||
if (msg.includes('开始') || msg.includes('生成')) {
|
||||
setGenerationSteps(prev => ({
|
||||
...prev,
|
||||
worldBuilding: 'completed',
|
||||
careers: 'processing'
|
||||
}));
|
||||
}
|
||||
if (msg.includes('完成') || msg.includes('✅')) {
|
||||
setGenerationSteps(prev => ({ ...prev, careers: 'completed' }));
|
||||
}
|
||||
}
|
||||
},
|
||||
onResult: (result) => {
|
||||
setProjectId(result.project_id);
|
||||
setWorldBuildingResult(result);
|
||||
setGenerationSteps(prev => ({ ...prev, worldBuilding: 'completed' }));
|
||||
// 职业体系状态已在onProgress中更新
|
||||
},
|
||||
onError: (error) => {
|
||||
console.error('世界观生成失败:', error);
|
||||
@@ -385,7 +422,37 @@ export const AIProjectGenerator: React.FC<AIProjectGeneratorProps> = ({
|
||||
setWorldBuildingResult(worldResult);
|
||||
saveProgress(createdProjectId, data, 'generating');
|
||||
|
||||
// 步骤2: 生成角色
|
||||
// 步骤2: 生成职业体系
|
||||
setGenerationSteps(prev => ({ ...prev, careers: 'processing' }));
|
||||
setProgressMessage('正在生成职业体系...');
|
||||
|
||||
await wizardStreamApi.generateCareerSystemStream(
|
||||
{
|
||||
project_id: createdProjectId,
|
||||
},
|
||||
{
|
||||
onProgress: (msg, prog) => {
|
||||
setProgress(prog);
|
||||
setProgressMessage(msg);
|
||||
},
|
||||
onResult: (result) => {
|
||||
console.log(`成功生成职业体系:主职业${result.main_careers_count}个,副职业${result.sub_careers_count}个`);
|
||||
setGenerationSteps(prev => ({ ...prev, careers: 'completed' }));
|
||||
},
|
||||
onError: (error) => {
|
||||
console.error('职业体系生成失败:', error);
|
||||
setErrorDetails(`职业体系生成失败: ${error}`);
|
||||
setGenerationSteps(prev => ({ ...prev, careers: 'error' }));
|
||||
setLoading(false);
|
||||
throw new Error(error);
|
||||
},
|
||||
onComplete: () => {
|
||||
console.log('职业体系生成完成');
|
||||
}
|
||||
}
|
||||
);
|
||||
|
||||
// 步骤3: 生成角色
|
||||
setGenerationSteps(prev => ({ ...prev, characters: 'processing' }));
|
||||
setProgressMessage('正在生成角色...');
|
||||
|
||||
@@ -497,6 +564,9 @@ export const AIProjectGenerator: React.FC<AIProjectGeneratorProps> = ({
|
||||
if (generationSteps.worldBuilding === 'error') {
|
||||
message.info('从世界观步骤开始重新生成...');
|
||||
await retryFromWorldBuilding();
|
||||
} else if (generationSteps.careers === 'error') {
|
||||
message.info('从职业体系步骤继续生成...');
|
||||
await retryFromCareers();
|
||||
} else if (generationSteps.characters === 'error') {
|
||||
message.info('从角色步骤继续生成...');
|
||||
await retryFromCharacters();
|
||||
@@ -504,9 +574,10 @@ export const AIProjectGenerator: React.FC<AIProjectGeneratorProps> = ({
|
||||
message.info('从大纲步骤继续生成...');
|
||||
await retryFromOutline();
|
||||
}
|
||||
} catch (error: any) {
|
||||
} catch (error) {
|
||||
console.error('智能重试失败:', error);
|
||||
message.error('重试失败:' + (error.message || '未知错误'));
|
||||
const errorMessage = error instanceof Error ? error.message : '未知错误';
|
||||
message.error('重试失败:' + errorMessage);
|
||||
setLoading(false);
|
||||
}
|
||||
};
|
||||
@@ -537,20 +608,6 @@ export const AIProjectGenerator: React.FC<AIProjectGeneratorProps> = ({
|
||||
// 直接使用后端返回的进度值
|
||||
setProgress(prog);
|
||||
setProgressMessage(msg);
|
||||
|
||||
// 检测职业体系生成阶段
|
||||
if (msg.includes('职业体系')) {
|
||||
if (msg.includes('开始') || msg.includes('生成')) {
|
||||
setGenerationSteps(prev => ({
|
||||
...prev,
|
||||
worldBuilding: 'completed',
|
||||
careers: 'processing'
|
||||
}));
|
||||
}
|
||||
if (msg.includes('完成') || msg.includes('✅')) {
|
||||
setGenerationSteps(prev => ({ ...prev, careers: 'completed' }));
|
||||
}
|
||||
}
|
||||
},
|
||||
onResult: (result) => {
|
||||
setProjectId(result.project_id);
|
||||
@@ -574,17 +631,72 @@ export const AIProjectGenerator: React.FC<AIProjectGeneratorProps> = ({
|
||||
throw new Error('项目创建失败:未获取到项目ID');
|
||||
}
|
||||
|
||||
await continueFromCharacters(worldResult);
|
||||
await continueFromCareers(worldResult);
|
||||
};
|
||||
|
||||
// 从职业体系步骤继续
|
||||
const retryFromCareers = async () => {
|
||||
if (!worldBuildingResult) {
|
||||
message.warning('缺少必要数据,无法从职业体系步骤继续');
|
||||
setLoading(false);
|
||||
return;
|
||||
}
|
||||
|
||||
const pid = worldBuildingResult.project_id || projectId;
|
||||
if (!pid) {
|
||||
message.warning('缺少项目ID,无法从职业体系步骤继续');
|
||||
setLoading(false);
|
||||
return;
|
||||
}
|
||||
|
||||
setGenerationSteps(prev => ({ ...prev, careers: 'processing' }));
|
||||
setProgressMessage('重新生成职业体系...');
|
||||
|
||||
await wizardStreamApi.generateCareerSystemStream(
|
||||
{
|
||||
project_id: pid,
|
||||
},
|
||||
{
|
||||
onProgress: (msg, prog) => {
|
||||
setProgress(prog);
|
||||
setProgressMessage(msg);
|
||||
},
|
||||
onResult: (result) => {
|
||||
console.log(`成功生成职业体系:主职业${result.main_careers_count}个,副职业${result.sub_careers_count}个`);
|
||||
setGenerationSteps(prev => ({ ...prev, careers: 'completed' }));
|
||||
},
|
||||
onError: (error) => {
|
||||
console.error('职业体系生成失败:', error);
|
||||
setErrorDetails(`职业体系生成失败: ${error}`);
|
||||
setGenerationSteps(prev => ({ ...prev, careers: 'error' }));
|
||||
setLoading(false);
|
||||
throw new Error(error);
|
||||
},
|
||||
onComplete: () => {
|
||||
console.log('职业体系重新生成完成');
|
||||
}
|
||||
}
|
||||
);
|
||||
|
||||
await continueFromCharacters(worldBuildingResult);
|
||||
};
|
||||
|
||||
// 从角色步骤继续
|
||||
const retryFromCharacters = async () => {
|
||||
if (!generationData || !projectId || !worldBuildingResult) {
|
||||
if (!generationData || !worldBuildingResult) {
|
||||
message.warning('缺少必要数据,无法从角色步骤继续');
|
||||
setLoading(false);
|
||||
return;
|
||||
}
|
||||
|
||||
// 优先使用 worldBuildingResult 中的 project_id,因为重试可能创建了新项目
|
||||
const pid = worldBuildingResult.project_id || projectId;
|
||||
if (!pid) {
|
||||
message.warning('缺少项目ID,无法从角色步骤继续');
|
||||
setLoading(false);
|
||||
return;
|
||||
}
|
||||
|
||||
setGenerationSteps(prev => ({ ...prev, characters: 'processing' }));
|
||||
setProgressMessage('重新生成角色...');
|
||||
|
||||
@@ -592,7 +704,7 @@ export const AIProjectGenerator: React.FC<AIProjectGeneratorProps> = ({
|
||||
|
||||
await wizardStreamApi.generateCharactersStream(
|
||||
{
|
||||
project_id: projectId,
|
||||
project_id: pid,
|
||||
count: generationData.character_count,
|
||||
world_context: {
|
||||
time_period: worldBuildingResult.time_period || '',
|
||||
@@ -626,23 +738,31 @@ export const AIProjectGenerator: React.FC<AIProjectGeneratorProps> = ({
|
||||
}
|
||||
);
|
||||
|
||||
await continueFromOutline();
|
||||
await continueFromOutline(pid);
|
||||
};
|
||||
|
||||
// 从大纲步骤继续
|
||||
const retryFromOutline = async () => {
|
||||
if (!generationData || !projectId) {
|
||||
if (!generationData) {
|
||||
message.warning('缺少必要数据,无法从大纲步骤继续');
|
||||
setLoading(false);
|
||||
return;
|
||||
}
|
||||
|
||||
// 优先使用 worldBuildingResult 中的 project_id,fallback 到状态中的 projectId
|
||||
const pid = (worldBuildingResult?.project_id) || projectId;
|
||||
if (!pid) {
|
||||
message.warning('缺少项目ID,无法从大纲步骤继续');
|
||||
setLoading(false);
|
||||
return;
|
||||
}
|
||||
|
||||
setGenerationSteps(prev => ({ ...prev, outline: 'processing' }));
|
||||
setProgressMessage('重新生成大纲...');
|
||||
|
||||
await wizardStreamApi.generateCompleteOutlineStream(
|
||||
{
|
||||
project_id: projectId,
|
||||
project_id: pid,
|
||||
chapter_count: generationData.chapter_count,
|
||||
narrative_perspective: generationData.narrative_perspective,
|
||||
target_words: generationData.target_words,
|
||||
@@ -676,20 +796,59 @@ export const AIProjectGenerator: React.FC<AIProjectGeneratorProps> = ({
|
||||
setLoading(false);
|
||||
|
||||
// 调用完成回调
|
||||
if (projectId) {
|
||||
onComplete(projectId);
|
||||
if (pid) {
|
||||
onComplete(pid);
|
||||
|
||||
// 延迟1秒后自动跳转到项目详情页
|
||||
setTimeout(() => {
|
||||
navigate(`/project/${projectId}`);
|
||||
navigate(`/project/${pid}`);
|
||||
}, 1000);
|
||||
}
|
||||
};
|
||||
|
||||
// 从角色步骤开始的完整流程
|
||||
const continueFromCharacters = async (worldResult: any) => {
|
||||
// 从职业体系步骤开始的完整流程
|
||||
const continueFromCareers = async (worldResult: WorldBuildingResult) => {
|
||||
if (!generationData || !worldResult?.project_id) return;
|
||||
|
||||
const pid = worldResult.project_id;
|
||||
|
||||
setGenerationSteps(prev => ({ ...prev, careers: 'processing' }));
|
||||
setProgressMessage('正在生成职业体系...');
|
||||
|
||||
await wizardStreamApi.generateCareerSystemStream(
|
||||
{
|
||||
project_id: pid,
|
||||
},
|
||||
{
|
||||
onProgress: (msg, prog) => {
|
||||
setProgress(prog);
|
||||
setProgressMessage(msg);
|
||||
},
|
||||
onResult: (result) => {
|
||||
console.log(`成功生成职业体系:主职业${result.main_careers_count}个,副职业${result.sub_careers_count}个`);
|
||||
setGenerationSteps(prev => ({ ...prev, careers: 'completed' }));
|
||||
},
|
||||
onError: (error) => {
|
||||
console.error('职业体系生成失败:', error);
|
||||
setErrorDetails(`职业体系生成失败: ${error}`);
|
||||
setGenerationSteps(prev => ({ ...prev, careers: 'error' }));
|
||||
setLoading(false);
|
||||
throw new Error(error);
|
||||
},
|
||||
onComplete: () => {
|
||||
console.log('职业体系生成完成');
|
||||
}
|
||||
}
|
||||
);
|
||||
|
||||
await continueFromCharacters(worldResult);
|
||||
};
|
||||
|
||||
// 从角色步骤开始的完整流程
|
||||
const continueFromCharacters = async (worldResult: WorldBuildingResult) => {
|
||||
if (!generationData || !worldResult?.project_id) return;
|
||||
|
||||
const pid = worldResult.project_id;
|
||||
const genreString = Array.isArray(generationData.genre) ? generationData.genre.join('、') : generationData.genre;
|
||||
|
||||
setGenerationSteps(prev => ({ ...prev, characters: 'processing' }));
|
||||
@@ -697,7 +856,7 @@ export const AIProjectGenerator: React.FC<AIProjectGeneratorProps> = ({
|
||||
|
||||
await wizardStreamApi.generateCharactersStream(
|
||||
{
|
||||
project_id: worldResult.project_id,
|
||||
project_id: pid,
|
||||
count: generationData.character_count,
|
||||
world_context: {
|
||||
time_period: worldResult.time_period || '',
|
||||
@@ -731,19 +890,19 @@ export const AIProjectGenerator: React.FC<AIProjectGeneratorProps> = ({
|
||||
}
|
||||
);
|
||||
|
||||
await continueFromOutline();
|
||||
await continueFromOutline(pid);
|
||||
};
|
||||
|
||||
// 从大纲步骤开始的完整流程
|
||||
const continueFromOutline = async () => {
|
||||
if (!generationData || !projectId) return;
|
||||
const continueFromOutline = async (pid: string) => {
|
||||
if (!generationData || !pid) return;
|
||||
|
||||
setGenerationSteps(prev => ({ ...prev, outline: 'processing' }));
|
||||
setProgressMessage('正在生成大纲...');
|
||||
|
||||
await wizardStreamApi.generateCompleteOutlineStream(
|
||||
{
|
||||
project_id: projectId,
|
||||
project_id: pid,
|
||||
chapter_count: generationData.chapter_count,
|
||||
narrative_perspective: generationData.narrative_perspective,
|
||||
target_words: generationData.target_words,
|
||||
@@ -777,12 +936,12 @@ export const AIProjectGenerator: React.FC<AIProjectGeneratorProps> = ({
|
||||
setLoading(false);
|
||||
|
||||
// 调用完成回调
|
||||
if (projectId) {
|
||||
onComplete(projectId);
|
||||
if (pid) {
|
||||
onComplete(pid);
|
||||
|
||||
// 延迟1秒后自动跳转到项目详情页
|
||||
setTimeout(() => {
|
||||
navigate(`/project/${projectId}`);
|
||||
navigate(`/project/${pid}`);
|
||||
}, 1000);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -28,8 +28,11 @@ import {
|
||||
InfoCircleOutlined,
|
||||
ToolOutlined,
|
||||
ArrowLeftOutlined,
|
||||
ApiOutlined,
|
||||
QuestionCircleOutlined,
|
||||
WarningOutlined,
|
||||
} from '@ant-design/icons';
|
||||
import { mcpPluginApi } from '../services/api';
|
||||
import { mcpPluginApi, settingsApi } from '../services/api';
|
||||
import type { MCPPlugin, MCPTool } from '../types';
|
||||
|
||||
const { Paragraph, Text, Title } = Typography;
|
||||
@@ -46,24 +49,112 @@ export default function MCPPluginsPage() {
|
||||
const [editingPlugin, setEditingPlugin] = useState<MCPPlugin | null>(null);
|
||||
const [testingPluginId, setTestingPluginId] = useState<string | null>(null);
|
||||
const [viewingTools, setViewingTools] = useState<{ pluginId: string; tools: MCPTool[] } | null>(null);
|
||||
const [checkingFunctionCalling, setCheckingFunctionCalling] = useState(false);
|
||||
const [modelSupportStatus, setModelSupportStatus] = useState<'unknown' | 'supported' | 'unsupported'>('unknown');
|
||||
|
||||
useEffect(() => {
|
||||
loadPlugins();
|
||||
}, []);
|
||||
const initPage = async () => {
|
||||
setLoading(true);
|
||||
try {
|
||||
// 1. 并行获取插件列表和当前设置
|
||||
const [pluginsData, settings] = await Promise.all([
|
||||
mcpPluginApi.getPlugins(),
|
||||
settingsApi.getSettings()
|
||||
]);
|
||||
|
||||
setPlugins(pluginsData);
|
||||
|
||||
// 2. 检查配置一致性
|
||||
const verifiedConfigStr = localStorage.getItem('mcp_verified_config');
|
||||
if (verifiedConfigStr) {
|
||||
try {
|
||||
const verifiedConfig = JSON.parse(verifiedConfigStr);
|
||||
const currentConfig = {
|
||||
provider: settings.api_provider,
|
||||
baseUrl: settings.api_base_url,
|
||||
model: settings.llm_model
|
||||
};
|
||||
|
||||
// 比较关键配置是否发生变更
|
||||
const isConfigChanged =
|
||||
verifiedConfig.provider !== currentConfig.provider ||
|
||||
verifiedConfig.baseUrl !== currentConfig.baseUrl ||
|
||||
verifiedConfig.model !== currentConfig.model;
|
||||
|
||||
if (isConfigChanged) {
|
||||
// 配置已变更
|
||||
setModelSupportStatus('unknown');
|
||||
|
||||
// 检查是否有正在运行的插件
|
||||
const activePlugins = pluginsData.filter(p => p.enabled);
|
||||
if (activePlugins.length > 0) {
|
||||
// 自动禁用所有插件
|
||||
message.loading({ content: '检测到模型配置变更,正在为了安全自动禁用插件...', key: 'auto_disable' });
|
||||
|
||||
await Promise.all(activePlugins.map(p => mcpPluginApi.togglePlugin(p.id, false)));
|
||||
|
||||
// 重新加载插件列表状态
|
||||
const updatedPlugins = await mcpPluginApi.getPlugins();
|
||||
setPlugins(updatedPlugins);
|
||||
|
||||
message.success({ content: '已自动禁用所有插件,请重新检测模型能力', key: 'auto_disable' });
|
||||
|
||||
modal.warning({
|
||||
title: '配置变更提醒',
|
||||
centered: true,
|
||||
content: '检测到您更换了 AI 模型或接口地址。为了防止错误调用,系统已自动暂停所有 MCP 插件。请重新进行"模型能力检查",确认新模型支持 Function Calling 后再启用插件。',
|
||||
okText: '知道了',
|
||||
});
|
||||
} else {
|
||||
// 没有运行中的插件,仅提示
|
||||
message.info('检测到模型配置已变更,请重新检测模型能力');
|
||||
}
|
||||
|
||||
// 清除旧的验证状态
|
||||
localStorage.removeItem('mcp_verified_config');
|
||||
} else {
|
||||
// 配置未变更,恢复验证状态(根据缓存的状态恢复)
|
||||
const cachedStatus = verifiedConfig.status || 'supported';
|
||||
setModelSupportStatus(cachedStatus as 'unknown' | 'supported' | 'unsupported');
|
||||
}
|
||||
} catch (e) {
|
||||
console.error('Failed to parse verified config:', e);
|
||||
localStorage.removeItem('mcp_verified_config');
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Init page failed:', error);
|
||||
message.error('页面初始化失败');
|
||||
} finally {
|
||||
setLoading(false);
|
||||
}
|
||||
};
|
||||
initPage();
|
||||
}, [modal]);
|
||||
|
||||
const loadPlugins = async () => {
|
||||
setLoading(true);
|
||||
try {
|
||||
const data = await mcpPluginApi.getPlugins();
|
||||
setPlugins(data);
|
||||
} catch (error) {
|
||||
console.error('Load plugins failed:', error);
|
||||
message.error('加载插件列表失败');
|
||||
} finally {
|
||||
setLoading(false);
|
||||
}
|
||||
};
|
||||
|
||||
const handleCreate = () => {
|
||||
if (modelSupportStatus !== 'supported') {
|
||||
modal.confirm({
|
||||
title: '模型能力检查',
|
||||
centered: true,
|
||||
icon: <WarningOutlined />,
|
||||
content: '为了确保 MCP 插件正常工作,您当前使用的 AI 模型必须支持 Function Calling(工具调用)能力。请先进行模型支持检测。',
|
||||
okText: '去检测',
|
||||
cancelText: '取消',
|
||||
onOk: handleCheckFunctionCalling,
|
||||
});
|
||||
return;
|
||||
}
|
||||
setEditingPlugin(null);
|
||||
form.resetFields();
|
||||
form.setFieldsValue({
|
||||
@@ -86,7 +177,7 @@ export default function MCPPluginsPage() {
|
||||
setEditingPlugin(plugin);
|
||||
|
||||
// 重构为标准MCP配置格式
|
||||
const mcpConfig: any = {
|
||||
const mcpConfig: Record<string, Record<string, Record<string, unknown>>> = {
|
||||
mcpServers: {
|
||||
[plugin.plugin_name]: {
|
||||
type: plugin.plugin_type || 'http'
|
||||
@@ -94,7 +185,7 @@ export default function MCPPluginsPage() {
|
||||
}
|
||||
};
|
||||
|
||||
if (plugin.plugin_type === 'http') {
|
||||
if (plugin.plugin_type === 'http' || plugin.plugin_type === 'streamable_http' || plugin.plugin_type === 'sse') {
|
||||
mcpConfig.mcpServers[plugin.plugin_name].url = plugin.server_url;
|
||||
mcpConfig.mcpServers[plugin.plugin_name].headers = plugin.headers || {};
|
||||
} else {
|
||||
@@ -125,6 +216,7 @@ export default function MCPPluginsPage() {
|
||||
message.success('插件已删除');
|
||||
loadPlugins();
|
||||
} catch (error) {
|
||||
console.error('Delete plugin failed:', error);
|
||||
message.error('删除插件失败');
|
||||
}
|
||||
},
|
||||
@@ -137,6 +229,7 @@ export default function MCPPluginsPage() {
|
||||
message.success(enabled ? '插件已启用' : '插件已禁用');
|
||||
loadPlugins();
|
||||
} catch (error) {
|
||||
console.error('Toggle plugin failed:', error);
|
||||
message.error('切换插件状态失败');
|
||||
}
|
||||
};
|
||||
@@ -150,45 +243,62 @@ export default function MCPPluginsPage() {
|
||||
await loadPlugins();
|
||||
|
||||
if (result.success) {
|
||||
const suggestions = result.suggestions || [];
|
||||
const aiChoice = suggestions.find((s: string) => s.startsWith('🤖'))?.replace('🤖 AI选择: ', '') || '';
|
||||
const paramsStr = suggestions.find((s: string) => s.startsWith('📝'))?.replace('📝 参数: ', '') || '';
|
||||
const callTime = suggestions.find((s: string) => s.startsWith('⏱️'))?.replace('⏱️ 耗时: ', '') || '';
|
||||
const resultStr = suggestions.find((s: string) => s.startsWith('📊'))?.replace('📊 结果:\n', '') || '';
|
||||
|
||||
modal.success({
|
||||
title: '测试成功',
|
||||
title: '🎉 测试成功',
|
||||
centered: true,
|
||||
width: isMobile ? '90%' : 600,
|
||||
width: isMobile ? '95%' : 700,
|
||||
content: (
|
||||
<div style={{ padding: '8px 0' }}>
|
||||
<div style={{ marginBottom: 24, padding: 16, background: 'var(--color-success-bg)', border: '1px solid var(--color-success-border)', borderRadius: 8 }}>
|
||||
<Typography.Text strong style={{ color: 'var(--color-success)' }}>
|
||||
<div style={{ marginBottom: 16, padding: 12, background: 'var(--color-success-bg)', border: '1px solid var(--color-success-border)', borderRadius: 8 }}>
|
||||
<Typography.Text strong style={{ color: 'var(--color-success)', fontSize: 14 }}>
|
||||
✓ {result.message}
|
||||
</Typography.Text>
|
||||
</div>
|
||||
|
||||
{(result.tools_count !== undefined || result.response_time_ms !== undefined) && (
|
||||
<div style={{
|
||||
padding: 16,
|
||||
background: 'var(--color-bg-layout)',
|
||||
borderRadius: 8,
|
||||
marginBottom: 16
|
||||
}}>
|
||||
{result.tools_count !== undefined && (
|
||||
<div style={{ marginBottom: 8, fontSize: 14 }}>
|
||||
<Text type="secondary">可用工具数:</Text>
|
||||
<Text strong>{result.tools_count}</Text>
|
||||
</div>
|
||||
)}
|
||||
{result.response_time_ms !== undefined && (
|
||||
<div style={{ fontSize: 14 }}>
|
||||
<Text type="secondary">响应时间:</Text>
|
||||
<Text strong>{result.response_time_ms}ms</Text>
|
||||
</div>
|
||||
)}
|
||||
<div style={{ display: 'grid', gridTemplateColumns: isMobile ? '1fr' : '1fr 1fr', gap: 12, marginBottom: 16 }}>
|
||||
<div style={{ padding: 12, background: 'var(--color-bg-layout)', borderRadius: 8 }}>
|
||||
<Text type="secondary" style={{ fontSize: 12 }}>可用工具数</Text>
|
||||
<div><Text strong style={{ fontSize: 20 }}>{result.tools_count || 0}</Text></div>
|
||||
</div>
|
||||
<div style={{ padding: 12, background: 'var(--color-bg-layout)', borderRadius: 8 }}>
|
||||
<Text type="secondary" style={{ fontSize: 12 }}>总响应时间</Text>
|
||||
<div><Text strong style={{ fontSize: 20 }}>{result.response_time_ms?.toFixed(0) || 0}ms</Text></div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{aiChoice && (
|
||||
<div style={{ marginBottom: 12, padding: 12, background: 'var(--color-info-bg)', borderRadius: 8, border: '1px solid var(--color-info-border)' }}>
|
||||
<Text type="secondary" style={{ fontSize: 12, display: 'block', marginBottom: 4 }}>🤖 AI选择的工具</Text>
|
||||
<Text code strong>{aiChoice}</Text>
|
||||
{callTime && <Tag color="blue" style={{ marginLeft: 8 }}>{callTime}</Tag>}
|
||||
</div>
|
||||
)}
|
||||
|
||||
<Alert
|
||||
message='插件状态已自动更新为"运行中"'
|
||||
type="success"
|
||||
showIcon
|
||||
/>
|
||||
{paramsStr && (
|
||||
<div style={{ marginBottom: 12 }}>
|
||||
<Text type="secondary" style={{ fontSize: 12, display: 'block', marginBottom: 4 }}>📝 调用参数</Text>
|
||||
<pre style={{ margin: 0, padding: 8, background: 'var(--color-bg-layout)', borderRadius: 4, fontSize: 12, overflow: 'auto', maxHeight: 100 }}>
|
||||
{(() => { try { return JSON.stringify(JSON.parse(paramsStr), null, 2); } catch { return paramsStr; } })()}
|
||||
</pre>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{resultStr && (
|
||||
<div style={{ marginBottom: 12 }}>
|
||||
<Text type="secondary" style={{ fontSize: 12, display: 'block', marginBottom: 4 }}>📊 返回结果预览</Text>
|
||||
<pre style={{ margin: 0, padding: 8, background: 'var(--color-bg-layout)', borderRadius: 4, fontSize: 11, overflow: 'auto', maxHeight: 150, whiteSpace: 'pre-wrap', wordBreak: 'break-word' }}>
|
||||
{resultStr}
|
||||
</pre>
|
||||
</div>
|
||||
)}
|
||||
|
||||
<Alert message='插件状态已自动更新为"运行中"' type="success" showIcon />
|
||||
</div>
|
||||
),
|
||||
});
|
||||
@@ -248,7 +358,7 @@ export default function MCPPluginsPage() {
|
||||
),
|
||||
});
|
||||
}
|
||||
} catch (error: any) {
|
||||
} catch {
|
||||
message.error('测试插件失败');
|
||||
} finally {
|
||||
setTestingPluginId(null);
|
||||
@@ -260,17 +370,181 @@ export default function MCPPluginsPage() {
|
||||
const result = await mcpPluginApi.getPluginTools(pluginId);
|
||||
setViewingTools({ pluginId, tools: result.tools });
|
||||
} catch (error) {
|
||||
console.error('Get tools failed:', error);
|
||||
message.error('获取工具列表失败');
|
||||
}
|
||||
};
|
||||
|
||||
const handleSubmit = async (values: any) => {
|
||||
const handleCheckFunctionCalling = async () => {
|
||||
// 从设置中获取当前配置
|
||||
setCheckingFunctionCalling(true);
|
||||
try {
|
||||
const settings = await settingsApi.getSettings();
|
||||
|
||||
if (!settings.api_key || !settings.llm_model) {
|
||||
message.warning('请先在设置页面配置 API Key 和模型');
|
||||
return;
|
||||
}
|
||||
|
||||
const result = await settingsApi.checkFunctionCalling({
|
||||
api_key: settings.api_key,
|
||||
api_base_url: settings.api_base_url || '',
|
||||
provider: settings.api_provider || 'openai',
|
||||
llm_model: settings.llm_model,
|
||||
});
|
||||
|
||||
// 无论成功失败,都缓存当前测试的配置和状态
|
||||
const configToCache = {
|
||||
provider: settings.api_provider,
|
||||
baseUrl: settings.api_base_url,
|
||||
model: settings.llm_model,
|
||||
status: result.success && result.supported ? 'supported' : 'unsupported',
|
||||
testedAt: new Date().toISOString()
|
||||
};
|
||||
localStorage.setItem('mcp_verified_config', JSON.stringify(configToCache));
|
||||
|
||||
if (result.success && result.supported) {
|
||||
setModelSupportStatus('supported');
|
||||
|
||||
modal.success({
|
||||
title: '✅ Function Calling 支持检测',
|
||||
centered: true,
|
||||
width: isMobile ? '95%' : 700,
|
||||
content: (
|
||||
<div style={{ padding: '8px 0' }}>
|
||||
<div style={{ marginBottom: 16, padding: 12, background: 'var(--color-success-bg)', border: '1px solid var(--color-success-border)', borderRadius: 8 }}>
|
||||
<Typography.Text strong style={{ color: 'var(--color-success)', fontSize: 14 }}>
|
||||
✓ {result.message}
|
||||
</Typography.Text>
|
||||
</div>
|
||||
|
||||
<div style={{ display: 'grid', gridTemplateColumns: isMobile ? '1fr' : '1fr 1fr', gap: 12, marginBottom: 16 }}>
|
||||
<div style={{ padding: 12, background: 'var(--color-bg-layout)', borderRadius: 8 }}>
|
||||
<Text type="secondary" style={{ fontSize: 12 }}>API 提供商</Text>
|
||||
<div><Text strong style={{ fontSize: 16 }}>{result.provider}</Text></div>
|
||||
</div>
|
||||
<div style={{ padding: 12, background: 'var(--color-bg-layout)', borderRadius: 8 }}>
|
||||
<Text type="secondary" style={{ fontSize: 12 }}>响应时间</Text>
|
||||
<div><Text strong style={{ fontSize: 16 }}>{result.response_time_ms?.toFixed(0) || 0}ms</Text></div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div style={{ marginBottom: 12, padding: 12, background: 'var(--color-info-bg)', borderRadius: 8, border: '1px solid var(--color-info-border)' }}>
|
||||
<Text type="secondary" style={{ fontSize: 12, display: 'block', marginBottom: 4 }}>🔧 模型信息</Text>
|
||||
<Text code strong>{result.model}</Text>
|
||||
{result.details?.finish_reason && (
|
||||
<Tag color="green" style={{ marginLeft: 8 }}>finish_reason: {result.details.finish_reason}</Tag>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{result.details && (
|
||||
<div style={{ marginBottom: 12 }}>
|
||||
<Text type="secondary" style={{ fontSize: 12, display: 'block', marginBottom: 4 }}>📊 检测详情</Text>
|
||||
<div style={{ padding: 8, background: 'var(--color-bg-layout)', borderRadius: 4, fontSize: 12 }}>
|
||||
<div>✓ 工具调用数量: {result.details.tool_call_count || 0}</div>
|
||||
<div>✓ 测试工具: {result.details.test_tool || 'N/A'}</div>
|
||||
<div>✓ 响应类型: {result.details.response_type || 'N/A'}</div>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{result.tool_calls && result.tool_calls.length > 0 && (
|
||||
<div style={{ marginBottom: 12 }}>
|
||||
<Text type="secondary" style={{ fontSize: 12, display: 'block', marginBottom: 4 }}>🔨 工具调用示例</Text>
|
||||
<pre style={{ margin: 0, padding: 8, background: 'var(--color-bg-layout)', borderRadius: 4, fontSize: 11, overflow: 'auto', maxHeight: 150 }}>
|
||||
{JSON.stringify(result.tool_calls[0], null, 2)}
|
||||
</pre>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{result.suggestions && result.suggestions.length > 0 && (
|
||||
<div style={{ padding: 12, background: 'var(--color-success-bg)', border: '1px solid var(--color-success-border)', borderRadius: 8 }}>
|
||||
<Text strong style={{ fontSize: 13, display: 'block', marginBottom: 8 }}>💡 建议</Text>
|
||||
<ul style={{ margin: 0, paddingLeft: 20, fontSize: 12 }}>
|
||||
{result.suggestions.map((s: string, i: number) => (
|
||||
<li key={i} style={{ marginBottom: 4 }}>{s}</li>
|
||||
))}
|
||||
</ul>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
),
|
||||
});
|
||||
} else {
|
||||
setModelSupportStatus('unsupported');
|
||||
modal.warning({
|
||||
title: '❌ Function Calling 支持检测',
|
||||
centered: true,
|
||||
width: isMobile ? '95%' : 700,
|
||||
content: (
|
||||
<div style={{ padding: '8px 0' }}>
|
||||
<div style={{ marginBottom: 16 }}>
|
||||
<Alert
|
||||
message={result.message || '模型不支持 Function Calling'}
|
||||
type="warning"
|
||||
showIcon
|
||||
/>
|
||||
</div>
|
||||
|
||||
{result.error && (
|
||||
<div style={{
|
||||
padding: 16,
|
||||
background: 'var(--color-warning-bg)',
|
||||
border: '1px solid var(--color-warning-border)',
|
||||
borderRadius: 8,
|
||||
marginBottom: 16
|
||||
}}>
|
||||
<Text strong style={{ fontSize: 14, display: 'block', marginBottom: 8 }}>错误信息:</Text>
|
||||
<Text style={{ fontSize: 13, fontFamily: 'monospace' }}>
|
||||
{result.error}
|
||||
</Text>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{result.response_preview && (
|
||||
<div style={{ marginBottom: 12 }}>
|
||||
<Text type="secondary" style={{ fontSize: 12, display: 'block', marginBottom: 4 }}>📝 模型返回内容(前200字符)</Text>
|
||||
<pre style={{ margin: 0, padding: 8, background: 'var(--color-bg-layout)', borderRadius: 4, fontSize: 11, overflow: 'auto', maxHeight: 100, whiteSpace: 'pre-wrap' }}>
|
||||
{result.response_preview}
|
||||
</pre>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{result.suggestions && result.suggestions.length > 0 && (
|
||||
<div style={{
|
||||
padding: 16,
|
||||
background: 'var(--color-info-bg)',
|
||||
border: '1px solid var(--color-info-border)',
|
||||
borderRadius: 8
|
||||
}}>
|
||||
<Text strong style={{ fontSize: 14, display: 'block', marginBottom: 8 }}>💡 建议:</Text>
|
||||
<ul style={{ margin: 0, paddingLeft: 20, fontSize: 13 }}>
|
||||
{result.suggestions.map((s: string, i: number) => (
|
||||
<li key={i} style={{ marginBottom: 4 }}>{s}</li>
|
||||
))}
|
||||
</ul>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
),
|
||||
});
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Check function calling failed:', error);
|
||||
message.error('检测失败,请稍后重试');
|
||||
setModelSupportStatus('unsupported');
|
||||
} finally {
|
||||
setCheckingFunctionCalling(false);
|
||||
}
|
||||
};
|
||||
|
||||
const handleSubmit = async (values: { config_json: string; enabled: boolean; category?: string }) => {
|
||||
setLoading(true);
|
||||
try {
|
||||
// 验证JSON格式
|
||||
try {
|
||||
JSON.parse(values.config_json);
|
||||
} catch (e) {
|
||||
} catch {
|
||||
message.error('配置JSON格式错误,请检查');
|
||||
setLoading(false);
|
||||
return;
|
||||
@@ -289,8 +563,9 @@ export default function MCPPluginsPage() {
|
||||
setModalVisible(false);
|
||||
form.resetFields();
|
||||
loadPlugins();
|
||||
} catch (error: any) {
|
||||
const errorMsg = error?.response?.data?.detail || '操作失败';
|
||||
} catch (error: unknown) {
|
||||
const err = error as { response?: { data?: { detail?: string } } };
|
||||
const errorMsg = err?.response?.data?.detail || '操作失败';
|
||||
message.error(errorMsg);
|
||||
} finally {
|
||||
setLoading(false);
|
||||
@@ -407,38 +682,104 @@ export default function MCPPluginsPage() {
|
||||
</Col>
|
||||
</Row>
|
||||
|
||||
{/* 使用提示 */}
|
||||
<Alert
|
||||
message={
|
||||
<Space align="center">
|
||||
<InfoCircleOutlined style={{ fontSize: 16, color: 'var(--color-primary)' }} />
|
||||
<Text strong style={{ fontSize: isMobile ? 13 : 14, color: 'var(--color-text-primary)' }}>什么是 MCP 插件?</Text>
|
||||
</Space>
|
||||
}
|
||||
description={
|
||||
<div>
|
||||
<Text style={{ fontSize: isMobile ? 12 : 13, display: 'block', marginBottom: 8 }}>
|
||||
• <strong>MCP (Model Context Protocol)</strong> 是一个标准化的协议,允许 AI 调用外部工具获取数据。
|
||||
</Text>
|
||||
<Text style={{ fontSize: isMobile ? 12 : 13, display: 'block' }}>
|
||||
• 通过添加 MCP 插件,AI 可以访问搜索引擎、数据库、API 等外部服务,增强创作能力。
|
||||
</Text>
|
||||
<div style={{ marginTop: isMobile ? 16 : 24, display: 'flex', gap: 16, flexDirection: isMobile ? 'column' : 'row' }}>
|
||||
<Card
|
||||
variant="borderless"
|
||||
style={{
|
||||
flex: 1,
|
||||
borderRadius: 12,
|
||||
background: 'rgba(255, 255, 255, 0.9)',
|
||||
border: '1px solid rgba(255, 255, 255, 0.6)',
|
||||
backdropFilter: 'blur(10px)',
|
||||
boxShadow: '0 4px 12px rgba(0, 0, 0, 0.03)'
|
||||
}}
|
||||
bodyStyle={{ padding: 20 }}
|
||||
>
|
||||
<div style={{ display: 'flex', justifyContent: 'space-between', alignItems: 'center' }}>
|
||||
<Space align="start">
|
||||
<div style={{
|
||||
width: 40, height: 40, borderRadius: '50%',
|
||||
background: modelSupportStatus === 'supported' ? 'var(--color-success-bg)' : modelSupportStatus === 'unsupported' ? 'var(--color-error-bg)' : 'var(--color-info-bg)',
|
||||
display: 'flex', alignItems: 'center', justifyContent: 'center',
|
||||
border: `1px solid ${modelSupportStatus === 'supported' ? 'var(--color-success-border)' : modelSupportStatus === 'unsupported' ? 'var(--color-error-border)' : 'var(--color-info-border)'}`
|
||||
}}>
|
||||
{modelSupportStatus === 'supported' ? (
|
||||
<CheckCircleOutlined style={{ fontSize: 20, color: 'var(--color-success)' }} />
|
||||
) : modelSupportStatus === 'unsupported' ? (
|
||||
<CloseCircleOutlined style={{ fontSize: 20, color: 'var(--color-error)' }} />
|
||||
) : (
|
||||
<QuestionCircleOutlined style={{ fontSize: 20, color: 'var(--color-info)' }} />
|
||||
)}
|
||||
</div>
|
||||
<div>
|
||||
<Text strong style={{ fontSize: 16, display: 'block', color: 'var(--color-text-primary)' }}>模型能力检查</Text>
|
||||
<Text type="secondary" style={{ fontSize: 13 }}>
|
||||
{modelSupportStatus === 'supported'
|
||||
? '当前模型支持 Function Calling,可正常使用 MCP 插件'
|
||||
: modelSupportStatus === 'unsupported'
|
||||
? '当前模型不支持 Function Calling,无法使用 MCP 插件'
|
||||
: '请先检测模型是否支持 Function Calling 能力'}
|
||||
</Text>
|
||||
</div>
|
||||
</Space>
|
||||
<Button
|
||||
type={modelSupportStatus === 'supported' ? 'default' : 'primary'}
|
||||
icon={<ApiOutlined />}
|
||||
onClick={handleCheckFunctionCalling}
|
||||
loading={checkingFunctionCalling}
|
||||
style={{ borderRadius: 8 }}
|
||||
>
|
||||
{modelSupportStatus === 'unknown' ? '开始检测' : '重新检测'}
|
||||
</Button>
|
||||
</div>
|
||||
}
|
||||
type="info"
|
||||
showIcon={false}
|
||||
style={{
|
||||
marginTop: isMobile ? 16 : 24,
|
||||
borderRadius: 12,
|
||||
background: 'rgba(230, 247, 255, 0.6)',
|
||||
border: '1px solid rgba(145, 213, 255, 0.6)',
|
||||
backdropFilter: 'blur(5px)'
|
||||
}}
|
||||
/>
|
||||
</Card>
|
||||
|
||||
<Card
|
||||
variant="borderless"
|
||||
style={{
|
||||
flex: 1,
|
||||
borderRadius: 12,
|
||||
background: 'rgba(230, 247, 255, 0.6)',
|
||||
border: '1px solid rgba(145, 213, 255, 0.6)',
|
||||
backdropFilter: 'blur(10px)',
|
||||
boxShadow: '0 4px 12px rgba(0, 0, 0, 0.03)'
|
||||
}}
|
||||
bodyStyle={{ padding: 20 }}
|
||||
>
|
||||
<Space align="start">
|
||||
<InfoCircleOutlined style={{ fontSize: 20, color: 'var(--color-primary)', marginTop: 4 }} />
|
||||
<div>
|
||||
<Text strong style={{ fontSize: 16, display: 'block', color: 'var(--color-text-primary)', marginBottom: 4 }}>什么是 MCP 插件?</Text>
|
||||
<Text style={{ fontSize: 13, display: 'block', color: 'var(--color-text-secondary)', lineHeight: 1.6 }}>
|
||||
MCP (Model Context Protocol) 协议允许 AI 调用外部工具获取数据。通过添加插件,AI 可以访问搜索引擎、数据库、API 等服务,大幅增强创作能力。
|
||||
</Text>
|
||||
</div>
|
||||
</Space>
|
||||
</Card>
|
||||
</div>
|
||||
</Card>
|
||||
|
||||
{/* 主内容区 */}
|
||||
<div style={{ flex: 1 }}>
|
||||
{/* 模型能力未验证时的警告提示 */}
|
||||
{modelSupportStatus !== 'supported' && plugins.length > 0 && (
|
||||
<Alert
|
||||
message={
|
||||
modelSupportStatus === 'unsupported'
|
||||
? '当前模型不支持 Function Calling,所有插件操作已禁用'
|
||||
: '请先完成模型能力检查,才能操作插件'
|
||||
}
|
||||
type={modelSupportStatus === 'unsupported' ? 'error' : 'warning'}
|
||||
showIcon
|
||||
icon={modelSupportStatus === 'unsupported' ? <CloseCircleOutlined /> : <WarningOutlined />}
|
||||
style={{ marginBottom: 16, borderRadius: 8 }}
|
||||
action={
|
||||
<Button size="small" type="primary" onClick={handleCheckFunctionCalling} loading={checkingFunctionCalling}>
|
||||
{modelSupportStatus === 'unknown' ? '开始检测' : '重新检测'}
|
||||
</Button>
|
||||
}
|
||||
/>
|
||||
)}
|
||||
|
||||
{/* 插件列表 */}
|
||||
<Spin spinning={loading}>
|
||||
@@ -479,7 +820,7 @@ export default function MCPPluginsPage() {
|
||||
{plugin.display_name || plugin.plugin_name}
|
||||
</Text>
|
||||
{getStatusTag(plugin)}
|
||||
<Tag color={plugin.plugin_type === 'http' ? 'blue' : 'cyan'}>
|
||||
<Tag color={plugin.plugin_type === 'http' || plugin.plugin_type === 'streamable_http' || plugin.plugin_type === 'sse' ? 'blue' : 'cyan'}>
|
||||
{plugin.plugin_type?.toUpperCase() || 'UNKNOWN'}
|
||||
</Tag>
|
||||
{plugin.category && plugin.category !== 'general' && (
|
||||
@@ -500,7 +841,7 @@ export default function MCPPluginsPage() {
|
||||
)}
|
||||
|
||||
{/* 只显示有值的URL或命令,脱敏处理敏感信息 */}
|
||||
{plugin.plugin_type === 'http' && plugin.server_url && (
|
||||
{(plugin.plugin_type === 'http' || plugin.plugin_type === 'streamable_http' || plugin.plugin_type === 'sse') && plugin.server_url && (
|
||||
<div style={{ fontSize: isMobile ? '11px' : '12px' }}>
|
||||
<Text type="secondary" code>
|
||||
{(() => {
|
||||
@@ -551,9 +892,10 @@ export default function MCPPluginsPage() {
|
||||
|
||||
<Space size="small" wrap>
|
||||
<Switch
|
||||
title={plugin.enabled ? '禁用插件' : '启用插件'}
|
||||
title={modelSupportStatus !== 'supported' ? '请先完成模型能力检查' : (plugin.enabled ? '禁用插件' : '启用插件')}
|
||||
checked={plugin.enabled}
|
||||
onChange={(checked) => handleToggle(plugin, checked)}
|
||||
disabled={modelSupportStatus !== 'supported'}
|
||||
size={isMobile ? 'small' : 'default'}
|
||||
style={{
|
||||
flexShrink: 0,
|
||||
@@ -563,30 +905,33 @@ export default function MCPPluginsPage() {
|
||||
}}
|
||||
/>
|
||||
<Button
|
||||
title="测试连接"
|
||||
title={modelSupportStatus !== 'supported' ? '请先完成模型能力检查' : '测试连接'}
|
||||
icon={<ThunderboltOutlined />}
|
||||
onClick={() => handleTest(plugin.id)}
|
||||
loading={testingPluginId === plugin.id}
|
||||
disabled={modelSupportStatus !== 'supported'}
|
||||
size={isMobile ? 'small' : 'middle'}
|
||||
/>
|
||||
<Button
|
||||
title="查看工具"
|
||||
title={modelSupportStatus !== 'supported' ? '请先完成模型能力检查' : '查看工具'}
|
||||
icon={<ToolOutlined />}
|
||||
onClick={() => handleViewTools(plugin.id)}
|
||||
disabled={!plugin.enabled || plugin.status !== 'active'}
|
||||
disabled={modelSupportStatus !== 'supported' || !plugin.enabled || plugin.status !== 'active'}
|
||||
size={isMobile ? 'small' : 'middle'}
|
||||
/>
|
||||
<Button
|
||||
title="编辑"
|
||||
title={modelSupportStatus !== 'supported' ? '请先完成模型能力检查' : '编辑'}
|
||||
icon={<EditOutlined />}
|
||||
onClick={() => handleEdit(plugin)}
|
||||
disabled={modelSupportStatus !== 'supported'}
|
||||
size={isMobile ? 'small' : 'middle'}
|
||||
/>
|
||||
<Button
|
||||
title="删除"
|
||||
title={modelSupportStatus !== 'supported' ? '请先完成模型能力检查' : '删除'}
|
||||
danger
|
||||
icon={<DeleteOutlined />}
|
||||
onClick={() => handleDelete(plugin)}
|
||||
disabled={modelSupportStatus !== 'supported'}
|
||||
size={isMobile ? 'small' : 'middle'}
|
||||
/>
|
||||
</Space>
|
||||
@@ -627,7 +972,7 @@ export default function MCPPluginsPage() {
|
||||
{
|
||||
"mcpServers": {
|
||||
"exa": {
|
||||
"type": "http",
|
||||
"type": "streamable_http",
|
||||
"url": "https://mcp.exa.ai/mcp?exaApiKey=YOUR_API_KEY",
|
||||
"headers": {}
|
||||
}
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import { useState, useEffect } from 'react';
|
||||
import { useNavigate } from 'react-router-dom';
|
||||
import { Card, Form, Input, Button, Select, Slider, InputNumber, message, Space, Typography, Spin, Modal, Alert, Grid, Tabs, List, Tag, Popconfirm, Empty, Row, Col } from 'antd';
|
||||
import { SettingOutlined, SaveOutlined, DeleteOutlined, ReloadOutlined, ArrowLeftOutlined, InfoCircleOutlined, CheckCircleOutlined, CloseCircleOutlined, ThunderboltOutlined, PlusOutlined, EditOutlined, CopyOutlined } from '@ant-design/icons';
|
||||
import { settingsApi } from '../services/api';
|
||||
import { SettingOutlined, SaveOutlined, DeleteOutlined, ReloadOutlined, ArrowLeftOutlined, InfoCircleOutlined, CheckCircleOutlined, CloseCircleOutlined, ThunderboltOutlined, PlusOutlined, EditOutlined, CopyOutlined, WarningOutlined } from '@ant-design/icons';
|
||||
import { settingsApi, mcpPluginApi } from '../services/api';
|
||||
import type { SettingsUpdate, APIKeyPreset, PresetCreateRequest, APIKeyPresetConfig } from '../types';
|
||||
|
||||
const { Title, Text } = Typography;
|
||||
@@ -95,10 +95,86 @@ export default function SettingsPage() {
|
||||
const handleSave = async (values: SettingsUpdate) => {
|
||||
setLoading(true);
|
||||
try {
|
||||
// 检查是否与 MCP 缓存的配置不一致
|
||||
const verifiedConfigStr = localStorage.getItem('mcp_verified_config');
|
||||
let configChanged = false;
|
||||
|
||||
if (verifiedConfigStr) {
|
||||
try {
|
||||
const verifiedConfig = JSON.parse(verifiedConfigStr);
|
||||
configChanged =
|
||||
verifiedConfig.provider !== values.api_provider ||
|
||||
verifiedConfig.baseUrl !== values.api_base_url ||
|
||||
verifiedConfig.model !== values.llm_model;
|
||||
} catch (e) {
|
||||
console.error('Failed to parse verified config:', e);
|
||||
}
|
||||
}
|
||||
|
||||
await settingsApi.saveSettings(values);
|
||||
message.success('设置已保存');
|
||||
setHasSettings(true);
|
||||
setIsDefaultSettings(false);
|
||||
|
||||
// 如果配置发生变化,需要处理 MCP 插件
|
||||
if (configChanged) {
|
||||
// 清除 MCP 验证缓存
|
||||
localStorage.removeItem('mcp_verified_config');
|
||||
|
||||
// 检查并禁用所有 MCP 插件
|
||||
try {
|
||||
const plugins = await mcpPluginApi.getPlugins();
|
||||
const activePlugins = plugins.filter(p => p.enabled);
|
||||
|
||||
if (activePlugins.length > 0) {
|
||||
// 禁用所有插件
|
||||
message.loading({ content: '正在禁用 MCP 插件...', key: 'disable_mcp' });
|
||||
await Promise.all(activePlugins.map(p => mcpPluginApi.togglePlugin(p.id, false)));
|
||||
message.success({ content: '已禁用所有 MCP 插件', key: 'disable_mcp' });
|
||||
|
||||
// 显示提示弹窗
|
||||
modal.warning({
|
||||
title: (
|
||||
<Space>
|
||||
<WarningOutlined style={{ color: '#faad14' }} />
|
||||
<span>API 配置已更改</span>
|
||||
</Space>
|
||||
),
|
||||
centered: true,
|
||||
content: (
|
||||
<div style={{ padding: '8px 0' }}>
|
||||
<Alert
|
||||
message="检测到您修改了 API 配置(提供商、地址或模型),为确保 MCP 插件正常工作,系统已自动禁用所有插件。"
|
||||
type="warning"
|
||||
showIcon
|
||||
style={{ marginBottom: 16 }}
|
||||
/>
|
||||
<div style={{
|
||||
padding: 12,
|
||||
background: 'var(--color-info-bg)',
|
||||
border: '1px solid var(--color-info-border)',
|
||||
borderRadius: 8
|
||||
}}>
|
||||
<Text strong style={{ display: 'block', marginBottom: 8 }}>请完成以下步骤:</Text>
|
||||
<ol style={{ margin: 0, paddingLeft: 20, fontSize: 13 }}>
|
||||
<li>前往 MCP 插件管理页面</li>
|
||||
<li>重新进行"模型能力检查"</li>
|
||||
<li>确认新模型支持 Function Calling 后再启用插件</li>
|
||||
</ol>
|
||||
</div>
|
||||
</div>
|
||||
),
|
||||
okText: '前往 MCP 页面',
|
||||
cancelText: '稍后处理',
|
||||
onOk: () => {
|
||||
navigate('/mcp-plugins');
|
||||
},
|
||||
});
|
||||
}
|
||||
} catch (err) {
|
||||
console.error('Failed to disable MCP plugins:', err);
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
message.error('保存设置失败');
|
||||
} finally {
|
||||
@@ -348,10 +424,94 @@ export default function SettingsPage() {
|
||||
|
||||
const handlePresetActivate = async (presetId: string, presetName: string) => {
|
||||
try {
|
||||
// 获取预设配置用于比较
|
||||
const preset = presets.find(p => p.id === presetId);
|
||||
|
||||
await settingsApi.activatePreset(presetId);
|
||||
message.success(`已激活预设: ${presetName}`);
|
||||
loadPresets();
|
||||
loadSettings(); // 重新加载当前配置
|
||||
|
||||
// 检查是否与 MCP 缓存的配置不一致
|
||||
if (preset) {
|
||||
const verifiedConfigStr = localStorage.getItem('mcp_verified_config');
|
||||
let configChanged = false;
|
||||
|
||||
if (verifiedConfigStr) {
|
||||
try {
|
||||
const verifiedConfig = JSON.parse(verifiedConfigStr);
|
||||
configChanged =
|
||||
verifiedConfig.provider !== preset.config.api_provider ||
|
||||
verifiedConfig.baseUrl !== preset.config.api_base_url ||
|
||||
verifiedConfig.model !== preset.config.llm_model;
|
||||
} catch (e) {
|
||||
console.error('Failed to parse verified config:', e);
|
||||
configChanged = true; // 解析失败也视为配置变化
|
||||
}
|
||||
} else {
|
||||
// 没有缓存的配置,如果有启用的插件也需要处理
|
||||
configChanged = true;
|
||||
}
|
||||
|
||||
if (configChanged) {
|
||||
// 清除 MCP 验证缓存
|
||||
localStorage.removeItem('mcp_verified_config');
|
||||
|
||||
// 检查并禁用所有 MCP 插件
|
||||
try {
|
||||
const plugins = await mcpPluginApi.getPlugins();
|
||||
const activePlugins = plugins.filter(p => p.enabled);
|
||||
|
||||
if (activePlugins.length > 0) {
|
||||
// 禁用所有插件
|
||||
message.loading({ content: '正在禁用 MCP 插件...', key: 'disable_mcp' });
|
||||
await Promise.all(activePlugins.map(p => mcpPluginApi.togglePlugin(p.id, false)));
|
||||
message.success({ content: '已禁用所有 MCP 插件', key: 'disable_mcp' });
|
||||
|
||||
// 显示提示弹窗
|
||||
modal.warning({
|
||||
title: (
|
||||
<Space>
|
||||
<WarningOutlined style={{ color: '#faad14' }} />
|
||||
<span>API 配置已更改</span>
|
||||
</Space>
|
||||
),
|
||||
centered: true,
|
||||
content: (
|
||||
<div style={{ padding: '8px 0' }}>
|
||||
<Alert
|
||||
message={`切换到预设「${presetName}」后,API 配置发生了变化。为确保 MCP 插件正常工作,系统已自动禁用所有插件。`}
|
||||
type="warning"
|
||||
showIcon
|
||||
style={{ marginBottom: 16 }}
|
||||
/>
|
||||
<div style={{
|
||||
padding: 12,
|
||||
background: 'var(--color-info-bg)',
|
||||
border: '1px solid var(--color-info-border)',
|
||||
borderRadius: 8
|
||||
}}>
|
||||
<Text strong style={{ display: 'block', marginBottom: 8 }}>请完成以下步骤:</Text>
|
||||
<ol style={{ margin: 0, paddingLeft: 20, fontSize: 13 }}>
|
||||
<li>前往 MCP 插件管理页面</li>
|
||||
<li>重新进行"模型能力检查"</li>
|
||||
<li>确认新模型支持 Function Calling 后再启用插件</li>
|
||||
</ol>
|
||||
</div>
|
||||
</div>
|
||||
),
|
||||
okText: '前往 MCP 页面',
|
||||
cancelText: '稍后处理',
|
||||
onOk: () => {
|
||||
navigate('/mcp-plugins');
|
||||
},
|
||||
});
|
||||
}
|
||||
} catch (err) {
|
||||
console.error('Failed to disable MCP plugins:', err);
|
||||
}
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
message.error('激活失败');
|
||||
console.error(error);
|
||||
|
||||
@@ -1,9 +1,4 @@
|
||||
import axios from 'axios';
|
||||
|
||||
interface MCPPluginSimpleCreate {
|
||||
config_json: string;
|
||||
enabled: boolean;
|
||||
}
|
||||
import { message } from 'antd';
|
||||
import { ssePost } from '../utils/sseClient';
|
||||
import type { SSEClientOptions } from '../utils/sseClient';
|
||||
@@ -50,8 +45,14 @@ import type {
|
||||
PresetCreateRequest,
|
||||
PresetUpdateRequest,
|
||||
PresetListResponse,
|
||||
ChapterPlanItem,
|
||||
} from '../types';
|
||||
|
||||
interface MCPPluginSimpleCreate {
|
||||
config_json: string;
|
||||
enabled: boolean;
|
||||
}
|
||||
|
||||
const api = axios.create({
|
||||
baseURL: '/api',
|
||||
timeout: 120000,
|
||||
@@ -205,6 +206,36 @@ export const settingsApi = {
|
||||
suggestions?: string[];
|
||||
}>('/settings/test', params),
|
||||
|
||||
checkFunctionCalling: (params: { api_key: string; api_base_url: string; provider: string; llm_model: string }) =>
|
||||
api.post<unknown, {
|
||||
success: boolean;
|
||||
supported: boolean;
|
||||
message: string;
|
||||
response_time_ms?: number;
|
||||
provider?: string;
|
||||
model?: string;
|
||||
details?: {
|
||||
finish_reason?: string;
|
||||
has_tool_calls?: boolean;
|
||||
tool_call_count?: number;
|
||||
test_tool?: string;
|
||||
test_prompt?: string;
|
||||
response_type?: string;
|
||||
};
|
||||
tool_calls?: Array<{
|
||||
id?: string;
|
||||
type?: string;
|
||||
function?: {
|
||||
name: string;
|
||||
arguments: string;
|
||||
};
|
||||
}>;
|
||||
response_preview?: string;
|
||||
error?: string;
|
||||
error_type?: string;
|
||||
suggestions?: string[];
|
||||
}>('/settings/check-function-calling', params),
|
||||
|
||||
// API配置预设管理
|
||||
getPresets: () =>
|
||||
api.get<unknown, PresetListResponse>('/settings/presets'),
|
||||
@@ -410,7 +441,7 @@ export const outlineApi = {
|
||||
api.post<unknown, OutlineExpansionResponse>(`/outlines/${outlineId}/expand`, data),
|
||||
|
||||
// 根据已有规划创建章节(避免重复AI调用)
|
||||
createChaptersFromPlans: (outlineId: string, chapterPlans: any[]) =>
|
||||
createChaptersFromPlans: (outlineId: string, chapterPlans: ChapterPlanItem[]) =>
|
||||
api.post<unknown, {
|
||||
outline_id: string;
|
||||
outline_title: string;
|
||||
@@ -711,6 +742,25 @@ export const wizardStreamApi = {
|
||||
options
|
||||
),
|
||||
|
||||
generateCareerSystemStream: (
|
||||
data: {
|
||||
project_id: string;
|
||||
provider?: string;
|
||||
model?: string;
|
||||
},
|
||||
options?: SSEClientOptions
|
||||
) => ssePost<{
|
||||
project_id: string;
|
||||
main_careers_count: number;
|
||||
sub_careers_count: number;
|
||||
main_careers: string[];
|
||||
sub_careers: string[];
|
||||
}>(
|
||||
'/api/wizard-stream/career-system',
|
||||
data,
|
||||
options
|
||||
),
|
||||
|
||||
generateCompleteOutlineStream: (
|
||||
data: {
|
||||
project_id: string;
|
||||
|
||||
@@ -356,7 +356,7 @@ export function useChapterSync() {
|
||||
message.progress || 0
|
||||
);
|
||||
}
|
||||
} else if (message.type === 'content' && message.content) {
|
||||
} else if ((message.type === 'content' || message.type === 'chunk') && message.content) {
|
||||
fullContent += message.content;
|
||||
if (onProgress) {
|
||||
onProgress(fullContent);
|
||||
|
||||
@@ -667,7 +667,7 @@ export interface MCPPlugin {
|
||||
plugin_name: string;
|
||||
display_name: string;
|
||||
description?: string;
|
||||
plugin_type: 'http' | 'stdio';
|
||||
plugin_type: 'http' | 'stdio' | 'streamable_http' | 'sse';
|
||||
category: string;
|
||||
|
||||
// HTTP类型字段
|
||||
@@ -693,7 +693,7 @@ export interface MCPPluginCreate {
|
||||
plugin_name: string;
|
||||
display_name?: string;
|
||||
description?: string;
|
||||
server_type: 'http' | 'stdio';
|
||||
server_type: 'http' | 'stdio' | 'streamable_http' | 'sse';
|
||||
server_url?: string;
|
||||
command?: string;
|
||||
args?: string[];
|
||||
|
||||
Reference in New Issue
Block a user