feat: 重构MCP功能和AI服务提供者架构
This commit is contained in:
+118
-164
@@ -45,7 +45,7 @@ from app.services.memory_service import memory_service
|
||||
from app.services.chapter_regenerator import ChapterRegenerator
|
||||
from app.logger import get_logger
|
||||
from app.api.settings import get_user_ai_service
|
||||
from app.utils.sse_response import create_sse_response
|
||||
from app.utils.sse_response import SSEResponse, create_sse_response
|
||||
|
||||
router = APIRouter(prefix="/chapters", tags=["章节管理"])
|
||||
logger = get_logger(__name__)
|
||||
@@ -1172,7 +1172,6 @@ async def generate_chapter_content_stream(
|
||||
"""
|
||||
style_id = generate_request.style_id
|
||||
target_word_count = generate_request.target_word_count or 3000
|
||||
enable_mcp = generate_request.enable_mcp if hasattr(generate_request, 'enable_mcp') else True
|
||||
custom_model = generate_request.model if hasattr(generate_request, 'model') else None
|
||||
temp_narrative_perspective = generate_request.narrative_perspective if hasattr(generate_request, 'narrative_perspective') else None
|
||||
# 预先验证章节存在性(使用临时会话)
|
||||
@@ -1211,25 +1210,36 @@ async def generate_chapter_content_stream(
|
||||
# 获取当前用户ID(在生成器外部就需要)
|
||||
current_user_id = getattr(request.state, "user_id", "system")
|
||||
|
||||
# 初始化标准进度追踪器
|
||||
from app.utils.sse_response import WizardProgressTracker
|
||||
tracker = WizardProgressTracker("章节")
|
||||
|
||||
try:
|
||||
yield await tracker.start()
|
||||
|
||||
# 创建新的数据库会话
|
||||
async for db_session in get_db(request):
|
||||
# === 加载阶段 ===
|
||||
yield await tracker.loading("加载章节信息...", 0.2)
|
||||
|
||||
# 重新获取章节信息
|
||||
chapter_result = await db_session.execute(
|
||||
select(Chapter).where(Chapter.id == chapter_id)
|
||||
)
|
||||
current_chapter = chapter_result.scalar_one_or_none()
|
||||
if not current_chapter:
|
||||
yield f"data: {json.dumps({'type': 'error', 'error': '章节不存在'}, ensure_ascii=False)}\n\n"
|
||||
yield await tracker.error("章节不存在", 404)
|
||||
return
|
||||
|
||||
yield await tracker.loading("加载项目信息...", 0.4)
|
||||
|
||||
# 获取项目信息
|
||||
project_result = await db_session.execute(
|
||||
select(Project).where(Project.id == current_chapter.project_id)
|
||||
)
|
||||
project = project_result.scalar_one_or_none()
|
||||
if not project:
|
||||
yield f"data: {json.dumps({'type': 'error', 'error': '项目不存在'}, ensure_ascii=False)}\n\n"
|
||||
yield await tracker.error("项目不存在", 404)
|
||||
return
|
||||
|
||||
# 获取项目的大纲模式
|
||||
@@ -1333,80 +1343,7 @@ async def generate_chapter_content_stream(
|
||||
logger.info(f" - 相关记忆: {chapter_context.context_stats.get('memory_count', 0)} 条")
|
||||
logger.info(f" - 总上下文长度: {chapter_context.context_stats.get('total_length', 0)} 字符")
|
||||
|
||||
# 发送开始事件
|
||||
yield f"data: {json.dumps({'type': 'start', 'message': '开始AI创作...'}, ensure_ascii=False)}\n\n"
|
||||
|
||||
# 发送初始进度0%
|
||||
yield f"data: {json.dumps({'type': 'progress', 'progress': 0, 'message': '准备生成...', 'status': 'processing'}, ensure_ascii=False)}\n\n"
|
||||
|
||||
# 🔧 MCP工具增强:收集章节参考资料(优化版)
|
||||
mcp_reference_materials = ""
|
||||
if enable_mcp and current_user_id:
|
||||
try:
|
||||
# 1️⃣ 静默检查工具可用性
|
||||
from app.services.mcp_tool_service import mcp_tool_service
|
||||
available_tools = await mcp_tool_service.get_user_enabled_tools(
|
||||
user_id=current_user_id,
|
||||
db_session=db_session
|
||||
)
|
||||
|
||||
# 2️⃣ 只在有工具时才显示消息和调用
|
||||
if available_tools:
|
||||
yield f"data: {json.dumps({'type': 'progress', 'message': '🔍 使用MCP工具收集参考资料...', 'progress': 28}, ensure_ascii=False)}\n\n"
|
||||
|
||||
# 构建资料收集提示词
|
||||
planning_prompt = f"""你正在为小说《{project.title}》创作第{current_chapter.chapter_number}章《{current_chapter.title}》。
|
||||
|
||||
【章节大纲】
|
||||
{outline.content if outline else current_chapter.summary or '暂无大纲'}
|
||||
|
||||
【小说信息】
|
||||
- 题材:{project.genre or '未设定'}
|
||||
- 主题:{project.theme or '未设定'}
|
||||
- 时代背景:{project.world_time_period or '未设定'}
|
||||
- 地理位置:{project.world_location or '未设定'}
|
||||
|
||||
【任务】
|
||||
请使用可用工具搜索相关背景资料,帮助创作更真实、更有深度的章节内容。
|
||||
你可以查询:
|
||||
1. 该章节涉及的历史事件或时代背景
|
||||
2. 地理环境和场景描写参考
|
||||
3. 相关领域的专业知识(如武术、科技、魔法等)
|
||||
4. 文化习俗和生活细节
|
||||
|
||||
请根据章节内容,有针对性地查询1-2个最关键的问题。"""
|
||||
|
||||
# 调用MCP增强的AI(非流式,限制1轮避免超时)
|
||||
planning_result = await user_ai_service.generate_text_with_mcp(
|
||||
prompt=planning_prompt,
|
||||
user_id=current_user_id,
|
||||
db_session=db_session,
|
||||
enable_mcp=True,
|
||||
max_tool_rounds=2, # ✅ 减少为1轮,避免超时
|
||||
tool_choice="auto",
|
||||
provider=None,
|
||||
model=None
|
||||
)
|
||||
|
||||
# 3️⃣ 提取参考资料并显示结果
|
||||
if planning_result.get("tool_calls_made", 0) > 0:
|
||||
tool_count = planning_result["tool_calls_made"]
|
||||
yield f"data: {json.dumps({'type': 'progress', 'message': f'✅ MCP工具调用成功({tool_count}次)', 'progress': 32}, ensure_ascii=False)}\n\n"
|
||||
mcp_reference_materials = planning_result.get("content", "")
|
||||
logger.info(f"📚 MCP工具收集参考资料:{len(mcp_reference_materials)} 字符")
|
||||
else:
|
||||
yield f"data: {json.dumps({'type': 'progress', 'message': 'ℹ️ MCP未使用工具,继续', 'progress': 32}, ensure_ascii=False)}\n\n"
|
||||
else:
|
||||
logger.debug(f"用户 {current_user_id} 未启用MCP工具,跳过MCP增强")
|
||||
# 未启用MCP时也发送进度,保持连贯性
|
||||
yield f"data: {json.dumps({'type': 'progress', 'message': '准备生成内容...', 'progress': 10}, ensure_ascii=False)}\n\n"
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"⚠️ MCP工具调用失败,降级为基础模式: {str(e)}")
|
||||
yield f"data: {json.dumps({'type': 'progress', 'message': '⚠️ MCP工具暂时不可用,使用基础模式', 'progress': 10}, ensure_ascii=False)}\n\n"
|
||||
else:
|
||||
# 如果未启用MCP,也发送基础进度
|
||||
yield f"data: {json.dumps({'type': 'progress', 'message': '开始构建创作上下文...', 'progress': 10}, ensure_ascii=False)}\n\n"
|
||||
yield await tracker.loading("上下文构建完成", 0.8)
|
||||
|
||||
# 🎭 确定使用的叙事人称(临时指定 > 项目默认 > 系统默认)
|
||||
chapter_perspective = (
|
||||
@@ -1496,26 +1433,17 @@ async def generate_chapter_content_stream(
|
||||
characters_info=characters_info or '暂无角色信息'
|
||||
)
|
||||
|
||||
# 添加 MCP 参考资料(如果有)
|
||||
if mcp_reference_materials:
|
||||
mcp_section = f"\n\n<mcp_reference>\n{mcp_reference_materials}\n</mcp_reference>"
|
||||
base_prompt = base_prompt.replace("</task>", f"{mcp_section}\n</task>")
|
||||
logger.info(f"📖 已整合MCP参考资料({len(mcp_reference_materials)}字符)")
|
||||
|
||||
# 应用写作风格
|
||||
if style_content:
|
||||
prompt = WritingStyleManager.apply_style_to_prompt(base_prompt, style_content)
|
||||
else:
|
||||
prompt = base_prompt
|
||||
|
||||
if mcp_reference_materials:
|
||||
logger.info(f"📖 已整合MCP参考资料({len(mcp_reference_materials)}字符)到章节生成提示词")
|
||||
# === 准备阶段 ===
|
||||
yield await tracker.preparing("准备AI提示词...")
|
||||
|
||||
logger.info(f"开始AI流式创作章节 {chapter_id}")
|
||||
|
||||
# 发送开始生成的进度
|
||||
yield f"data: {json.dumps({'type': 'progress', 'progress': 10, 'message': '开始AI创作...', 'status': 'processing'}, ensure_ascii=False)}\n\n"
|
||||
|
||||
# 🎨 方案一:将写作风格注入到系统提示词(最高优先级)
|
||||
system_prompt_with_style = None
|
||||
if style_content:
|
||||
@@ -1530,7 +1458,8 @@ async def generate_chapter_content_stream(
|
||||
# 准备生成参数
|
||||
generate_kwargs = {
|
||||
"prompt": prompt,
|
||||
"system_prompt": system_prompt_with_style # 🔑 关键:使用系统提示词传递风格
|
||||
"system_prompt": system_prompt_with_style,
|
||||
"tool_choice": "required"
|
||||
}
|
||||
if custom_model:
|
||||
logger.info(f" 使用自定义模型: {custom_model}")
|
||||
@@ -1538,47 +1467,38 @@ async def generate_chapter_content_stream(
|
||||
# 注意:这里使用用户配置的AI服务,模型参数会覆盖默认模型
|
||||
# 如果需要切换provider,需要在前端传递provider参数
|
||||
|
||||
# 流式生成内容
|
||||
# === 生成阶段 ===
|
||||
full_content = ""
|
||||
chunk_count = 0
|
||||
last_progress = 0
|
||||
|
||||
yield await tracker.generating(
|
||||
current_chars=0,
|
||||
estimated_total=target_word_count
|
||||
)
|
||||
|
||||
async for chunk in user_ai_service.generate_text_stream(**generate_kwargs):
|
||||
full_content += chunk
|
||||
chunk_count += 1
|
||||
|
||||
# 发送内容块
|
||||
yield f"data: {json.dumps({'type': 'content', 'content': chunk}, ensure_ascii=False)}\n\n"
|
||||
yield await tracker.generating_chunk(chunk)
|
||||
|
||||
# 每5个chunk发送一次进度更新(10-95%,更平滑)
|
||||
# 每5个chunk发送一次进度更新
|
||||
if chunk_count % 5 == 0:
|
||||
current_word_count = len(full_content)
|
||||
# 优化进度计算:使用更平滑的递增方式
|
||||
# 基于chunk数量和字数的混合计算,避免大幅跳跃
|
||||
chunk_progress = min(40, chunk_count // 5) # chunk贡献最多40%
|
||||
word_progress = min(45, int((current_word_count / target_word_count) * 45)) # 字数贡献最多45%
|
||||
estimated_progress = min(95, 10 + chunk_progress + word_progress)
|
||||
|
||||
# 只在进度变化时发送
|
||||
if estimated_progress > last_progress:
|
||||
progress_data = {
|
||||
'type': 'progress',
|
||||
'progress': estimated_progress,
|
||||
'message': f'正在创作中... 已生成 {current_word_count} 字',
|
||||
'word_count': current_word_count,
|
||||
'status': 'processing'
|
||||
}
|
||||
yield f"data: {json.dumps(progress_data, ensure_ascii=False)}\n\n"
|
||||
last_progress = estimated_progress
|
||||
yield await tracker.generating(
|
||||
current_chars=len(full_content),
|
||||
estimated_total=target_word_count,
|
||||
message=f'正在创作中... 已生成 {len(full_content)} 字'
|
||||
)
|
||||
|
||||
# 每20个chunk发送心跳
|
||||
if chunk_count % 20 == 0:
|
||||
yield f"data: {json.dumps({'type': 'heartbeat'}, ensure_ascii=False)}\n\n"
|
||||
yield await tracker.heartbeat()
|
||||
|
||||
await asyncio.sleep(0) # 让出控制权
|
||||
|
||||
# 发送保存进度
|
||||
yield f"data: {json.dumps({'type': 'progress', 'progress': 97, 'message': '正在保存章节...', 'status': 'processing'}, ensure_ascii=False)}\n\n"
|
||||
# === 保存阶段 ===
|
||||
yield await tracker.saving("正在保存章节...", 0.3)
|
||||
|
||||
# 更新章节内容到数据库
|
||||
old_word_count = current_chapter.word_count or 0
|
||||
@@ -1634,25 +1554,28 @@ async def generate_chapter_content_stream(
|
||||
ai_service=user_ai_service
|
||||
)
|
||||
|
||||
# 发送最终进度100%
|
||||
yield f"data: {json.dumps({'type': 'progress', 'progress': 99, 'message': '创作完成!', 'word_count': new_word_count, 'status': 'success'}, ensure_ascii=False)}\n\n"
|
||||
yield await tracker.saving("章节保存完成", 0.8)
|
||||
|
||||
# 发送完成事件(包含分析任务ID)
|
||||
completion_data = {
|
||||
'type': 'done',
|
||||
'message': '创作完成',
|
||||
# === 完成阶段 ===
|
||||
yield await tracker.complete("创作完成!")
|
||||
|
||||
# 发送结果数据
|
||||
yield await tracker.result({
|
||||
'word_count': new_word_count,
|
||||
'analysis_task_id': task_id
|
||||
}
|
||||
yield f"data: {json.dumps(completion_data, ensure_ascii=False)}\n\n"
|
||||
})
|
||||
|
||||
# 发送分析开始事件
|
||||
analysis_started_data = {
|
||||
'type': 'analysis_started',
|
||||
'task_id': task_id,
|
||||
'message': '章节分析已开始'
|
||||
}
|
||||
yield f"data: {json.dumps(analysis_started_data, ensure_ascii=False)}\n\n"
|
||||
# 发送分析开始事件(使用自定义事件)
|
||||
yield await SSEResponse.send_event(
|
||||
event='analysis_started',
|
||||
data={
|
||||
'task_id': task_id,
|
||||
'message': '章节分析已开始'
|
||||
}
|
||||
)
|
||||
|
||||
# 发送完成信号
|
||||
yield await tracker.done()
|
||||
|
||||
break # 退出async for db_session循环
|
||||
|
||||
@@ -1675,7 +1598,7 @@ async def generate_chapter_content_stream(
|
||||
logger.info("章节生成事务已回滚(异常)")
|
||||
except Exception as rollback_error:
|
||||
logger.error(f"回滚失败: {str(rollback_error)}")
|
||||
yield f"data: {json.dumps({'type': 'error', 'error': str(e)}, ensure_ascii=False)}\n\n"
|
||||
yield await tracker.error(str(e))
|
||||
finally:
|
||||
# 确保数据库会话被正确关闭
|
||||
if db_session:
|
||||
@@ -2813,7 +2736,8 @@ async def generate_single_chapter_for_batch(
|
||||
# 准备生成参数
|
||||
generate_kwargs = {
|
||||
"prompt": prompt,
|
||||
"system_prompt": system_prompt_with_style # 🔑 关键:使用系统提示词传递风格
|
||||
"system_prompt": system_prompt_with_style,
|
||||
"tool_choice": "required"
|
||||
}
|
||||
# 如果传入了自定义模型,使用指定的模型
|
||||
if custom_model:
|
||||
@@ -3029,11 +2953,16 @@ async def regenerate_chapter_stream(
|
||||
db_session = None
|
||||
db_committed = False
|
||||
|
||||
# 初始化标准进度追踪器
|
||||
from app.utils.sse_response import WizardProgressTracker
|
||||
tracker = WizardProgressTracker("章节重新生成")
|
||||
|
||||
try:
|
||||
yield await tracker.start()
|
||||
|
||||
# 创建独立数据库会话
|
||||
async for db_session in get_db(request):
|
||||
# 发送开始事件
|
||||
yield f"data: {json.dumps({'type': 'start', 'message': '开始重新生成章节...'}, ensure_ascii=False)}\n\n"
|
||||
yield await tracker.loading("加载章节信息...", 0.5)
|
||||
|
||||
# 创建重新生成任务
|
||||
regen_task = RegenerationTask(
|
||||
@@ -3062,13 +2991,25 @@ async def regenerate_chapter_stream(
|
||||
task_id = regen_task.id
|
||||
logger.info(f"📝 创建重新生成任务: {task_id}")
|
||||
|
||||
yield f"data: {json.dumps({'type': 'task_created', 'task_id': task_id}, ensure_ascii=False)}\n\n"
|
||||
yield await tracker.preparing("准备重新生成...")
|
||||
|
||||
yield await SSEResponse.send_event(
|
||||
event='task_created',
|
||||
data={'task_id': task_id}
|
||||
)
|
||||
|
||||
# 初始化重新生成器
|
||||
regenerator = ChapterRegenerator(user_ai_service)
|
||||
|
||||
# 流式生成新内容
|
||||
# === 生成阶段 ===
|
||||
full_content = ""
|
||||
estimated_total = regenerate_request.target_word_count or len(chapter.content)
|
||||
|
||||
yield await tracker.generating(
|
||||
current_chars=0,
|
||||
estimated_total=estimated_total
|
||||
)
|
||||
|
||||
async for event in regenerator.regenerate_with_feedback(
|
||||
chapter=chapter,
|
||||
analysis=analysis,
|
||||
@@ -3083,19 +3024,35 @@ async def regenerate_chapter_stream(
|
||||
# 内容块
|
||||
chunk = event['content']
|
||||
full_content += chunk
|
||||
yield f"data: {json.dumps({'type': 'chunk', 'content': chunk}, ensure_ascii=False)}\n\n"
|
||||
yield await tracker.generating_chunk(chunk)
|
||||
|
||||
# 定期更新进度
|
||||
if len(full_content) % 500 == 0:
|
||||
yield await tracker.generating(
|
||||
current_chars=len(full_content),
|
||||
estimated_total=estimated_total,
|
||||
message=f'重新生成中... 已生成 {len(full_content)} 字'
|
||||
)
|
||||
elif event['type'] == 'progress':
|
||||
# 进度更新
|
||||
progress_data = {
|
||||
'type': 'progress',
|
||||
'progress': event.get('progress', 0),
|
||||
'message': event.get('message', ''),
|
||||
'word_count': event.get('word_count', 0)
|
||||
}
|
||||
yield f"data: {json.dumps(progress_data, ensure_ascii=False)}\n\n"
|
||||
# 进度更新 - 映射到对应阶段
|
||||
progress = event.get('progress', 0)
|
||||
message = event.get('message', '')
|
||||
if progress < 20:
|
||||
yield await tracker.preparing(message)
|
||||
elif progress < 85:
|
||||
yield await tracker.generating(
|
||||
current_chars=len(full_content),
|
||||
estimated_total=estimated_total,
|
||||
message=message
|
||||
)
|
||||
else:
|
||||
yield await tracker.parsing(message)
|
||||
|
||||
await asyncio.sleep(0)
|
||||
|
||||
# === 保存阶段 ===
|
||||
yield await tracker.saving("保存重新生成的内容...", 0.5)
|
||||
|
||||
# 更新任务状态
|
||||
regen_task.status = 'completed'
|
||||
regen_task.regenerated_content = full_content
|
||||
@@ -3108,25 +3065,22 @@ async def regenerate_chapter_stream(
|
||||
await db_session.commit()
|
||||
db_committed = True
|
||||
|
||||
# 先发送结果数据
|
||||
result_data = {
|
||||
'type': 'result',
|
||||
'data': {
|
||||
'task_id': task_id,
|
||||
'word_count': len(full_content),
|
||||
'version_number': regen_task.version_number,
|
||||
'auto_applied': regenerate_request.auto_apply,
|
||||
'diff_stats': diff_stats
|
||||
}
|
||||
}
|
||||
yield f"data: {json.dumps(result_data, ensure_ascii=False)}\n\n"
|
||||
yield await tracker.saving("保存完成", 0.9)
|
||||
|
||||
# 再发送完成事件
|
||||
completion_data = {
|
||||
'type': 'done',
|
||||
'message': '重新生成完成'
|
||||
}
|
||||
yield f"data: {json.dumps(completion_data, ensure_ascii=False)}\n\n"
|
||||
# === 完成阶段 ===
|
||||
yield await tracker.complete("重新生成完成!")
|
||||
|
||||
# 发送结果数据
|
||||
yield await tracker.result({
|
||||
'task_id': task_id,
|
||||
'word_count': len(full_content),
|
||||
'version_number': regen_task.version_number,
|
||||
'auto_applied': regenerate_request.auto_apply,
|
||||
'diff_stats': diff_stats
|
||||
})
|
||||
|
||||
# 发送完成信号
|
||||
yield await tracker.done()
|
||||
|
||||
logger.info(f"✅ 章节重新生成完成: {chapter_id}, 任务: {task_id}")
|
||||
|
||||
@@ -3151,7 +3105,7 @@ async def regenerate_chapter_stream(
|
||||
except Exception as update_error:
|
||||
logger.error(f"更新任务失败状态失败: {str(update_error)}")
|
||||
|
||||
yield f"data: {json.dumps({'type': 'error', 'error': str(e)}, ensure_ascii=False)}\n\n"
|
||||
yield await tracker.error(str(e))
|
||||
|
||||
finally:
|
||||
if db_session:
|
||||
|
||||
+28
-130
@@ -775,148 +775,46 @@ async def generate_character_stream(
|
||||
logger.info(f"🎯 开始为项目 {request.project_id} 生成角色(SSE流式)")
|
||||
|
||||
try:
|
||||
# 🔧 MCP工具增强:静默检查并收集参考资料
|
||||
# 直接使用 AIService 流式生成
|
||||
ai_response = ""
|
||||
chunk_count = 0
|
||||
|
||||
if user_id:
|
||||
try:
|
||||
from app.services.mcp_tool_service import mcp_tool_service
|
||||
available_tools = await mcp_tool_service.get_user_enabled_tools(
|
||||
user_id=user_id,
|
||||
db_session=db
|
||||
)
|
||||
|
||||
# 只在有工具时才调用
|
||||
if available_tools:
|
||||
logger.info(f"🔍 检测到可用MCP工具,尝试收集参考资料...")
|
||||
result = await user_ai_service.generate_text_with_mcp(
|
||||
prompt=prompt,
|
||||
user_id=user_id,
|
||||
db_session=db,
|
||||
enable_mcp=True,
|
||||
max_tool_rounds=2,
|
||||
tool_choice="auto",
|
||||
provider=None,
|
||||
model=None
|
||||
)
|
||||
|
||||
if isinstance(result, dict):
|
||||
ai_response = result.get('content', '')
|
||||
finish_reason = result.get('finish_reason', '')
|
||||
tool_calls_made = result.get('tool_calls_made', 0)
|
||||
|
||||
# 🔧 修复:检查工具调用是否真正成功
|
||||
if tool_calls_made > 0:
|
||||
if finish_reason == 'tool_error':
|
||||
logger.warning(f"⚠️ MCP工具调用失败,降级为基础模式")
|
||||
# 工具调用失败,重新用基础模式生成
|
||||
ai_response = ""
|
||||
elif not ai_response.strip():
|
||||
logger.warning(f"⚠️ MCP工具调用后返回空响应,降级为基础模式")
|
||||
# 工具调用成功但返回空内容,重新生成
|
||||
ai_response = ""
|
||||
else:
|
||||
logger.info(f"✅ MCP工具调用成功({tool_calls_made}次),内容长度: {len(ai_response)}")
|
||||
# MCP成功且有内容,模拟流式输出(分块发送)
|
||||
chunk_size = 50
|
||||
for i in range(0, len(ai_response), chunk_size):
|
||||
chunk = ai_response[i:i+chunk_size]
|
||||
chunk_count += 1
|
||||
yield await SSEResponse.send_chunk(chunk)
|
||||
|
||||
if chunk_count % 3 == 0:
|
||||
yield await SSEResponse.send_progress(
|
||||
f"AI生成角色中... ({i+len(chunk)}/{len(ai_response)}字符)",
|
||||
10 + min(85 * (i+len(chunk)) // len(ai_response), 85)
|
||||
)
|
||||
|
||||
# 跳过后续的流式生成
|
||||
ai_response = result.get('content', '')
|
||||
else:
|
||||
ai_response = result
|
||||
|
||||
# 如果MCP调用失败或返回空,继续走流式生成
|
||||
if not ai_response or not ai_response.strip():
|
||||
logger.info(f"🔄 开始流式生成...")
|
||||
ai_response = ""
|
||||
async for chunk in user_ai_service.generate_text_stream(prompt=prompt):
|
||||
chunk_count += 1
|
||||
ai_response += chunk
|
||||
|
||||
# 发送内容块
|
||||
yield await SSEResponse.send_chunk(chunk)
|
||||
|
||||
# 定期更新进度
|
||||
if chunk_count % 5 == 0:
|
||||
yield await SSEResponse.send_progress(
|
||||
f"AI生成角色中... ({len(ai_response)}字符)",
|
||||
10 + min(chunk_count // 2, 85)
|
||||
)
|
||||
|
||||
# 心跳
|
||||
if chunk_count % 20 == 0:
|
||||
yield await SSEResponse.send_heartbeat()
|
||||
else:
|
||||
logger.debug(f"用户 {user_id} 未启用MCP工具,使用流式基础模式")
|
||||
async for chunk in user_ai_service.generate_text_stream(prompt=prompt):
|
||||
chunk_count += 1
|
||||
ai_response += chunk
|
||||
|
||||
# 发送内容块
|
||||
yield await SSEResponse.send_chunk(chunk)
|
||||
|
||||
# 定期更新进度
|
||||
if chunk_count % 5 == 0:
|
||||
yield await SSEResponse.send_progress(
|
||||
f"AI生成角色中... ({len(ai_response)}字符)",
|
||||
10 + min(chunk_count // 2, 85)
|
||||
)
|
||||
|
||||
# 心跳
|
||||
if chunk_count % 20 == 0:
|
||||
yield await SSEResponse.send_heartbeat()
|
||||
|
||||
except Exception as mcp_error:
|
||||
logger.warning(f"⚠️ MCP工具调用异常,降级为流式基础模式: {str(mcp_error)}")
|
||||
ai_response = ""
|
||||
async for chunk in user_ai_service.generate_text_stream(prompt=prompt):
|
||||
chunk_count += 1
|
||||
ai_response += chunk
|
||||
|
||||
# 发送内容块
|
||||
yield await SSEResponse.send_chunk(chunk)
|
||||
|
||||
# 定期更新进度
|
||||
if chunk_count % 5 == 0:
|
||||
yield await SSEResponse.send_progress(
|
||||
f"AI生成角色中... ({len(ai_response)}字符)",
|
||||
10 + min(chunk_count // 2, 85)
|
||||
)
|
||||
|
||||
# 心跳
|
||||
if chunk_count % 20 == 0:
|
||||
yield await SSEResponse.send_heartbeat()
|
||||
else:
|
||||
logger.debug(f"未登录用户,使用流式基础模式")
|
||||
async for chunk in user_ai_service.generate_text_stream(prompt=prompt):
|
||||
chunk_count += 1
|
||||
ai_response += chunk
|
||||
logger.info(f"🎯 开始生成角色(流式模式)...")
|
||||
yield await SSEResponse.send_progress("🎯 开始生成角色...", 15)
|
||||
|
||||
async for chunk in user_ai_service.generate_text_stream(
|
||||
prompt=prompt,
|
||||
tool_choice="required",
|
||||
):
|
||||
# chunk 现在可能是 dict 或 str,提取 content 字段
|
||||
if isinstance(chunk, dict):
|
||||
content = chunk.get("content", "")
|
||||
else:
|
||||
content = chunk
|
||||
|
||||
if content:
|
||||
ai_response += content
|
||||
|
||||
# 发送内容块
|
||||
yield await SSEResponse.send_chunk(chunk)
|
||||
yield await SSEResponse.send_chunk(content)
|
||||
|
||||
# 定期更新进度
|
||||
if chunk_count % 5 == 0:
|
||||
# 定期更新进度(每收到约500字符更新一次,避免过于频繁)
|
||||
current_len = len(ai_response)
|
||||
if current_len >= chunk_count * 500:
|
||||
chunk_count += 1
|
||||
# 使用实际字符数量计算进度,上限85%(留15%给后续解析和保存)
|
||||
# 估算最终字符数约为提示词的8倍,最少3000字符
|
||||
estimated_total = max(3000, len(prompt) * 8)
|
||||
progress = min(15 + int(current_len / estimated_total * 70), 85)
|
||||
yield await SSEResponse.send_progress(
|
||||
f"AI生成角色中... ({len(ai_response)}字符)",
|
||||
10 + min(chunk_count // 2, 85)
|
||||
f"AI生成角色中... ({current_len}字符)",
|
||||
progress
|
||||
)
|
||||
|
||||
# 心跳
|
||||
if chunk_count % 20 == 0:
|
||||
yield await SSEResponse.send_heartbeat()
|
||||
|
||||
|
||||
except Exception as ai_error:
|
||||
logger.error(f"❌ AI服务调用异常:{str(ai_error)}")
|
||||
yield await SSEResponse.send_error(f"AI服务调用失败:{str(ai_error)}")
|
||||
|
||||
+210
-73
@@ -1,4 +1,7 @@
|
||||
"""MCP插件管理API"""
|
||||
"""MCP插件管理API
|
||||
|
||||
重构后使用统一的MCPClientFacade门面来管理所有MCP操作。
|
||||
"""
|
||||
from fastapi import APIRouter, HTTPException, Depends, Query, Request
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select
|
||||
@@ -17,9 +20,8 @@ from app.schemas.mcp_plugin import (
|
||||
)
|
||||
import json
|
||||
from app.user_manager import User
|
||||
from app.mcp.registry import mcp_registry
|
||||
from app.mcp import mcp_client, MCPPluginConfig, PluginStatus
|
||||
from app.services.mcp_test_service import mcp_test_service
|
||||
from app.services.mcp_tool_service import mcp_tool_service
|
||||
from app.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
@@ -34,6 +36,31 @@ def require_login(request: Request) -> User:
|
||||
return request.state.user
|
||||
|
||||
|
||||
async def _register_plugin_to_facade(plugin: MCPPlugin, user_id: str) -> bool:
|
||||
"""
|
||||
将插件注册到统一门面
|
||||
|
||||
Args:
|
||||
plugin: 插件对象
|
||||
user_id: 用户ID
|
||||
|
||||
Returns:
|
||||
是否注册成功
|
||||
"""
|
||||
if plugin.plugin_type in ["http", "streamable_http", "sse"] and plugin.server_url:
|
||||
return await mcp_client.register(MCPPluginConfig(
|
||||
user_id=user_id,
|
||||
plugin_name=plugin.plugin_name,
|
||||
url=plugin.server_url,
|
||||
plugin_type=plugin.plugin_type,
|
||||
headers=plugin.headers,
|
||||
timeout=plugin.config.get('timeout', 60.0) if plugin.config else 60.0
|
||||
))
|
||||
else:
|
||||
logger.warning(f"暂不支持的插件类型: {plugin.plugin_type}")
|
||||
return False
|
||||
|
||||
|
||||
@router.get("", response_model=List[MCPPluginResponse])
|
||||
async def list_plugins(
|
||||
enabled_only: bool = Query(False, description="只返回启用的插件"),
|
||||
@@ -99,9 +126,9 @@ async def create_plugin(
|
||||
await db.commit()
|
||||
await db.refresh(plugin)
|
||||
|
||||
# 如果启用,加载到注册表
|
||||
# 如果启用,注册到统一门面
|
||||
if plugin.enabled:
|
||||
success = await mcp_registry.load_plugin(plugin)
|
||||
success = await _register_plugin_to_facade(plugin, user.user_id)
|
||||
if success:
|
||||
plugin.status = "active"
|
||||
else:
|
||||
@@ -153,7 +180,7 @@ async def create_plugin_simple(
|
||||
# 提取配置
|
||||
server_type = server_config.get("type", "http")
|
||||
|
||||
if server_type not in ["http", "stdio"]:
|
||||
if server_type not in ["http", "stdio", "streamable_http", "sse"]:
|
||||
raise HTTPException(status_code=400, detail=f"不支持的服务器类型: {server_type}")
|
||||
|
||||
# 检查插件名是否已存在
|
||||
@@ -175,12 +202,12 @@ async def create_plugin_simple(
|
||||
"sort_order": 0
|
||||
}
|
||||
|
||||
if server_type == "http":
|
||||
if server_type in ["http", "streamable_http", "sse"]:
|
||||
plugin_data["server_url"] = server_config.get("url")
|
||||
plugin_data["headers"] = server_config.get("headers", {})
|
||||
|
||||
if not plugin_data["server_url"]:
|
||||
raise HTTPException(status_code=400, detail="HTTP类型插件必须提供url字段")
|
||||
raise HTTPException(status_code=400, detail=f"{server_type}类型插件必须提供url字段")
|
||||
|
||||
elif server_type == "stdio":
|
||||
plugin_data["command"] = server_config.get("command")
|
||||
@@ -194,9 +221,9 @@ async def create_plugin_simple(
|
||||
# 更新现有插件
|
||||
logger.info(f"插件 {plugin_name} 已存在,执行更新操作")
|
||||
|
||||
# 先卸载旧插件
|
||||
if existing.enabled:
|
||||
await mcp_registry.unload_plugin(user.user_id, existing.plugin_name)
|
||||
# 保存旧状态
|
||||
old_enabled = existing.enabled
|
||||
old_plugin_name = existing.plugin_name
|
||||
|
||||
# 更新字段
|
||||
for key, value in plugin_data.items():
|
||||
@@ -206,17 +233,24 @@ async def create_plugin_simple(
|
||||
await db.commit()
|
||||
await db.refresh(plugin)
|
||||
|
||||
# 如果启用,重新加载
|
||||
# 数据库完成后进行MCP操作
|
||||
if old_enabled:
|
||||
try:
|
||||
await mcp_client.unregister(user.user_id, old_plugin_name)
|
||||
except Exception as e:
|
||||
logger.warning(f"注销旧插件出错: {e}")
|
||||
|
||||
if plugin.enabled:
|
||||
success = await mcp_registry.load_plugin(plugin)
|
||||
if success:
|
||||
plugin.status = "active"
|
||||
plugin.last_error = None
|
||||
else:
|
||||
try:
|
||||
success = await _register_plugin_to_facade(plugin, user.user_id)
|
||||
plugin.status = "active" if success else "error"
|
||||
plugin.last_error = None if success else "加载失败"
|
||||
await db.commit()
|
||||
except Exception as e:
|
||||
logger.error(f"注册插件失败: {e}")
|
||||
plugin.status = "error"
|
||||
plugin.last_error = "加载失败"
|
||||
await db.commit()
|
||||
await db.refresh(plugin)
|
||||
plugin.last_error = str(e)
|
||||
await db.commit()
|
||||
|
||||
logger.info(f"用户 {user.user_id} 更新插件: {plugin_name}")
|
||||
else:
|
||||
@@ -230,16 +264,18 @@ async def create_plugin_simple(
|
||||
await db.commit()
|
||||
await db.refresh(plugin)
|
||||
|
||||
# 如果启用,加载到注册表
|
||||
# 数据库完成后进行MCP操作
|
||||
if plugin.enabled:
|
||||
success = await mcp_registry.load_plugin(plugin)
|
||||
if success:
|
||||
plugin.status = "active"
|
||||
else:
|
||||
try:
|
||||
success = await _register_plugin_to_facade(plugin, user.user_id)
|
||||
plugin.status = "active" if success else "error"
|
||||
plugin.last_error = None if success else "加载失败"
|
||||
await db.commit()
|
||||
except Exception as e:
|
||||
logger.error(f"注册插件失败: {e}")
|
||||
plugin.status = "error"
|
||||
plugin.last_error = "加载失败"
|
||||
await db.commit()
|
||||
await db.refresh(plugin)
|
||||
plugin.last_error = str(e)
|
||||
await db.commit()
|
||||
|
||||
logger.info(f"用户 {user.user_id} 通过简化配置创建插件: {plugin_name}")
|
||||
|
||||
@@ -306,9 +342,10 @@ async def update_plugin(
|
||||
await db.commit()
|
||||
await db.refresh(plugin)
|
||||
|
||||
# 如果插件已启用,重新加载
|
||||
# 如果插件已启用,重新注册
|
||||
if plugin.enabled:
|
||||
await mcp_registry.reload_plugin(plugin)
|
||||
await mcp_client.unregister(user.user_id, plugin.plugin_name)
|
||||
await _register_plugin_to_facade(plugin, user.user_id)
|
||||
|
||||
logger.info(f"用户 {user.user_id} 更新插件: {plugin.plugin_name}")
|
||||
return plugin
|
||||
@@ -334,8 +371,8 @@ async def delete_plugin(
|
||||
if not plugin:
|
||||
raise HTTPException(status_code=404, detail="插件不存在")
|
||||
|
||||
# 从注册表卸载
|
||||
await mcp_registry.unload_plugin(user.user_id, plugin.plugin_name)
|
||||
# 从统一门面注销
|
||||
await mcp_client.unregister(user.user_id, plugin.plugin_name)
|
||||
|
||||
# 删除数据库记录
|
||||
await db.delete(plugin)
|
||||
@@ -366,27 +403,57 @@ async def toggle_plugin(
|
||||
if not plugin:
|
||||
raise HTTPException(status_code=404, detail="插件不存在")
|
||||
|
||||
plugin.enabled = enabled
|
||||
# 保存插件信息用于后续MCP操作
|
||||
plugin_name = plugin.plugin_name
|
||||
plugin_type = plugin.plugin_type
|
||||
server_url = plugin.server_url
|
||||
headers = plugin.headers
|
||||
config = plugin.config
|
||||
|
||||
if enabled:
|
||||
# 启用:加载到注册表
|
||||
success = await mcp_registry.load_plugin(plugin)
|
||||
if success:
|
||||
plugin.status = "active"
|
||||
plugin.last_error = None
|
||||
else:
|
||||
plugin.status = "error"
|
||||
plugin.last_error = "加载失败"
|
||||
else:
|
||||
# 禁用:从注册表卸载
|
||||
await mcp_registry.unload_plugin(user.user_id, plugin.plugin_name)
|
||||
# 先更新数据库状态
|
||||
plugin.enabled = enabled
|
||||
if not enabled:
|
||||
plugin.status = "inactive"
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(plugin)
|
||||
|
||||
# 数据库操作完成后,再进行MCP操作
|
||||
if enabled:
|
||||
# 启用:注册到统一门面
|
||||
try:
|
||||
if plugin_type in ["http", "streamable_http", "sse"] and server_url:
|
||||
success = await mcp_client.register(MCPPluginConfig(
|
||||
user_id=user.user_id,
|
||||
plugin_name=plugin_name,
|
||||
url=server_url,
|
||||
plugin_type=plugin_type,
|
||||
headers=headers,
|
||||
timeout=config.get('timeout', 60.0) if config else 60.0
|
||||
))
|
||||
else:
|
||||
success = False
|
||||
|
||||
# 更新状态
|
||||
plugin.status = "active" if success else "error"
|
||||
plugin.last_error = None if success else "加载失败"
|
||||
await db.commit()
|
||||
await db.refresh(plugin)
|
||||
except Exception as e:
|
||||
logger.error(f"注册插件失败: {plugin_name}, 错误: {e}")
|
||||
plugin.status = "error"
|
||||
plugin.last_error = str(e)
|
||||
await db.commit()
|
||||
await db.refresh(plugin)
|
||||
else:
|
||||
# 禁用:从统一门面注销(不影响数据库状态)
|
||||
try:
|
||||
await mcp_client.unregister(user.user_id, plugin_name)
|
||||
except Exception as e:
|
||||
logger.warning(f"注销插件时出错(可忽略): {plugin_name}, 错误: {e}")
|
||||
|
||||
action = "启用" if enabled else "禁用"
|
||||
logger.info(f"用户 {user.user_id} {action}插件: {plugin.plugin_name}")
|
||||
logger.info(f"用户 {user.user_id} {action}插件: {plugin_name}")
|
||||
return plugin
|
||||
|
||||
|
||||
@@ -399,7 +466,7 @@ async def test_plugin(
|
||||
"""
|
||||
测试插件连接并调用工具验证功能
|
||||
|
||||
使用新的MCPTestService进行测试
|
||||
使用MCPTestService进行测试
|
||||
"""
|
||||
|
||||
result = await db.execute(
|
||||
@@ -421,7 +488,7 @@ async def test_plugin(
|
||||
suggestions=["点击开关按钮启用插件"]
|
||||
)
|
||||
|
||||
# 使用新的测试服务
|
||||
# 使用测试服务
|
||||
try:
|
||||
test_result = await mcp_test_service.test_plugin_with_ai(plugin, user, db)
|
||||
|
||||
@@ -447,32 +514,77 @@ async def test_plugin(
|
||||
raise HTTPException(status_code=500, detail=f"测试失败: {str(e)}")
|
||||
|
||||
|
||||
async def _ensure_plugin_loaded(
|
||||
async def _ensure_plugin_registered(
|
||||
plugin: MCPPlugin,
|
||||
user_id: str
|
||||
) -> bool:
|
||||
"""
|
||||
确保插件已加载(共享逻辑)
|
||||
确保插件已注册到统一门面
|
||||
|
||||
Args:
|
||||
plugin: 插件对象
|
||||
user_id: 用户ID
|
||||
|
||||
Returns:
|
||||
是否加载成功
|
||||
是否成功
|
||||
|
||||
Raises:
|
||||
HTTPException: 加载失败
|
||||
HTTPException: 注册失败
|
||||
"""
|
||||
if not mcp_registry.get_client(user_id, plugin.plugin_name):
|
||||
logger.info(f"插件 {plugin.plugin_name} 未加载,自动加载中...")
|
||||
success = await mcp_registry.load_plugin(plugin)
|
||||
try:
|
||||
# 使用ensure_registered方法,它会检查是否已注册
|
||||
if plugin.plugin_type in ["http", "streamable_http", "sse"] and plugin.server_url:
|
||||
return await mcp_client.ensure_registered(
|
||||
user_id=user_id,
|
||||
plugin_name=plugin.plugin_name,
|
||||
url=plugin.server_url,
|
||||
plugin_type=plugin.plugin_type,
|
||||
headers=plugin.headers
|
||||
)
|
||||
return False
|
||||
except ValueError as e:
|
||||
logger.info(f"插件 {plugin.plugin_name} 未注册,自动注册中...")
|
||||
success = await _register_plugin_to_facade(plugin, user_id)
|
||||
if not success:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"插件加载失败: {plugin.plugin_name}"
|
||||
detail=f"插件注册失败: {plugin.plugin_name}"
|
||||
)
|
||||
return True
|
||||
return True
|
||||
|
||||
|
||||
@router.get("/{plugin_id}/status")
|
||||
async def get_plugin_status(
|
||||
plugin_id: str,
|
||||
user: User = Depends(require_login),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""获取插件的实时状态(包括内存中的会话状态)"""
|
||||
result = await db.execute(
|
||||
select(MCPPlugin).where(
|
||||
MCPPlugin.id == plugin_id,
|
||||
MCPPlugin.user_id == user.user_id
|
||||
)
|
||||
)
|
||||
plugin = result.scalar_one_or_none()
|
||||
|
||||
if not plugin:
|
||||
raise HTTPException(status_code=404, detail="插件不存在")
|
||||
|
||||
session_stats = mcp_client.get_session_stats()
|
||||
session_key = f"{user.user_id}:{plugin.plugin_name}"
|
||||
session_info = next((s for s in session_stats.get("sessions", []) if s["key"] == session_key), None)
|
||||
|
||||
return {
|
||||
"plugin_id": plugin_id,
|
||||
"plugin_name": plugin.plugin_name,
|
||||
"db_status": plugin.status,
|
||||
"session_status": session_info["status"] if session_info else None,
|
||||
"is_registered": session_info is not None,
|
||||
"error_rate": session_info["error_rate"] if session_info else 0,
|
||||
"in_sync": (plugin.status == session_info["status"]) if session_info else (plugin.status == "inactive"),
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
|
||||
@router.get("/metrics")
|
||||
@@ -495,7 +607,8 @@ async def get_metrics(
|
||||
- avg_duration_ms: 平均耗时(毫秒)
|
||||
- last_call_time: 最后调用时间
|
||||
"""
|
||||
metrics = mcp_tool_service.get_metrics(tool_name)
|
||||
# 使用统一门面获取指标
|
||||
metrics = mcp_client.get_metrics(tool_name)
|
||||
|
||||
return {
|
||||
"metrics": metrics,
|
||||
@@ -518,7 +631,8 @@ async def get_cache_stats(
|
||||
- cache_ttl_minutes: 缓存TTL(分钟)
|
||||
- entries: 各缓存条目详情
|
||||
"""
|
||||
stats = mcp_tool_service.get_cache_stats()
|
||||
# 使用统一门面获取缓存统计
|
||||
stats = mcp_client.get_cache_stats()
|
||||
|
||||
return {
|
||||
"cache_stats": stats,
|
||||
@@ -526,6 +640,27 @@ async def get_cache_stats(
|
||||
}
|
||||
|
||||
|
||||
@router.get("/sessions/stats")
|
||||
async def get_session_stats(
|
||||
user: User = Depends(require_login)
|
||||
):
|
||||
"""
|
||||
获取MCP会话统计信息
|
||||
|
||||
Returns:
|
||||
会话统计信息,包含:
|
||||
- total_sessions: 会话总数
|
||||
- sessions: 各会话详情
|
||||
"""
|
||||
# 使用统一门面获取会话统计
|
||||
stats = mcp_client.get_session_stats()
|
||||
|
||||
return {
|
||||
"session_stats": stats,
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
|
||||
@router.post("/cache/clear")
|
||||
async def clear_cache(
|
||||
user_id: Optional[str] = Query(None, description="用户ID(可选)"),
|
||||
@@ -551,7 +686,8 @@ async def clear_cache(
|
||||
# 如果没有指定user_id,使用当前用户
|
||||
target_user_id = user_id or user.user_id
|
||||
|
||||
mcp_tool_service.clear_cache(target_user_id, plugin_name)
|
||||
# 使用统一门面清理缓存
|
||||
mcp_client.clear_cache(target_user_id, plugin_name)
|
||||
|
||||
message = "已清理"
|
||||
if plugin_name:
|
||||
@@ -594,12 +730,13 @@ async def get_plugin_tools(
|
||||
raise HTTPException(status_code=400, detail="插件未启用")
|
||||
|
||||
try:
|
||||
# 确保插件已加载
|
||||
await _ensure_plugin_loaded(plugin, user.user_id)
|
||||
# 确保插件已注册
|
||||
await _ensure_plugin_registered(plugin, user.user_id)
|
||||
|
||||
tools = await mcp_registry.get_plugin_tools(user.user_id, plugin.plugin_name)
|
||||
# 使用统一门面获取工具列表
|
||||
tools = await mcp_client.get_tools(user.user_id, plugin.plugin_name)
|
||||
|
||||
# 更新缓存
|
||||
# 更新数据库中的工具缓存
|
||||
plugin.tools = tools
|
||||
await db.commit()
|
||||
|
||||
@@ -640,22 +777,22 @@ async def call_mcp_tool(
|
||||
raise HTTPException(status_code=400, detail="插件未启用")
|
||||
|
||||
try:
|
||||
# 确保插件已加载
|
||||
await _ensure_plugin_loaded(plugin, user.user_id)
|
||||
# 确保插件已注册
|
||||
await _ensure_plugin_registered(plugin, user.user_id)
|
||||
|
||||
# 调用工具
|
||||
result = await mcp_registry.call_tool(
|
||||
user.user_id,
|
||||
plugin.plugin_name,
|
||||
data.tool_name,
|
||||
data.arguments
|
||||
# 使用统一门面调用工具
|
||||
tool_result = await mcp_client.call_tool(
|
||||
user_id=user.user_id,
|
||||
plugin_name=plugin.plugin_name,
|
||||
tool_name=data.tool_name,
|
||||
arguments=data.arguments
|
||||
)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"plugin_name": plugin.plugin_name,
|
||||
"tool_name": data.tool_name,
|
||||
"result": result
|
||||
"result": tool_result
|
||||
}
|
||||
except HTTPException:
|
||||
raise
|
||||
|
||||
+235
-552
File diff suppressed because it is too large
Load Diff
+253
-7
@@ -22,7 +22,7 @@ from app.schemas.settings import (
|
||||
from app.user_manager import User
|
||||
from app.logger import get_logger
|
||||
from app.config import settings as app_settings, PROJECT_ROOT
|
||||
from app.services.ai_service import AIService, create_user_ai_service
|
||||
from app.services.ai_service import AIService, create_user_ai_service, create_user_ai_service_with_mcp
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -53,9 +53,14 @@ async def get_user_ai_service(
|
||||
db: AsyncSession = Depends(get_db)
|
||||
) -> AIService:
|
||||
"""
|
||||
依赖:获取当前用户的AI服务实例
|
||||
从数据库读取用户设置并创建对应的AI服务
|
||||
依赖:获取当前用户的AI服务实例(支持MCP工具自动加载)
|
||||
|
||||
从数据库读取用户设置并创建对应的AI服务。
|
||||
自动传递 user_id 和 db_session,使得 AIService 能够加载用户配置的MCP工具。
|
||||
根据用户的所有MCP插件状态决定是否启用MCP:如果有启用的插件则启用,否则禁用。
|
||||
"""
|
||||
from app.models.mcp_plugin import MCPPlugin
|
||||
|
||||
result = await db.execute(
|
||||
select(Settings).where(Settings.user_id == user.user_id)
|
||||
)
|
||||
@@ -73,15 +78,34 @@ async def get_user_ai_service(
|
||||
await db.refresh(settings)
|
||||
logger.info(f"用户 {user.user_id} 首次使用AI服务,已从.env同步设置到数据库")
|
||||
|
||||
# 使用用户设置创建AI服务实例(包括系统提示词)
|
||||
return create_user_ai_service(
|
||||
# 查询用户的所有MCP插件状态
|
||||
mcp_result = await db.execute(
|
||||
select(MCPPlugin).where(MCPPlugin.user_id == user.user_id)
|
||||
)
|
||||
mcp_plugins = mcp_result.scalars().all()
|
||||
|
||||
# 检查是否有启用的MCP插件
|
||||
enable_mcp = any(plugin.enabled for plugin in mcp_plugins) if mcp_plugins else False
|
||||
|
||||
if mcp_plugins:
|
||||
enabled_count = sum(1 for p in mcp_plugins if p.enabled)
|
||||
logger.info(f"用户 {user.user_id} 有 {len(mcp_plugins)} 个MCP插件,{enabled_count} 个启用,{enable_mcp} 决定使用MCP")
|
||||
else:
|
||||
logger.debug(f"用户 {user.user_id} 没有配置MCP插件,禁用MCP")
|
||||
|
||||
# ✅ 使用支持MCP的工厂函数创建AI服务实例
|
||||
# 传递 user_id 和 db_session,使得 AIService 能够自动加载用户配置的MCP工具
|
||||
return create_user_ai_service_with_mcp(
|
||||
api_provider=settings.api_provider,
|
||||
api_key=settings.api_key,
|
||||
api_base_url=settings.api_base_url or "",
|
||||
model_name=settings.llm_model,
|
||||
temperature=settings.temperature,
|
||||
max_tokens=settings.max_tokens,
|
||||
system_prompt=settings.system_prompt # 传递系统提示词
|
||||
user_id=user.user_id, # ✅ 传递 user_id
|
||||
db_session=db, # ✅ 传递 db_session
|
||||
system_prompt=settings.system_prompt,
|
||||
enable_mcp=enable_mcp, # 根据MCP插件状态动态决定
|
||||
)
|
||||
|
||||
|
||||
@@ -327,6 +351,227 @@ class ApiTestRequest(BaseModel):
|
||||
llm_model: str
|
||||
|
||||
|
||||
@router.post("/check-function-calling")
|
||||
async def check_function_calling_support(data: ApiTestRequest):
|
||||
"""
|
||||
检查模型是否支持 Function Calling(工具调用)
|
||||
|
||||
基于业界最佳实践的测试方法:
|
||||
1. 发送包含工具定义的请求
|
||||
2. 检查响应的 finish_reason 是否为 "tool_calls"
|
||||
3. 验证响应中是否包含有效的 tool_calls 数据
|
||||
|
||||
Args:
|
||||
data: 包含 API 配置的请求数据
|
||||
|
||||
Returns:
|
||||
检测结果包含支持状态、详细信息和建议
|
||||
"""
|
||||
api_key = data.api_key
|
||||
api_base_url = data.api_base_url
|
||||
provider = data.provider
|
||||
llm_model = data.llm_model
|
||||
|
||||
try:
|
||||
start_time = time.time()
|
||||
|
||||
# 定义一个简单的测试工具(天气查询)
|
||||
test_tools = [{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"description": "获取指定城市的当前天气信息",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {
|
||||
"type": "string",
|
||||
"description": "城市名称,例如:北京、上海、深圳"
|
||||
},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
"description": "温度单位"
|
||||
}
|
||||
},
|
||||
"required": ["city"]
|
||||
}
|
||||
}
|
||||
}]
|
||||
|
||||
# 测试提示:故意设计一个需要调用工具的问题
|
||||
test_prompt = "请告诉我北京现在的天气情况如何?"
|
||||
|
||||
logger.info(f"🧪 开始检测 Function Calling 支持")
|
||||
logger.info(f" - 提供商: {provider}")
|
||||
logger.info(f" - 模型: {llm_model}")
|
||||
logger.info(f" - 测试工具: get_weather")
|
||||
|
||||
# 创建临时 AI 服务实例进行测试
|
||||
test_service = AIService(
|
||||
api_provider=provider,
|
||||
api_key=api_key,
|
||||
api_base_url=api_base_url,
|
||||
default_model=llm_model,
|
||||
default_temperature=0.3, # 使用较低温度以获得更确定的行为
|
||||
default_max_tokens=200
|
||||
)
|
||||
|
||||
# 发送带工具的测试请求
|
||||
response = await test_service.generate_text(
|
||||
prompt=test_prompt,
|
||||
provider=provider,
|
||||
model=llm_model,
|
||||
temperature=0.3,
|
||||
max_tokens=200,
|
||||
tools=test_tools,
|
||||
tool_choice="auto", # 让模型自动决定是否使用工具
|
||||
auto_mcp=False # 禁用 MCP 自动加载
|
||||
)
|
||||
|
||||
end_time = time.time()
|
||||
response_time = round((end_time - start_time) * 1000, 2)
|
||||
|
||||
# 分析响应以确定是否支持 Function Calling
|
||||
supported = False
|
||||
finish_reason = None
|
||||
tool_calls = None
|
||||
response_content = None
|
||||
|
||||
if isinstance(response, dict):
|
||||
# 检查 finish_reason(OpenAI 标准)
|
||||
finish_reason = response.get("finish_reason")
|
||||
|
||||
# 检查是否有 tool_calls
|
||||
if "tool_calls" in response and response["tool_calls"]:
|
||||
supported = True
|
||||
tool_calls = response["tool_calls"]
|
||||
logger.info(f"✅ 检测到工具调用: {len(tool_calls)} 个")
|
||||
|
||||
# 记录返回的内容(如果有)
|
||||
if "content" in response:
|
||||
response_content = response["content"]
|
||||
elif isinstance(response, str):
|
||||
# 如果只返回字符串,说明不支持工具调用
|
||||
response_content = response
|
||||
|
||||
logger.info(f" - 响应时间: {response_time}ms")
|
||||
logger.info(f" - finish_reason: {finish_reason}")
|
||||
logger.info(f" - 支持状态: {'✅ 支持' if supported else '❌ 不支持'}")
|
||||
|
||||
# 构建详细的返回信息
|
||||
result = {
|
||||
"success": True,
|
||||
"supported": supported,
|
||||
"message": "✅ 模型支持 Function Calling" if supported else "❌ 模型不支持 Function Calling",
|
||||
"response_time_ms": response_time,
|
||||
"provider": provider,
|
||||
"model": llm_model,
|
||||
"details": {
|
||||
"finish_reason": finish_reason,
|
||||
"has_tool_calls": bool(tool_calls),
|
||||
"tool_call_count": len(tool_calls) if tool_calls else 0,
|
||||
"test_tool": "get_weather",
|
||||
"test_prompt": test_prompt,
|
||||
"response_type": "tool_calls" if supported else "text"
|
||||
}
|
||||
}
|
||||
|
||||
# 添加工具调用详情
|
||||
if tool_calls:
|
||||
result["tool_calls"] = tool_calls
|
||||
result["suggestions"] = [
|
||||
"✅ 该模型支持 Function Calling,可以正常使用 MCP 插件",
|
||||
"建议:启用需要的 MCP 插件以扩展 AI 能力",
|
||||
"提示:测试成功检测到工具调用,模型能够正确解析和使用外部工具"
|
||||
]
|
||||
else:
|
||||
result["response_preview"] = response_content[:200] if response_content else None
|
||||
result["suggestions"] = [
|
||||
"❌ 该模型不支持 Function Calling,无法使用 MCP 插件功能",
|
||||
"建议:更换支持工具调用的模型",
|
||||
"推荐模型:GPT-4 系列、GPT-4-turbo、Claude 3 Opus/Sonnet、Gemini 1.5 Pro 等",
|
||||
"说明:模型返回了文本回复而非工具调用,表明不支持该功能"
|
||||
]
|
||||
|
||||
return result
|
||||
|
||||
except ValueError as e:
|
||||
error_msg = str(e)
|
||||
logger.error(f"❌ Function Calling 检测配置错误: {error_msg}")
|
||||
return {
|
||||
"success": False,
|
||||
"supported": False,
|
||||
"message": "配置错误",
|
||||
"error": error_msg,
|
||||
"error_type": "ConfigurationError",
|
||||
"suggestions": [
|
||||
"请检查 API Key 是否正确",
|
||||
"请确认 API Base URL 格式是否正确",
|
||||
"请验证所选提供商与配置是否匹配"
|
||||
]
|
||||
}
|
||||
|
||||
except TimeoutError as e:
|
||||
error_msg = str(e)
|
||||
logger.error(f"❌ Function Calling 检测超时: {error_msg}")
|
||||
return {
|
||||
"success": False,
|
||||
"supported": None,
|
||||
"message": "检测超时",
|
||||
"error": error_msg,
|
||||
"error_type": "TimeoutError",
|
||||
"suggestions": [
|
||||
"请检查网络连接是否正常",
|
||||
"请确认 API 服务是否可访问",
|
||||
"建议:稍后重试或使用其他网络环境"
|
||||
]
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
error_type = type(e).__name__
|
||||
|
||||
logger.error(f"❌ Function Calling 检测失败: {error_msg}")
|
||||
logger.error(f" - 错误类型: {error_type}")
|
||||
|
||||
# 智能分析错误原因
|
||||
suggestions = []
|
||||
if "tool" in error_msg.lower() or "function" in error_msg.lower():
|
||||
suggestions = [
|
||||
"该模型可能不支持 Function Calling 功能",
|
||||
"API 返回了与工具调用相关的错误",
|
||||
"建议:更换支持工具调用的模型或联系 API 提供商"
|
||||
]
|
||||
elif "unauthorized" in error_msg.lower() or "401" in error_msg:
|
||||
suggestions = [
|
||||
"API Key 认证失败",
|
||||
"请检查 API Key 是否正确且有效",
|
||||
"请确认 API Key 是否有足够的权限"
|
||||
]
|
||||
elif "not found" in error_msg.lower() or "404" in error_msg:
|
||||
suggestions = [
|
||||
"模型不存在或不可用",
|
||||
"请检查模型名称是否正确",
|
||||
"请确认该模型在当前 API 中是否可用"
|
||||
]
|
||||
else:
|
||||
suggestions = [
|
||||
"检测过程中遇到未知错误",
|
||||
"建议:检查所有配置参数是否正确",
|
||||
"提示:查看详细错误信息以获取更多线索"
|
||||
]
|
||||
|
||||
return {
|
||||
"success": False,
|
||||
"supported": False,
|
||||
"message": "Function Calling 检测失败",
|
||||
"error": error_msg,
|
||||
"error_type": error_type,
|
||||
"suggestions": suggestions
|
||||
}
|
||||
|
||||
|
||||
@router.post("/test")
|
||||
async def test_api_connection(data: ApiTestRequest):
|
||||
"""
|
||||
@@ -370,7 +615,8 @@ async def test_api_connection(data: ApiTestRequest):
|
||||
provider=provider,
|
||||
model=llm_model,
|
||||
temperature=0.7,
|
||||
max_tokens=8000
|
||||
max_tokens=8000,
|
||||
auto_mcp=False # 测试时不加载MCP工具
|
||||
)
|
||||
|
||||
end_time = time.time()
|
||||
|
||||
+361
-425
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user