feat: 重构MCP功能和AI服务提供者架构

This commit is contained in:
xiamuceer-j
2026-01-09 17:13:19 +08:00
parent f3c224261d
commit 77c5489ff8
49 changed files with 4763 additions and 4307 deletions
+118 -164
View File
@@ -45,7 +45,7 @@ from app.services.memory_service import memory_service
from app.services.chapter_regenerator import ChapterRegenerator
from app.logger import get_logger
from app.api.settings import get_user_ai_service
from app.utils.sse_response import create_sse_response
from app.utils.sse_response import SSEResponse, create_sse_response
router = APIRouter(prefix="/chapters", tags=["章节管理"])
logger = get_logger(__name__)
@@ -1172,7 +1172,6 @@ async def generate_chapter_content_stream(
"""
style_id = generate_request.style_id
target_word_count = generate_request.target_word_count or 3000
enable_mcp = generate_request.enable_mcp if hasattr(generate_request, 'enable_mcp') else True
custom_model = generate_request.model if hasattr(generate_request, 'model') else None
temp_narrative_perspective = generate_request.narrative_perspective if hasattr(generate_request, 'narrative_perspective') else None
# 预先验证章节存在性(使用临时会话)
@@ -1211,25 +1210,36 @@ async def generate_chapter_content_stream(
# 获取当前用户ID(在生成器外部就需要)
current_user_id = getattr(request.state, "user_id", "system")
# 初始化标准进度追踪器
from app.utils.sse_response import WizardProgressTracker
tracker = WizardProgressTracker("章节")
try:
yield await tracker.start()
# 创建新的数据库会话
async for db_session in get_db(request):
# === 加载阶段 ===
yield await tracker.loading("加载章节信息...", 0.2)
# 重新获取章节信息
chapter_result = await db_session.execute(
select(Chapter).where(Chapter.id == chapter_id)
)
current_chapter = chapter_result.scalar_one_or_none()
if not current_chapter:
yield f"data: {json.dumps({'type': 'error', 'error': '章节不存在'}, ensure_ascii=False)}\n\n"
yield await tracker.error("章节不存在", 404)
return
yield await tracker.loading("加载项目信息...", 0.4)
# 获取项目信息
project_result = await db_session.execute(
select(Project).where(Project.id == current_chapter.project_id)
)
project = project_result.scalar_one_or_none()
if not project:
yield f"data: {json.dumps({'type': 'error', 'error': '项目不存在'}, ensure_ascii=False)}\n\n"
yield await tracker.error("项目不存在", 404)
return
# 获取项目的大纲模式
@@ -1333,80 +1343,7 @@ async def generate_chapter_content_stream(
logger.info(f" - 相关记忆: {chapter_context.context_stats.get('memory_count', 0)}")
logger.info(f" - 总上下文长度: {chapter_context.context_stats.get('total_length', 0)} 字符")
# 发送开始事件
yield f"data: {json.dumps({'type': 'start', 'message': '开始AI创作...'}, ensure_ascii=False)}\n\n"
# 发送初始进度0%
yield f"data: {json.dumps({'type': 'progress', 'progress': 0, 'message': '准备生成...', 'status': 'processing'}, ensure_ascii=False)}\n\n"
# 🔧 MCP工具增强:收集章节参考资料(优化版)
mcp_reference_materials = ""
if enable_mcp and current_user_id:
try:
# 1️⃣ 静默检查工具可用性
from app.services.mcp_tool_service import mcp_tool_service
available_tools = await mcp_tool_service.get_user_enabled_tools(
user_id=current_user_id,
db_session=db_session
)
# 2️⃣ 只在有工具时才显示消息和调用
if available_tools:
yield f"data: {json.dumps({'type': 'progress', 'message': '🔍 使用MCP工具收集参考资料...', 'progress': 28}, ensure_ascii=False)}\n\n"
# 构建资料收集提示词
planning_prompt = f"""你正在为小说《{project.title}》创作第{current_chapter.chapter_number}章《{current_chapter.title}》。
【章节大纲】
{outline.content if outline else current_chapter.summary or '暂无大纲'}
【小说信息】
- 题材:{project.genre or '未设定'}
- 主题:{project.theme or '未设定'}
- 时代背景:{project.world_time_period or '未设定'}
- 地理位置:{project.world_location or '未设定'}
【任务】
请使用可用工具搜索相关背景资料,帮助创作更真实、更有深度的章节内容。
你可以查询:
1. 该章节涉及的历史事件或时代背景
2. 地理环境和场景描写参考
3. 相关领域的专业知识(如武术、科技、魔法等)
4. 文化习俗和生活细节
请根据章节内容,有针对性地查询1-2个最关键的问题。"""
# 调用MCP增强的AI(非流式,限制1轮避免超时)
planning_result = await user_ai_service.generate_text_with_mcp(
prompt=planning_prompt,
user_id=current_user_id,
db_session=db_session,
enable_mcp=True,
max_tool_rounds=2, # ✅ 减少为1轮,避免超时
tool_choice="auto",
provider=None,
model=None
)
# 3️⃣ 提取参考资料并显示结果
if planning_result.get("tool_calls_made", 0) > 0:
tool_count = planning_result["tool_calls_made"]
yield f"data: {json.dumps({'type': 'progress', 'message': f'✅ MCP工具调用成功({tool_count}次)', 'progress': 32}, ensure_ascii=False)}\n\n"
mcp_reference_materials = planning_result.get("content", "")
logger.info(f"📚 MCP工具收集参考资料:{len(mcp_reference_materials)} 字符")
else:
yield f"data: {json.dumps({'type': 'progress', 'message': '️ MCP未使用工具,继续', 'progress': 32}, ensure_ascii=False)}\n\n"
else:
logger.debug(f"用户 {current_user_id} 未启用MCP工具,跳过MCP增强")
# 未启用MCP时也发送进度,保持连贯性
yield f"data: {json.dumps({'type': 'progress', 'message': '准备生成内容...', 'progress': 10}, ensure_ascii=False)}\n\n"
except Exception as e:
logger.warning(f"⚠️ MCP工具调用失败,降级为基础模式: {str(e)}")
yield f"data: {json.dumps({'type': 'progress', 'message': '⚠️ MCP工具暂时不可用,使用基础模式', 'progress': 10}, ensure_ascii=False)}\n\n"
else:
# 如果未启用MCP,也发送基础进度
yield f"data: {json.dumps({'type': 'progress', 'message': '开始构建创作上下文...', 'progress': 10}, ensure_ascii=False)}\n\n"
yield await tracker.loading("上下文构建完成", 0.8)
# 🎭 确定使用的叙事人称(临时指定 > 项目默认 > 系统默认)
chapter_perspective = (
@@ -1496,26 +1433,17 @@ async def generate_chapter_content_stream(
characters_info=characters_info or '暂无角色信息'
)
# 添加 MCP 参考资料(如果有)
if mcp_reference_materials:
mcp_section = f"\n\n<mcp_reference>\n{mcp_reference_materials}\n</mcp_reference>"
base_prompt = base_prompt.replace("</task>", f"{mcp_section}\n</task>")
logger.info(f"📖 已整合MCP参考资料({len(mcp_reference_materials)}字符)")
# 应用写作风格
if style_content:
prompt = WritingStyleManager.apply_style_to_prompt(base_prompt, style_content)
else:
prompt = base_prompt
if mcp_reference_materials:
logger.info(f"📖 已整合MCP参考资料({len(mcp_reference_materials)}字符)到章节生成提示词")
# === 准备阶段 ===
yield await tracker.preparing("准备AI提示词...")
logger.info(f"开始AI流式创作章节 {chapter_id}")
# 发送开始生成的进度
yield f"data: {json.dumps({'type': 'progress', 'progress': 10, 'message': '开始AI创作...', 'status': 'processing'}, ensure_ascii=False)}\n\n"
# 🎨 方案一:将写作风格注入到系统提示词(最高优先级)
system_prompt_with_style = None
if style_content:
@@ -1530,7 +1458,8 @@ async def generate_chapter_content_stream(
# 准备生成参数
generate_kwargs = {
"prompt": prompt,
"system_prompt": system_prompt_with_style # 🔑 关键:使用系统提示词传递风格
"system_prompt": system_prompt_with_style,
"tool_choice": "required"
}
if custom_model:
logger.info(f" 使用自定义模型: {custom_model}")
@@ -1538,47 +1467,38 @@ async def generate_chapter_content_stream(
# 注意:这里使用用户配置的AI服务,模型参数会覆盖默认模型
# 如果需要切换provider,需要在前端传递provider参数
# 流式生成内容
# === 生成阶段 ===
full_content = ""
chunk_count = 0
last_progress = 0
yield await tracker.generating(
current_chars=0,
estimated_total=target_word_count
)
async for chunk in user_ai_service.generate_text_stream(**generate_kwargs):
full_content += chunk
chunk_count += 1
# 发送内容块
yield f"data: {json.dumps({'type': 'content', 'content': chunk}, ensure_ascii=False)}\n\n"
yield await tracker.generating_chunk(chunk)
# 每5个chunk发送一次进度更新10-95%,更平滑)
# 每5个chunk发送一次进度更新
if chunk_count % 5 == 0:
current_word_count = len(full_content)
# 优化进度计算:使用更平滑的递增方式
# 基于chunk数量和字数的混合计算,避免大幅跳跃
chunk_progress = min(40, chunk_count // 5) # chunk贡献最多40%
word_progress = min(45, int((current_word_count / target_word_count) * 45)) # 字数贡献最多45%
estimated_progress = min(95, 10 + chunk_progress + word_progress)
# 只在进度变化时发送
if estimated_progress > last_progress:
progress_data = {
'type': 'progress',
'progress': estimated_progress,
'message': f'正在创作中... 已生成 {current_word_count}',
'word_count': current_word_count,
'status': 'processing'
}
yield f"data: {json.dumps(progress_data, ensure_ascii=False)}\n\n"
last_progress = estimated_progress
yield await tracker.generating(
current_chars=len(full_content),
estimated_total=target_word_count,
message=f'正在创作中... 已生成 {len(full_content)}'
)
# 每20个chunk发送心跳
if chunk_count % 20 == 0:
yield f"data: {json.dumps({'type': 'heartbeat'}, ensure_ascii=False)}\n\n"
yield await tracker.heartbeat()
await asyncio.sleep(0) # 让出控制权
# 发送保存进度
yield f"data: {json.dumps({'type': 'progress', 'progress': 97, 'message': '正在保存章节...', 'status': 'processing'}, ensure_ascii=False)}\n\n"
# === 保存阶段 ===
yield await tracker.saving("正在保存章节...", 0.3)
# 更新章节内容到数据库
old_word_count = current_chapter.word_count or 0
@@ -1634,25 +1554,28 @@ async def generate_chapter_content_stream(
ai_service=user_ai_service
)
# 发送最终进度100%
yield f"data: {json.dumps({'type': 'progress', 'progress': 99, 'message': '创作完成!', 'word_count': new_word_count, 'status': 'success'}, ensure_ascii=False)}\n\n"
yield await tracker.saving("章节保存完成", 0.8)
# 发送完成事件(包含分析任务ID
completion_data = {
'type': 'done',
'message': '创作完成',
# === 完成阶段 ===
yield await tracker.complete("创作完成!")
# 发送结果数据
yield await tracker.result({
'word_count': new_word_count,
'analysis_task_id': task_id
}
yield f"data: {json.dumps(completion_data, ensure_ascii=False)}\n\n"
})
# 发送分析开始事件
analysis_started_data = {
'type': 'analysis_started',
'task_id': task_id,
'message': '章节分析已开始'
}
yield f"data: {json.dumps(analysis_started_data, ensure_ascii=False)}\n\n"
# 发送分析开始事件(使用自定义事件)
yield await SSEResponse.send_event(
event='analysis_started',
data={
'task_id': task_id,
'message': '章节分析已开始'
}
)
# 发送完成信号
yield await tracker.done()
break # 退出async for db_session循环
@@ -1675,7 +1598,7 @@ async def generate_chapter_content_stream(
logger.info("章节生成事务已回滚(异常)")
except Exception as rollback_error:
logger.error(f"回滚失败: {str(rollback_error)}")
yield f"data: {json.dumps({'type': 'error', 'error': str(e)}, ensure_ascii=False)}\n\n"
yield await tracker.error(str(e))
finally:
# 确保数据库会话被正确关闭
if db_session:
@@ -2813,7 +2736,8 @@ async def generate_single_chapter_for_batch(
# 准备生成参数
generate_kwargs = {
"prompt": prompt,
"system_prompt": system_prompt_with_style # 🔑 关键:使用系统提示词传递风格
"system_prompt": system_prompt_with_style,
"tool_choice": "required"
}
# 如果传入了自定义模型,使用指定的模型
if custom_model:
@@ -3029,11 +2953,16 @@ async def regenerate_chapter_stream(
db_session = None
db_committed = False
# 初始化标准进度追踪器
from app.utils.sse_response import WizardProgressTracker
tracker = WizardProgressTracker("章节重新生成")
try:
yield await tracker.start()
# 创建独立数据库会话
async for db_session in get_db(request):
# 发送开始事件
yield f"data: {json.dumps({'type': 'start', 'message': '开始重新生成章节...'}, ensure_ascii=False)}\n\n"
yield await tracker.loading("加载章节信息...", 0.5)
# 创建重新生成任务
regen_task = RegenerationTask(
@@ -3062,13 +2991,25 @@ async def regenerate_chapter_stream(
task_id = regen_task.id
logger.info(f"📝 创建重新生成任务: {task_id}")
yield f"data: {json.dumps({'type': 'task_created', 'task_id': task_id}, ensure_ascii=False)}\n\n"
yield await tracker.preparing("准备重新生成...")
yield await SSEResponse.send_event(
event='task_created',
data={'task_id': task_id}
)
# 初始化重新生成器
regenerator = ChapterRegenerator(user_ai_service)
# 流式生成新内容
# === 生成阶段 ===
full_content = ""
estimated_total = regenerate_request.target_word_count or len(chapter.content)
yield await tracker.generating(
current_chars=0,
estimated_total=estimated_total
)
async for event in regenerator.regenerate_with_feedback(
chapter=chapter,
analysis=analysis,
@@ -3083,19 +3024,35 @@ async def regenerate_chapter_stream(
# 内容块
chunk = event['content']
full_content += chunk
yield f"data: {json.dumps({'type': 'chunk', 'content': chunk}, ensure_ascii=False)}\n\n"
yield await tracker.generating_chunk(chunk)
# 定期更新进度
if len(full_content) % 500 == 0:
yield await tracker.generating(
current_chars=len(full_content),
estimated_total=estimated_total,
message=f'重新生成中... 已生成 {len(full_content)}'
)
elif event['type'] == 'progress':
# 进度更新
progress_data = {
'type': 'progress',
'progress': event.get('progress', 0),
'message': event.get('message', ''),
'word_count': event.get('word_count', 0)
}
yield f"data: {json.dumps(progress_data, ensure_ascii=False)}\n\n"
# 进度更新 - 映射到对应阶段
progress = event.get('progress', 0)
message = event.get('message', '')
if progress < 20:
yield await tracker.preparing(message)
elif progress < 85:
yield await tracker.generating(
current_chars=len(full_content),
estimated_total=estimated_total,
message=message
)
else:
yield await tracker.parsing(message)
await asyncio.sleep(0)
# === 保存阶段 ===
yield await tracker.saving("保存重新生成的内容...", 0.5)
# 更新任务状态
regen_task.status = 'completed'
regen_task.regenerated_content = full_content
@@ -3108,25 +3065,22 @@ async def regenerate_chapter_stream(
await db_session.commit()
db_committed = True
# 先发送结果数据
result_data = {
'type': 'result',
'data': {
'task_id': task_id,
'word_count': len(full_content),
'version_number': regen_task.version_number,
'auto_applied': regenerate_request.auto_apply,
'diff_stats': diff_stats
}
}
yield f"data: {json.dumps(result_data, ensure_ascii=False)}\n\n"
yield await tracker.saving("保存完成", 0.9)
# 再发送完成事件
completion_data = {
'type': 'done',
'message': '重新生成完成'
}
yield f"data: {json.dumps(completion_data, ensure_ascii=False)}\n\n"
# === 完成阶段 ===
yield await tracker.complete("重新生成完成!")
# 发送结果数据
yield await tracker.result({
'task_id': task_id,
'word_count': len(full_content),
'version_number': regen_task.version_number,
'auto_applied': regenerate_request.auto_apply,
'diff_stats': diff_stats
})
# 发送完成信号
yield await tracker.done()
logger.info(f"✅ 章节重新生成完成: {chapter_id}, 任务: {task_id}")
@@ -3151,7 +3105,7 @@ async def regenerate_chapter_stream(
except Exception as update_error:
logger.error(f"更新任务失败状态失败: {str(update_error)}")
yield f"data: {json.dumps({'type': 'error', 'error': str(e)}, ensure_ascii=False)}\n\n"
yield await tracker.error(str(e))
finally:
if db_session:
+28 -130
View File
@@ -775,148 +775,46 @@ async def generate_character_stream(
logger.info(f"🎯 开始为项目 {request.project_id} 生成角色(SSE流式)")
try:
# 🔧 MCP工具增强:静默检查并收集参考资料
# 直接使用 AIService 流式生成
ai_response = ""
chunk_count = 0
if user_id:
try:
from app.services.mcp_tool_service import mcp_tool_service
available_tools = await mcp_tool_service.get_user_enabled_tools(
user_id=user_id,
db_session=db
)
# 只在有工具时才调用
if available_tools:
logger.info(f"🔍 检测到可用MCP工具,尝试收集参考资料...")
result = await user_ai_service.generate_text_with_mcp(
prompt=prompt,
user_id=user_id,
db_session=db,
enable_mcp=True,
max_tool_rounds=2,
tool_choice="auto",
provider=None,
model=None
)
if isinstance(result, dict):
ai_response = result.get('content', '')
finish_reason = result.get('finish_reason', '')
tool_calls_made = result.get('tool_calls_made', 0)
# 🔧 修复:检查工具调用是否真正成功
if tool_calls_made > 0:
if finish_reason == 'tool_error':
logger.warning(f"⚠️ MCP工具调用失败,降级为基础模式")
# 工具调用失败,重新用基础模式生成
ai_response = ""
elif not ai_response.strip():
logger.warning(f"⚠️ MCP工具调用后返回空响应,降级为基础模式")
# 工具调用成功但返回空内容,重新生成
ai_response = ""
else:
logger.info(f"✅ MCP工具调用成功({tool_calls_made}次),内容长度: {len(ai_response)}")
# MCP成功且有内容,模拟流式输出(分块发送)
chunk_size = 50
for i in range(0, len(ai_response), chunk_size):
chunk = ai_response[i:i+chunk_size]
chunk_count += 1
yield await SSEResponse.send_chunk(chunk)
if chunk_count % 3 == 0:
yield await SSEResponse.send_progress(
f"AI生成角色中... ({i+len(chunk)}/{len(ai_response)}字符)",
10 + min(85 * (i+len(chunk)) // len(ai_response), 85)
)
# 跳过后续的流式生成
ai_response = result.get('content', '')
else:
ai_response = result
# 如果MCP调用失败或返回空,继续走流式生成
if not ai_response or not ai_response.strip():
logger.info(f"🔄 开始流式生成...")
ai_response = ""
async for chunk in user_ai_service.generate_text_stream(prompt=prompt):
chunk_count += 1
ai_response += chunk
# 发送内容块
yield await SSEResponse.send_chunk(chunk)
# 定期更新进度
if chunk_count % 5 == 0:
yield await SSEResponse.send_progress(
f"AI生成角色中... ({len(ai_response)}字符)",
10 + min(chunk_count // 2, 85)
)
# 心跳
if chunk_count % 20 == 0:
yield await SSEResponse.send_heartbeat()
else:
logger.debug(f"用户 {user_id} 未启用MCP工具,使用流式基础模式")
async for chunk in user_ai_service.generate_text_stream(prompt=prompt):
chunk_count += 1
ai_response += chunk
# 发送内容块
yield await SSEResponse.send_chunk(chunk)
# 定期更新进度
if chunk_count % 5 == 0:
yield await SSEResponse.send_progress(
f"AI生成角色中... ({len(ai_response)}字符)",
10 + min(chunk_count // 2, 85)
)
# 心跳
if chunk_count % 20 == 0:
yield await SSEResponse.send_heartbeat()
except Exception as mcp_error:
logger.warning(f"⚠️ MCP工具调用异常,降级为流式基础模式: {str(mcp_error)}")
ai_response = ""
async for chunk in user_ai_service.generate_text_stream(prompt=prompt):
chunk_count += 1
ai_response += chunk
# 发送内容块
yield await SSEResponse.send_chunk(chunk)
# 定期更新进度
if chunk_count % 5 == 0:
yield await SSEResponse.send_progress(
f"AI生成角色中... ({len(ai_response)}字符)",
10 + min(chunk_count // 2, 85)
)
# 心跳
if chunk_count % 20 == 0:
yield await SSEResponse.send_heartbeat()
else:
logger.debug(f"未登录用户,使用流式基础模式")
async for chunk in user_ai_service.generate_text_stream(prompt=prompt):
chunk_count += 1
ai_response += chunk
logger.info(f"🎯 开始生成角色(流式模式)...")
yield await SSEResponse.send_progress("🎯 开始生成角色...", 15)
async for chunk in user_ai_service.generate_text_stream(
prompt=prompt,
tool_choice="required",
):
# chunk 现在可能是 dict 或 str,提取 content 字段
if isinstance(chunk, dict):
content = chunk.get("content", "")
else:
content = chunk
if content:
ai_response += content
# 发送内容块
yield await SSEResponse.send_chunk(chunk)
yield await SSEResponse.send_chunk(content)
# 定期更新进度
if chunk_count % 5 == 0:
# 定期更新进度(每收到约500字符更新一次,避免过于频繁)
current_len = len(ai_response)
if current_len >= chunk_count * 500:
chunk_count += 1
# 使用实际字符数量计算进度,上限85%(留15%给后续解析和保存)
# 估算最终字符数约为提示词的8倍,最少3000字符
estimated_total = max(3000, len(prompt) * 8)
progress = min(15 + int(current_len / estimated_total * 70), 85)
yield await SSEResponse.send_progress(
f"AI生成角色中... ({len(ai_response)}字符)",
10 + min(chunk_count // 2, 85)
f"AI生成角色中... ({current_len}字符)",
progress
)
# 心跳
if chunk_count % 20 == 0:
yield await SSEResponse.send_heartbeat()
except Exception as ai_error:
logger.error(f"❌ AI服务调用异常:{str(ai_error)}")
yield await SSEResponse.send_error(f"AI服务调用失败:{str(ai_error)}")
+210 -73
View File
@@ -1,4 +1,7 @@
"""MCP插件管理API"""
"""MCP插件管理API
重构后使用统一的MCPClientFacade门面来管理所有MCP操作。
"""
from fastapi import APIRouter, HTTPException, Depends, Query, Request
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
@@ -17,9 +20,8 @@ from app.schemas.mcp_plugin import (
)
import json
from app.user_manager import User
from app.mcp.registry import mcp_registry
from app.mcp import mcp_client, MCPPluginConfig, PluginStatus
from app.services.mcp_test_service import mcp_test_service
from app.services.mcp_tool_service import mcp_tool_service
from app.logger import get_logger
logger = get_logger(__name__)
@@ -34,6 +36,31 @@ def require_login(request: Request) -> User:
return request.state.user
async def _register_plugin_to_facade(plugin: MCPPlugin, user_id: str) -> bool:
"""
将插件注册到统一门面
Args:
plugin: 插件对象
user_id: 用户ID
Returns:
是否注册成功
"""
if plugin.plugin_type in ["http", "streamable_http", "sse"] and plugin.server_url:
return await mcp_client.register(MCPPluginConfig(
user_id=user_id,
plugin_name=plugin.plugin_name,
url=plugin.server_url,
plugin_type=plugin.plugin_type,
headers=plugin.headers,
timeout=plugin.config.get('timeout', 60.0) if plugin.config else 60.0
))
else:
logger.warning(f"暂不支持的插件类型: {plugin.plugin_type}")
return False
@router.get("", response_model=List[MCPPluginResponse])
async def list_plugins(
enabled_only: bool = Query(False, description="只返回启用的插件"),
@@ -99,9 +126,9 @@ async def create_plugin(
await db.commit()
await db.refresh(plugin)
# 如果启用,加载到注册表
# 如果启用,注册到统一门面
if plugin.enabled:
success = await mcp_registry.load_plugin(plugin)
success = await _register_plugin_to_facade(plugin, user.user_id)
if success:
plugin.status = "active"
else:
@@ -153,7 +180,7 @@ async def create_plugin_simple(
# 提取配置
server_type = server_config.get("type", "http")
if server_type not in ["http", "stdio"]:
if server_type not in ["http", "stdio", "streamable_http", "sse"]:
raise HTTPException(status_code=400, detail=f"不支持的服务器类型: {server_type}")
# 检查插件名是否已存在
@@ -175,12 +202,12 @@ async def create_plugin_simple(
"sort_order": 0
}
if server_type == "http":
if server_type in ["http", "streamable_http", "sse"]:
plugin_data["server_url"] = server_config.get("url")
plugin_data["headers"] = server_config.get("headers", {})
if not plugin_data["server_url"]:
raise HTTPException(status_code=400, detail="HTTP类型插件必须提供url字段")
raise HTTPException(status_code=400, detail=f"{server_type}类型插件必须提供url字段")
elif server_type == "stdio":
plugin_data["command"] = server_config.get("command")
@@ -194,9 +221,9 @@ async def create_plugin_simple(
# 更新现有插件
logger.info(f"插件 {plugin_name} 已存在,执行更新操作")
# 先卸载旧插件
if existing.enabled:
await mcp_registry.unload_plugin(user.user_id, existing.plugin_name)
# 保存旧状态
old_enabled = existing.enabled
old_plugin_name = existing.plugin_name
# 更新字段
for key, value in plugin_data.items():
@@ -206,17 +233,24 @@ async def create_plugin_simple(
await db.commit()
await db.refresh(plugin)
# 如果启用,重新加载
# 数据库完成后进行MCP操作
if old_enabled:
try:
await mcp_client.unregister(user.user_id, old_plugin_name)
except Exception as e:
logger.warning(f"注销旧插件出错: {e}")
if plugin.enabled:
success = await mcp_registry.load_plugin(plugin)
if success:
plugin.status = "active"
plugin.last_error = None
else:
try:
success = await _register_plugin_to_facade(plugin, user.user_id)
plugin.status = "active" if success else "error"
plugin.last_error = None if success else "加载失败"
await db.commit()
except Exception as e:
logger.error(f"注册插件失败: {e}")
plugin.status = "error"
plugin.last_error = "加载失败"
await db.commit()
await db.refresh(plugin)
plugin.last_error = str(e)
await db.commit()
logger.info(f"用户 {user.user_id} 更新插件: {plugin_name}")
else:
@@ -230,16 +264,18 @@ async def create_plugin_simple(
await db.commit()
await db.refresh(plugin)
# 如果启用,加载到注册表
# 数据库完成后进行MCP操作
if plugin.enabled:
success = await mcp_registry.load_plugin(plugin)
if success:
plugin.status = "active"
else:
try:
success = await _register_plugin_to_facade(plugin, user.user_id)
plugin.status = "active" if success else "error"
plugin.last_error = None if success else "加载失败"
await db.commit()
except Exception as e:
logger.error(f"注册插件失败: {e}")
plugin.status = "error"
plugin.last_error = "加载失败"
await db.commit()
await db.refresh(plugin)
plugin.last_error = str(e)
await db.commit()
logger.info(f"用户 {user.user_id} 通过简化配置创建插件: {plugin_name}")
@@ -306,9 +342,10 @@ async def update_plugin(
await db.commit()
await db.refresh(plugin)
# 如果插件已启用,重新加载
# 如果插件已启用,重新注册
if plugin.enabled:
await mcp_registry.reload_plugin(plugin)
await mcp_client.unregister(user.user_id, plugin.plugin_name)
await _register_plugin_to_facade(plugin, user.user_id)
logger.info(f"用户 {user.user_id} 更新插件: {plugin.plugin_name}")
return plugin
@@ -334,8 +371,8 @@ async def delete_plugin(
if not plugin:
raise HTTPException(status_code=404, detail="插件不存在")
# 从注册表卸载
await mcp_registry.unload_plugin(user.user_id, plugin.plugin_name)
# 从统一门面注销
await mcp_client.unregister(user.user_id, plugin.plugin_name)
# 删除数据库记录
await db.delete(plugin)
@@ -366,27 +403,57 @@ async def toggle_plugin(
if not plugin:
raise HTTPException(status_code=404, detail="插件不存在")
plugin.enabled = enabled
# 保存插件信息用于后续MCP操作
plugin_name = plugin.plugin_name
plugin_type = plugin.plugin_type
server_url = plugin.server_url
headers = plugin.headers
config = plugin.config
if enabled:
# 启用:加载到注册表
success = await mcp_registry.load_plugin(plugin)
if success:
plugin.status = "active"
plugin.last_error = None
else:
plugin.status = "error"
plugin.last_error = "加载失败"
else:
# 禁用:从注册表卸载
await mcp_registry.unload_plugin(user.user_id, plugin.plugin_name)
# 先更新数据库状态
plugin.enabled = enabled
if not enabled:
plugin.status = "inactive"
await db.commit()
await db.refresh(plugin)
# 数据库操作完成后,再进行MCP操作
if enabled:
# 启用:注册到统一门面
try:
if plugin_type in ["http", "streamable_http", "sse"] and server_url:
success = await mcp_client.register(MCPPluginConfig(
user_id=user.user_id,
plugin_name=plugin_name,
url=server_url,
plugin_type=plugin_type,
headers=headers,
timeout=config.get('timeout', 60.0) if config else 60.0
))
else:
success = False
# 更新状态
plugin.status = "active" if success else "error"
plugin.last_error = None if success else "加载失败"
await db.commit()
await db.refresh(plugin)
except Exception as e:
logger.error(f"注册插件失败: {plugin_name}, 错误: {e}")
plugin.status = "error"
plugin.last_error = str(e)
await db.commit()
await db.refresh(plugin)
else:
# 禁用:从统一门面注销(不影响数据库状态)
try:
await mcp_client.unregister(user.user_id, plugin_name)
except Exception as e:
logger.warning(f"注销插件时出错(可忽略): {plugin_name}, 错误: {e}")
action = "启用" if enabled else "禁用"
logger.info(f"用户 {user.user_id} {action}插件: {plugin.plugin_name}")
logger.info(f"用户 {user.user_id} {action}插件: {plugin_name}")
return plugin
@@ -399,7 +466,7 @@ async def test_plugin(
"""
测试插件连接并调用工具验证功能
使用新的MCPTestService进行测试
使用MCPTestService进行测试
"""
result = await db.execute(
@@ -421,7 +488,7 @@ async def test_plugin(
suggestions=["点击开关按钮启用插件"]
)
# 使用新的测试服务
# 使用测试服务
try:
test_result = await mcp_test_service.test_plugin_with_ai(plugin, user, db)
@@ -447,32 +514,77 @@ async def test_plugin(
raise HTTPException(status_code=500, detail=f"测试失败: {str(e)}")
async def _ensure_plugin_loaded(
async def _ensure_plugin_registered(
plugin: MCPPlugin,
user_id: str
) -> bool:
"""
确保插件已加载(共享逻辑)
确保插件已注册到统一门面
Args:
plugin: 插件对象
user_id: 用户ID
Returns:
是否加载成功
是否成功
Raises:
HTTPException: 加载失败
HTTPException: 注册失败
"""
if not mcp_registry.get_client(user_id, plugin.plugin_name):
logger.info(f"插件 {plugin.plugin_name} 未加载,自动加载中...")
success = await mcp_registry.load_plugin(plugin)
try:
# 使用ensure_registered方法,它会检查是否已注册
if plugin.plugin_type in ["http", "streamable_http", "sse"] and plugin.server_url:
return await mcp_client.ensure_registered(
user_id=user_id,
plugin_name=plugin.plugin_name,
url=plugin.server_url,
plugin_type=plugin.plugin_type,
headers=plugin.headers
)
return False
except ValueError as e:
logger.info(f"插件 {plugin.plugin_name} 未注册,自动注册中...")
success = await _register_plugin_to_facade(plugin, user_id)
if not success:
raise HTTPException(
status_code=500,
detail=f"插件加载失败: {plugin.plugin_name}"
detail=f"插件注册失败: {plugin.plugin_name}"
)
return True
return True
@router.get("/{plugin_id}/status")
async def get_plugin_status(
plugin_id: str,
user: User = Depends(require_login),
db: AsyncSession = Depends(get_db)
):
"""获取插件的实时状态(包括内存中的会话状态)"""
result = await db.execute(
select(MCPPlugin).where(
MCPPlugin.id == plugin_id,
MCPPlugin.user_id == user.user_id
)
)
plugin = result.scalar_one_or_none()
if not plugin:
raise HTTPException(status_code=404, detail="插件不存在")
session_stats = mcp_client.get_session_stats()
session_key = f"{user.user_id}:{plugin.plugin_name}"
session_info = next((s for s in session_stats.get("sessions", []) if s["key"] == session_key), None)
return {
"plugin_id": plugin_id,
"plugin_name": plugin.plugin_name,
"db_status": plugin.status,
"session_status": session_info["status"] if session_info else None,
"is_registered": session_info is not None,
"error_rate": session_info["error_rate"] if session_info else 0,
"in_sync": (plugin.status == session_info["status"]) if session_info else (plugin.status == "inactive"),
"timestamp": datetime.now().isoformat()
}
@router.get("/metrics")
@@ -495,7 +607,8 @@ async def get_metrics(
- avg_duration_ms: 平均耗时(毫秒)
- last_call_time: 最后调用时间
"""
metrics = mcp_tool_service.get_metrics(tool_name)
# 使用统一门面获取指标
metrics = mcp_client.get_metrics(tool_name)
return {
"metrics": metrics,
@@ -518,7 +631,8 @@ async def get_cache_stats(
- cache_ttl_minutes: 缓存TTL(分钟)
- entries: 各缓存条目详情
"""
stats = mcp_tool_service.get_cache_stats()
# 使用统一门面获取缓存统计
stats = mcp_client.get_cache_stats()
return {
"cache_stats": stats,
@@ -526,6 +640,27 @@ async def get_cache_stats(
}
@router.get("/sessions/stats")
async def get_session_stats(
user: User = Depends(require_login)
):
"""
获取MCP会话统计信息
Returns:
会话统计信息,包含:
- total_sessions: 会话总数
- sessions: 各会话详情
"""
# 使用统一门面获取会话统计
stats = mcp_client.get_session_stats()
return {
"session_stats": stats,
"timestamp": datetime.now().isoformat()
}
@router.post("/cache/clear")
async def clear_cache(
user_id: Optional[str] = Query(None, description="用户ID(可选)"),
@@ -551,7 +686,8 @@ async def clear_cache(
# 如果没有指定user_id,使用当前用户
target_user_id = user_id or user.user_id
mcp_tool_service.clear_cache(target_user_id, plugin_name)
# 使用统一门面清理缓存
mcp_client.clear_cache(target_user_id, plugin_name)
message = "已清理"
if plugin_name:
@@ -594,12 +730,13 @@ async def get_plugin_tools(
raise HTTPException(status_code=400, detail="插件未启用")
try:
# 确保插件已加载
await _ensure_plugin_loaded(plugin, user.user_id)
# 确保插件已注册
await _ensure_plugin_registered(plugin, user.user_id)
tools = await mcp_registry.get_plugin_tools(user.user_id, plugin.plugin_name)
# 使用统一门面获取工具列表
tools = await mcp_client.get_tools(user.user_id, plugin.plugin_name)
# 更新缓存
# 更新数据库中的工具缓存
plugin.tools = tools
await db.commit()
@@ -640,22 +777,22 @@ async def call_mcp_tool(
raise HTTPException(status_code=400, detail="插件未启用")
try:
# 确保插件已加载
await _ensure_plugin_loaded(plugin, user.user_id)
# 确保插件已注册
await _ensure_plugin_registered(plugin, user.user_id)
# 调用工具
result = await mcp_registry.call_tool(
user.user_id,
plugin.plugin_name,
data.tool_name,
data.arguments
# 使用统一门面调用工具
tool_result = await mcp_client.call_tool(
user_id=user.user_id,
plugin_name=plugin.plugin_name,
tool_name=data.tool_name,
arguments=data.arguments
)
return {
"success": True,
"plugin_name": plugin.plugin_name,
"tool_name": data.tool_name,
"result": result
"result": tool_result
}
except HTTPException:
raise
File diff suppressed because it is too large Load Diff
+253 -7
View File
@@ -22,7 +22,7 @@ from app.schemas.settings import (
from app.user_manager import User
from app.logger import get_logger
from app.config import settings as app_settings, PROJECT_ROOT
from app.services.ai_service import AIService, create_user_ai_service
from app.services.ai_service import AIService, create_user_ai_service, create_user_ai_service_with_mcp
logger = get_logger(__name__)
@@ -53,9 +53,14 @@ async def get_user_ai_service(
db: AsyncSession = Depends(get_db)
) -> AIService:
"""
依赖:获取当前用户的AI服务实例
从数据库读取用户设置并创建对应的AI服务
依赖:获取当前用户的AI服务实例(支持MCP工具自动加载)
从数据库读取用户设置并创建对应的AI服务。
自动传递 user_id 和 db_session,使得 AIService 能够加载用户配置的MCP工具。
根据用户的所有MCP插件状态决定是否启用MCP:如果有启用的插件则启用,否则禁用。
"""
from app.models.mcp_plugin import MCPPlugin
result = await db.execute(
select(Settings).where(Settings.user_id == user.user_id)
)
@@ -73,15 +78,34 @@ async def get_user_ai_service(
await db.refresh(settings)
logger.info(f"用户 {user.user_id} 首次使用AI服务,已从.env同步设置到数据库")
# 使用用户设置创建AI服务实例(包括系统提示词)
return create_user_ai_service(
# 查询用户的所有MCP插件状态
mcp_result = await db.execute(
select(MCPPlugin).where(MCPPlugin.user_id == user.user_id)
)
mcp_plugins = mcp_result.scalars().all()
# 检查是否有启用的MCP插件
enable_mcp = any(plugin.enabled for plugin in mcp_plugins) if mcp_plugins else False
if mcp_plugins:
enabled_count = sum(1 for p in mcp_plugins if p.enabled)
logger.info(f"用户 {user.user_id}{len(mcp_plugins)} 个MCP插件,{enabled_count} 个启用,{enable_mcp} 决定使用MCP")
else:
logger.debug(f"用户 {user.user_id} 没有配置MCP插件,禁用MCP")
# ✅ 使用支持MCP的工厂函数创建AI服务实例
# 传递 user_id 和 db_session,使得 AIService 能够自动加载用户配置的MCP工具
return create_user_ai_service_with_mcp(
api_provider=settings.api_provider,
api_key=settings.api_key,
api_base_url=settings.api_base_url or "",
model_name=settings.llm_model,
temperature=settings.temperature,
max_tokens=settings.max_tokens,
system_prompt=settings.system_prompt # 传递系统提示词
user_id=user.user_id, # ✅ 传递 user_id
db_session=db, # ✅ 传递 db_session
system_prompt=settings.system_prompt,
enable_mcp=enable_mcp, # 根据MCP插件状态动态决定
)
@@ -327,6 +351,227 @@ class ApiTestRequest(BaseModel):
llm_model: str
@router.post("/check-function-calling")
async def check_function_calling_support(data: ApiTestRequest):
"""
检查模型是否支持 Function Calling(工具调用)
基于业界最佳实践的测试方法:
1. 发送包含工具定义的请求
2. 检查响应的 finish_reason 是否为 "tool_calls"
3. 验证响应中是否包含有效的 tool_calls 数据
Args:
data: 包含 API 配置的请求数据
Returns:
检测结果包含支持状态、详细信息和建议
"""
api_key = data.api_key
api_base_url = data.api_base_url
provider = data.provider
llm_model = data.llm_model
try:
start_time = time.time()
# 定义一个简单的测试工具(天气查询)
test_tools = [{
"type": "function",
"function": {
"name": "get_weather",
"description": "获取指定城市的当前天气信息",
"parameters": {
"type": "object",
"properties": {
"city": {
"type": "string",
"description": "城市名称,例如:北京、上海、深圳"
},
"unit": {
"type": "string",
"enum": ["celsius", "fahrenheit"],
"description": "温度单位"
}
},
"required": ["city"]
}
}
}]
# 测试提示:故意设计一个需要调用工具的问题
test_prompt = "请告诉我北京现在的天气情况如何?"
logger.info(f"🧪 开始检测 Function Calling 支持")
logger.info(f" - 提供商: {provider}")
logger.info(f" - 模型: {llm_model}")
logger.info(f" - 测试工具: get_weather")
# 创建临时 AI 服务实例进行测试
test_service = AIService(
api_provider=provider,
api_key=api_key,
api_base_url=api_base_url,
default_model=llm_model,
default_temperature=0.3, # 使用较低温度以获得更确定的行为
default_max_tokens=200
)
# 发送带工具的测试请求
response = await test_service.generate_text(
prompt=test_prompt,
provider=provider,
model=llm_model,
temperature=0.3,
max_tokens=200,
tools=test_tools,
tool_choice="auto", # 让模型自动决定是否使用工具
auto_mcp=False # 禁用 MCP 自动加载
)
end_time = time.time()
response_time = round((end_time - start_time) * 1000, 2)
# 分析响应以确定是否支持 Function Calling
supported = False
finish_reason = None
tool_calls = None
response_content = None
if isinstance(response, dict):
# 检查 finish_reasonOpenAI 标准)
finish_reason = response.get("finish_reason")
# 检查是否有 tool_calls
if "tool_calls" in response and response["tool_calls"]:
supported = True
tool_calls = response["tool_calls"]
logger.info(f"✅ 检测到工具调用: {len(tool_calls)}")
# 记录返回的内容(如果有)
if "content" in response:
response_content = response["content"]
elif isinstance(response, str):
# 如果只返回字符串,说明不支持工具调用
response_content = response
logger.info(f" - 响应时间: {response_time}ms")
logger.info(f" - finish_reason: {finish_reason}")
logger.info(f" - 支持状态: {'✅ 支持' if supported else '❌ 不支持'}")
# 构建详细的返回信息
result = {
"success": True,
"supported": supported,
"message": "✅ 模型支持 Function Calling" if supported else "❌ 模型不支持 Function Calling",
"response_time_ms": response_time,
"provider": provider,
"model": llm_model,
"details": {
"finish_reason": finish_reason,
"has_tool_calls": bool(tool_calls),
"tool_call_count": len(tool_calls) if tool_calls else 0,
"test_tool": "get_weather",
"test_prompt": test_prompt,
"response_type": "tool_calls" if supported else "text"
}
}
# 添加工具调用详情
if tool_calls:
result["tool_calls"] = tool_calls
result["suggestions"] = [
"✅ 该模型支持 Function Calling,可以正常使用 MCP 插件",
"建议:启用需要的 MCP 插件以扩展 AI 能力",
"提示:测试成功检测到工具调用,模型能够正确解析和使用外部工具"
]
else:
result["response_preview"] = response_content[:200] if response_content else None
result["suggestions"] = [
"❌ 该模型不支持 Function Calling,无法使用 MCP 插件功能",
"建议:更换支持工具调用的模型",
"推荐模型:GPT-4 系列、GPT-4-turbo、Claude 3 Opus/Sonnet、Gemini 1.5 Pro 等",
"说明:模型返回了文本回复而非工具调用,表明不支持该功能"
]
return result
except ValueError as e:
error_msg = str(e)
logger.error(f"❌ Function Calling 检测配置错误: {error_msg}")
return {
"success": False,
"supported": False,
"message": "配置错误",
"error": error_msg,
"error_type": "ConfigurationError",
"suggestions": [
"请检查 API Key 是否正确",
"请确认 API Base URL 格式是否正确",
"请验证所选提供商与配置是否匹配"
]
}
except TimeoutError as e:
error_msg = str(e)
logger.error(f"❌ Function Calling 检测超时: {error_msg}")
return {
"success": False,
"supported": None,
"message": "检测超时",
"error": error_msg,
"error_type": "TimeoutError",
"suggestions": [
"请检查网络连接是否正常",
"请确认 API 服务是否可访问",
"建议:稍后重试或使用其他网络环境"
]
}
except Exception as e:
error_msg = str(e)
error_type = type(e).__name__
logger.error(f"❌ Function Calling 检测失败: {error_msg}")
logger.error(f" - 错误类型: {error_type}")
# 智能分析错误原因
suggestions = []
if "tool" in error_msg.lower() or "function" in error_msg.lower():
suggestions = [
"该模型可能不支持 Function Calling 功能",
"API 返回了与工具调用相关的错误",
"建议:更换支持工具调用的模型或联系 API 提供商"
]
elif "unauthorized" in error_msg.lower() or "401" in error_msg:
suggestions = [
"API Key 认证失败",
"请检查 API Key 是否正确且有效",
"请确认 API Key 是否有足够的权限"
]
elif "not found" in error_msg.lower() or "404" in error_msg:
suggestions = [
"模型不存在或不可用",
"请检查模型名称是否正确",
"请确认该模型在当前 API 中是否可用"
]
else:
suggestions = [
"检测过程中遇到未知错误",
"建议:检查所有配置参数是否正确",
"提示:查看详细错误信息以获取更多线索"
]
return {
"success": False,
"supported": False,
"message": "Function Calling 检测失败",
"error": error_msg,
"error_type": error_type,
"suggestions": suggestions
}
@router.post("/test")
async def test_api_connection(data: ApiTestRequest):
"""
@@ -370,7 +615,8 @@ async def test_api_connection(data: ApiTestRequest):
provider=provider,
model=llm_model,
temperature=0.7,
max_tokens=8000
max_tokens=8000,
auto_mcp=False # 测试时不加载MCP工具
)
end_time = time.time()
File diff suppressed because it is too large Load Diff