feat: 优化MCP工具调用体验并集成通用适配器
- 静默检查MCP工具可用性,支持提示词注入调用mcp - 集成UniversalMCPAdapter,支持自动API能力检测和智能降级 - 新增MCP适配器配置项,增强系统兼容性和健壮性
This commit is contained in:
@@ -68,11 +68,19 @@ async def world_building_generator(
|
||||
reference_materials = ""
|
||||
if enable_mcp and user_id:
|
||||
try:
|
||||
yield await SSEResponse.send_progress("🔍 尝试使用MCP工具收集参考资料...", 18)
|
||||
# 先静默检查是否有可用工具
|
||||
from app.services.mcp_tool_service import mcp_tool_service
|
||||
available_tools = await mcp_tool_service.get_user_enabled_tools(
|
||||
user_id=user_id,
|
||||
db_session=db
|
||||
)
|
||||
|
||||
# 直接调用MCP增强的AI,内部会自动检查和加载工具
|
||||
# 构建资料收集提示词
|
||||
planning_prompt = f"""你正在为小说《{title}》设计世界观。
|
||||
# 只有在真正有可用工具时才显示消息和调用
|
||||
if available_tools:
|
||||
yield await SSEResponse.send_progress("🔍 尝试使用MCP工具收集参考资料...", 18)
|
||||
|
||||
# 构建资料收集提示词
|
||||
planning_prompt = f"""你正在为小说《{title}》设计世界观。
|
||||
|
||||
【小说信息】
|
||||
- 题材:{genre}
|
||||
@@ -88,28 +96,32 @@ async def world_building_generator(
|
||||
4. 类似作品的设定参考
|
||||
|
||||
请根据题材特点,有针对性地查询2-3个关键问题。"""
|
||||
|
||||
# 调用MCP增强的AI(非流式,最多2轮工具调用)
|
||||
planning_result = await user_ai_service.generate_text_with_mcp(
|
||||
prompt=planning_prompt,
|
||||
user_id=user_id,
|
||||
db_session=db,
|
||||
enable_mcp=True,
|
||||
max_tool_rounds=2,
|
||||
tool_choice="auto",
|
||||
provider=None,
|
||||
model=None
|
||||
)
|
||||
|
||||
# 提取参考资料
|
||||
if planning_result.get("tool_calls_made", 0) > 0:
|
||||
yield await SSEResponse.send_progress(
|
||||
f"✅ MCP工具调用成功({planning_result['tool_calls_made']}次)",
|
||||
25
|
||||
|
||||
# 调用MCP增强的AI(非流式,最多2轮工具调用)
|
||||
planning_result = await user_ai_service.generate_text_with_mcp(
|
||||
prompt=planning_prompt,
|
||||
user_id=user_id,
|
||||
db_session=db,
|
||||
enable_mcp=True,
|
||||
max_tool_rounds=2,
|
||||
tool_choice="auto",
|
||||
provider=None,
|
||||
model=None
|
||||
)
|
||||
reference_materials = planning_result.get("content", "")
|
||||
|
||||
# 提取参考资料
|
||||
if planning_result.get("tool_calls_made", 0) > 0:
|
||||
yield await SSEResponse.send_progress(
|
||||
f"✅ MCP工具调用成功({planning_result['tool_calls_made']}次)",
|
||||
25
|
||||
)
|
||||
reference_materials = planning_result.get("content", "")
|
||||
else:
|
||||
# 有工具但未使用
|
||||
logger.debug("MCP工具可用但AI未选择使用")
|
||||
else:
|
||||
yield await SSEResponse.send_progress("ℹ️ 未使用MCP工具(无可用工具或不需要)", 25)
|
||||
# 没有可用工具,静默跳过
|
||||
logger.debug(f"用户 {user_id} 未启用MCP工具,跳过MCP增强")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"MCP工具调用失败(降级处理): {e}")
|
||||
@@ -325,10 +337,19 @@ async def characters_generator(
|
||||
character_reference_materials = ""
|
||||
if enable_mcp and user_id:
|
||||
try:
|
||||
yield await SSEResponse.send_progress("🔍 尝试使用MCP工具收集角色参考资料...", 8)
|
||||
# 先静默检查是否有可用工具
|
||||
from app.services.mcp_tool_service import mcp_tool_service
|
||||
available_tools = await mcp_tool_service.get_user_enabled_tools(
|
||||
user_id=user_id,
|
||||
db_session=db
|
||||
)
|
||||
|
||||
# 构建角色资料收集提示词
|
||||
planning_prompt = f"""你正在为小说《{project.title}》设计角色。
|
||||
# 只有在真正有可用工具时才显示消息和调用
|
||||
if available_tools:
|
||||
yield await SSEResponse.send_progress("🔍 尝试使用MCP工具收集角色参考资料...", 8)
|
||||
|
||||
# 构建角色资料收集提示词
|
||||
planning_prompt = f"""你正在为小说《{project.title}》设计角色。
|
||||
|
||||
【小说信息】
|
||||
- 题材:{genre or project.genre}
|
||||
@@ -345,28 +366,32 @@ async def characters_generator(
|
||||
4. 相关领域的人物原型
|
||||
|
||||
请根据题材特点,有针对性地查询1-2个关键问题。"""
|
||||
|
||||
# 调用MCP增强的AI(非流式,最多2轮工具调用)
|
||||
planning_result = await user_ai_service.generate_text_with_mcp(
|
||||
prompt=planning_prompt,
|
||||
user_id=user_id,
|
||||
db_session=db,
|
||||
enable_mcp=True,
|
||||
max_tool_rounds=2,
|
||||
tool_choice="auto",
|
||||
provider=None,
|
||||
model=None
|
||||
)
|
||||
|
||||
# 提取参考资料
|
||||
if planning_result.get("tool_calls_made", 0) > 0:
|
||||
yield await SSEResponse.send_progress(
|
||||
f"✅ MCP工具调用成功({planning_result['tool_calls_made']}次)",
|
||||
12
|
||||
|
||||
# 调用MCP增强的AI(非流式,最多2轮工具调用)
|
||||
planning_result = await user_ai_service.generate_text_with_mcp(
|
||||
prompt=planning_prompt,
|
||||
user_id=user_id,
|
||||
db_session=db,
|
||||
enable_mcp=True,
|
||||
max_tool_rounds=2,
|
||||
tool_choice="auto",
|
||||
provider=None,
|
||||
model=None
|
||||
)
|
||||
character_reference_materials = planning_result.get("content", "")
|
||||
|
||||
# 提取参考资料
|
||||
if planning_result.get("tool_calls_made", 0) > 0:
|
||||
yield await SSEResponse.send_progress(
|
||||
f"✅ MCP工具调用成功({planning_result['tool_calls_made']}次)",
|
||||
12
|
||||
)
|
||||
character_reference_materials = planning_result.get("content", "")
|
||||
else:
|
||||
# 有工具但未使用
|
||||
logger.debug("MCP工具可用但AI未选择使用")
|
||||
else:
|
||||
yield await SSEResponse.send_progress("ℹ️ 未使用MCP工具(无可用工具或不需要)", 12)
|
||||
# 没有可用工具,静默跳过
|
||||
logger.debug(f"用户 {user_id} 未启用MCP工具,跳过MCP增强")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"MCP工具调用失败(降级处理): {e}")
|
||||
|
||||
@@ -77,6 +77,11 @@ class Settings(BaseSettings):
|
||||
default_temperature: float = 0.7
|
||||
default_max_tokens: int = 2000
|
||||
|
||||
# MCP适配器配置
|
||||
enable_mcp_adapter: bool = True # 是否启用MCP适配器(自动检测API能力)
|
||||
mcp_adapter_cache_ttl_hours: int = 24 # API能力检测缓存时长(小时)
|
||||
mcp_adapter_auto_fallback: bool = True # 是否启用自动降级(FC失败时切换到提示词注入)
|
||||
|
||||
# LinuxDO OAuth2 配置
|
||||
LINUXDO_CLIENT_ID: Optional[str] = None
|
||||
LINUXDO_CLIENT_SECRET: Optional[str] = None
|
||||
|
||||
@@ -0,0 +1,14 @@
|
||||
"""MCP适配器模块 - 支持多种AI API的工具调用方式"""
|
||||
|
||||
from .base import BaseMCPAdapter, AdapterType
|
||||
from .prompt_injection import PromptInjectionAdapter
|
||||
from .function_calling import FunctionCallingAdapter
|
||||
from .universal import UniversalMCPAdapter
|
||||
|
||||
__all__ = [
|
||||
"BaseMCPAdapter",
|
||||
"AdapterType",
|
||||
"PromptInjectionAdapter",
|
||||
"FunctionCallingAdapter",
|
||||
"UniversalMCPAdapter",
|
||||
]
|
||||
@@ -0,0 +1,89 @@
|
||||
"""MCP适配器基类"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from typing import Dict, Any, List, Optional
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
class AdapterType(Enum):
|
||||
"""适配器类型"""
|
||||
FUNCTION_CALLING = "function_calling" # 标准Function Calling
|
||||
PROMPT_INJECTION = "prompt_injection" # 提示词注入
|
||||
REACT = "react" # ReAct模式
|
||||
XML = "xml" # XML标记
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolCallResult:
|
||||
"""工具调用结果"""
|
||||
tool_calls: List[Dict[str, Any]] # 解析出的工具调用
|
||||
raw_response: str # 原始AI响应
|
||||
has_tool_calls: bool # 是否包含工具调用
|
||||
needs_continuation: bool = False # 是否需要继续对话
|
||||
|
||||
|
||||
class BaseMCPAdapter(ABC):
|
||||
"""MCP适配器基类"""
|
||||
|
||||
def __init__(self):
|
||||
self.adapter_type: AdapterType = AdapterType.PROMPT_INJECTION
|
||||
|
||||
@abstractmethod
|
||||
def format_tools_for_prompt(
|
||||
self,
|
||||
tools: List[Dict[str, Any]],
|
||||
user_message: str
|
||||
) -> str:
|
||||
"""
|
||||
将工具列表格式化为提示词
|
||||
|
||||
Args:
|
||||
tools: MCP工具列表
|
||||
user_message: 用户消息
|
||||
|
||||
Returns:
|
||||
格式化后的提示词
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def parse_tool_calls(self, ai_response: str) -> ToolCallResult:
|
||||
"""
|
||||
从AI响应中解析工具调用
|
||||
|
||||
Args:
|
||||
ai_response: AI的原始响应
|
||||
|
||||
Returns:
|
||||
解析结果
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def build_continuation_prompt(
|
||||
self,
|
||||
original_message: str,
|
||||
ai_response: str,
|
||||
tool_results: List[Dict[str, Any]]
|
||||
) -> str:
|
||||
"""
|
||||
构建包含工具结果的继续对话提示词
|
||||
|
||||
Args:
|
||||
original_message: 原始用户消息
|
||||
ai_response: AI响应
|
||||
tool_results: 工具执行结果
|
||||
|
||||
Returns:
|
||||
继续对话的提示词
|
||||
"""
|
||||
pass
|
||||
|
||||
def supports_native_tools(self) -> bool:
|
||||
"""是否支持原生工具调用(如Function Calling)"""
|
||||
return False
|
||||
|
||||
def get_adapter_type(self) -> AdapterType:
|
||||
"""获取适配器类型"""
|
||||
return self.adapter_type
|
||||
@@ -0,0 +1,171 @@
|
||||
"""Function Calling适配器 - 支持原生Function Calling的API"""
|
||||
|
||||
import json
|
||||
from typing import Dict, Any, List
|
||||
from app.mcp.adapters.base import BaseMCPAdapter, AdapterType, ToolCallResult
|
||||
from app.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class FunctionCallingAdapter(BaseMCPAdapter):
|
||||
"""Function Calling适配器 - 用于支持原生工具调用的AI API(如OpenAI)"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.adapter_type = AdapterType.FUNCTION_CALLING
|
||||
|
||||
def supports_native_tools(self) -> bool:
|
||||
"""支持原生工具调用"""
|
||||
return True
|
||||
|
||||
def format_tools_for_prompt(
|
||||
self,
|
||||
tools: List[Dict[str, Any]],
|
||||
user_message: str
|
||||
) -> str:
|
||||
"""
|
||||
Function Calling模式下,工具通过API参数传递,不需要修改提示词
|
||||
|
||||
Returns:
|
||||
原始用户消息
|
||||
"""
|
||||
return user_message
|
||||
|
||||
def get_tools_for_api(self, tools: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取适用于API的工具格式
|
||||
|
||||
Args:
|
||||
tools: MCP工具列表
|
||||
|
||||
Returns:
|
||||
适用于OpenAI Function Calling的工具格式
|
||||
"""
|
||||
return tools
|
||||
|
||||
def parse_tool_calls(self, ai_response: Any) -> ToolCallResult:
|
||||
"""
|
||||
从AI响应中解析工具调用(Function Calling格式)
|
||||
|
||||
Args:
|
||||
ai_response: AI响应对象(通常是OpenAI的ChatCompletion对象)
|
||||
|
||||
Returns:
|
||||
解析结果
|
||||
"""
|
||||
|
||||
try:
|
||||
# 处理不同类型的响应
|
||||
if isinstance(ai_response, dict):
|
||||
# 字典格式(OpenAI API响应)
|
||||
message = ai_response.get("choices", [{}])[0].get("message", {})
|
||||
tool_calls = message.get("tool_calls", [])
|
||||
content = message.get("content", "")
|
||||
|
||||
elif hasattr(ai_response, "choices"):
|
||||
# 对象格式(OpenAI SDK响应)
|
||||
message = ai_response.choices[0].message
|
||||
tool_calls = getattr(message, "tool_calls", None) or []
|
||||
content = getattr(message, "content", "") or ""
|
||||
|
||||
# 转换为字典格式
|
||||
if tool_calls:
|
||||
tool_calls = [
|
||||
{
|
||||
"id": tc.id,
|
||||
"type": tc.type,
|
||||
"function": {
|
||||
"name": tc.function.name,
|
||||
"arguments": tc.function.arguments
|
||||
}
|
||||
}
|
||||
for tc in tool_calls
|
||||
]
|
||||
else:
|
||||
# 字符串格式(降级为文本响应)
|
||||
return ToolCallResult(
|
||||
tool_calls=[],
|
||||
raw_response=str(ai_response),
|
||||
has_tool_calls=False
|
||||
)
|
||||
|
||||
has_tool_calls = len(tool_calls) > 0
|
||||
|
||||
if has_tool_calls:
|
||||
logger.info(f"✅ Function Calling模式解析出 {len(tool_calls)} 个工具调用")
|
||||
for tc in tool_calls:
|
||||
logger.info(f" - {tc['function']['name']}")
|
||||
|
||||
return ToolCallResult(
|
||||
tool_calls=tool_calls,
|
||||
raw_response=content or "",
|
||||
has_tool_calls=has_tool_calls,
|
||||
needs_continuation=has_tool_calls
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 解析Function Calling响应失败: {e}", exc_info=True)
|
||||
return ToolCallResult(
|
||||
tool_calls=[],
|
||||
raw_response=str(ai_response),
|
||||
has_tool_calls=False
|
||||
)
|
||||
|
||||
def build_continuation_prompt(
|
||||
self,
|
||||
original_message: str,
|
||||
ai_response: str,
|
||||
tool_results: List[Dict[str, Any]]
|
||||
) -> str:
|
||||
"""
|
||||
构建包含工具结果的继续对话提示词
|
||||
|
||||
在Function Calling模式下,这通常不需要,因为工具结果会作为消息历史的一部分
|
||||
"""
|
||||
# Function Calling模式下通常通过消息历史传递工具结果
|
||||
# 这里提供一个降级方案
|
||||
results_text = "\n\n".join([
|
||||
f"工具 {r['name']} 的结果:\n{r['content']}"
|
||||
for r in tool_results
|
||||
])
|
||||
|
||||
return f"{original_message}\n\n工具执行结果:\n{results_text}\n\n请基于以上工具结果回答用户的问题。"
|
||||
|
||||
def build_messages_with_tool_results(
|
||||
self,
|
||||
messages: List[Dict[str, Any]],
|
||||
tool_calls: List[Dict[str, Any]],
|
||||
tool_results: List[Dict[str, Any]]
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
构建包含工具结果的消息历史(Function Calling标准格式)
|
||||
|
||||
Args:
|
||||
messages: 原始消息历史
|
||||
tool_calls: AI的工具调用
|
||||
tool_results: 工具执行结果
|
||||
|
||||
Returns:
|
||||
更新后的消息历史
|
||||
"""
|
||||
|
||||
new_messages = messages.copy()
|
||||
|
||||
# 添加助手的工具调用消息
|
||||
new_messages.append({
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": tool_calls
|
||||
})
|
||||
|
||||
# 添加工具结果消息
|
||||
for result in tool_results:
|
||||
new_messages.append({
|
||||
"role": "tool",
|
||||
"tool_call_id": result.get("tool_call_id", ""),
|
||||
"name": result.get("name", ""),
|
||||
"content": result.get("content", "")
|
||||
})
|
||||
|
||||
return new_messages
|
||||
@@ -0,0 +1,274 @@
|
||||
"""提示词注入适配器 - 最通用的MCP工具调用方式"""
|
||||
|
||||
import re
|
||||
import json
|
||||
from typing import Dict, Any, List
|
||||
from app.mcp.adapters.base import BaseMCPAdapter, AdapterType, ToolCallResult
|
||||
from app.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class PromptInjectionAdapter(BaseMCPAdapter):
|
||||
"""提示词注入适配器 - 将工具转换为文本描述,通过提示词引导AI调用"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.adapter_type = AdapterType.PROMPT_INJECTION
|
||||
|
||||
def format_tools_for_prompt(
|
||||
self,
|
||||
tools: List[Dict[str, Any]],
|
||||
user_message: str
|
||||
) -> str:
|
||||
"""将工具列表注入到提示词中"""
|
||||
|
||||
if not tools:
|
||||
return user_message
|
||||
|
||||
# 格式化工具描述
|
||||
tool_descriptions = self._format_tools_as_text(tools)
|
||||
|
||||
# 构建增强的提示词
|
||||
enhanced_prompt = f"""你现在可以使用以下工具来帮助回答用户的问题。
|
||||
|
||||
## 可用工具
|
||||
|
||||
{tool_descriptions}
|
||||
|
||||
## 工具使用说明
|
||||
|
||||
当你需要使用工具时,请按以下XML格式输出(可以一次调用多个工具):
|
||||
|
||||
<tool_calls>
|
||||
<tool_call>
|
||||
<tool_name>工具名称</tool_name>
|
||||
<arguments>
|
||||
{{
|
||||
"参数名1": "参数值1",
|
||||
"参数名2": "参数值2"
|
||||
}}
|
||||
</arguments>
|
||||
</tool_call>
|
||||
</tool_calls>
|
||||
|
||||
## 重要提示
|
||||
|
||||
1. 只有在确实需要使用工具时才调用工具
|
||||
2. 参数必须是有效的JSON格式
|
||||
3. 仔细检查参数是否符合工具的要求
|
||||
4. 可以在一个<tool_calls>标签内包含多个<tool_call>
|
||||
5. 调用工具后,你会收到工具的执行结果,然后需要基于结果继续回答
|
||||
|
||||
---
|
||||
|
||||
用户问题:{user_message}
|
||||
|
||||
请分析问题,判断是否需要使用工具。如果需要,先输出工具调用,然后等待结果。如果不需要,直接回答问题。"""
|
||||
|
||||
return enhanced_prompt
|
||||
|
||||
def _format_tools_as_text(self, tools: List[Dict[str, Any]]) -> str:
|
||||
"""将工具格式化为可读的文本描述"""
|
||||
lines = []
|
||||
|
||||
for i, tool in enumerate(tools, 1):
|
||||
func = tool.get("function", {})
|
||||
name = func.get("name", "unknown")
|
||||
description = func.get("description", "无描述")
|
||||
parameters = func.get("parameters", {})
|
||||
|
||||
lines.append(f"### {i}. {name}")
|
||||
lines.append(f"**描述**: {description}")
|
||||
lines.append("")
|
||||
|
||||
# 格式化参数信息
|
||||
if parameters and "properties" in parameters:
|
||||
lines.append("**参数**:")
|
||||
properties = parameters.get("properties", {})
|
||||
required = parameters.get("required", [])
|
||||
|
||||
for param_name, param_info in properties.items():
|
||||
param_type = param_info.get("type", "string")
|
||||
param_desc = param_info.get("description", "")
|
||||
is_required = "必填" if param_name in required else "可选"
|
||||
|
||||
lines.append(f" - `{param_name}` ({param_type}, {is_required}): {param_desc}")
|
||||
lines.append("")
|
||||
|
||||
# 添加示例
|
||||
if "example" in func:
|
||||
lines.append(f"**示例**: {json.dumps(func['example'], ensure_ascii=False)}")
|
||||
lines.append("")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
def parse_tool_calls(self, ai_response) -> ToolCallResult:
|
||||
"""从AI响应中解析工具调用"""
|
||||
|
||||
tool_calls = []
|
||||
|
||||
try:
|
||||
# 处理不同类型的响应
|
||||
if isinstance(ai_response, dict):
|
||||
# 如果是字典,提取content字段
|
||||
ai_response = ai_response.get("choices", [{}])[0].get("message", {}).get("content", "")
|
||||
if not ai_response:
|
||||
return ToolCallResult(
|
||||
tool_calls=[],
|
||||
raw_response="",
|
||||
has_tool_calls=False
|
||||
)
|
||||
elif not isinstance(ai_response, str):
|
||||
# 转换为字符串
|
||||
ai_response = str(ai_response)
|
||||
|
||||
# 使用正则提取 <tool_calls> 标签内容
|
||||
tool_calls_match = re.search(
|
||||
r'<tool_calls>(.*?)</tool_calls>',
|
||||
ai_response,
|
||||
re.DOTALL | re.IGNORECASE
|
||||
)
|
||||
|
||||
if not tool_calls_match:
|
||||
# 没有找到工具调用
|
||||
return ToolCallResult(
|
||||
tool_calls=[],
|
||||
raw_response=ai_response,
|
||||
has_tool_calls=False
|
||||
)
|
||||
|
||||
tool_calls_content = tool_calls_match.group(1)
|
||||
|
||||
# 提取所有 <tool_call> 标签
|
||||
tool_call_pattern = r'<tool_call>(.*?)</tool_call>'
|
||||
tool_call_matches = re.findall(
|
||||
tool_call_pattern,
|
||||
tool_calls_content,
|
||||
re.DOTALL | re.IGNORECASE
|
||||
)
|
||||
|
||||
for i, tool_call_content in enumerate(tool_call_matches):
|
||||
# 提取工具名称
|
||||
name_match = re.search(
|
||||
r'<tool_name>(.*?)</tool_name>',
|
||||
tool_call_content,
|
||||
re.DOTALL | re.IGNORECASE
|
||||
)
|
||||
|
||||
# 提取参数
|
||||
args_match = re.search(
|
||||
r'<arguments>(.*?)</arguments>',
|
||||
tool_call_content,
|
||||
re.DOTALL | re.IGNORECASE
|
||||
)
|
||||
|
||||
if name_match and args_match:
|
||||
tool_name = name_match.group(1).strip()
|
||||
arguments_str = args_match.group(1).strip()
|
||||
|
||||
try:
|
||||
# 解析JSON参数
|
||||
arguments = json.loads(arguments_str)
|
||||
|
||||
# 构建标准格式的工具调用
|
||||
tool_calls.append({
|
||||
"id": f"call_{i}",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool_name,
|
||||
"arguments": json.dumps(arguments, ensure_ascii=False)
|
||||
}
|
||||
})
|
||||
|
||||
logger.info(f"✅ 解析工具调用: {tool_name}")
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"❌ 解析工具参数失败: {arguments_str}, 错误: {e}")
|
||||
continue
|
||||
|
||||
has_tool_calls = len(tool_calls) > 0
|
||||
|
||||
if has_tool_calls:
|
||||
logger.info(f"✅ 从响应中解析出 {len(tool_calls)} 个工具调用")
|
||||
|
||||
return ToolCallResult(
|
||||
tool_calls=tool_calls,
|
||||
raw_response=ai_response,
|
||||
has_tool_calls=has_tool_calls,
|
||||
needs_continuation=has_tool_calls
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 解析工具调用失败: {e}", exc_info=True)
|
||||
return ToolCallResult(
|
||||
tool_calls=[],
|
||||
raw_response=ai_response,
|
||||
has_tool_calls=False
|
||||
)
|
||||
|
||||
def build_continuation_prompt(
|
||||
self,
|
||||
original_message: str,
|
||||
ai_response: str,
|
||||
tool_results: List[Dict[str, Any]]
|
||||
) -> str:
|
||||
"""构建包含工具结果的继续对话提示词"""
|
||||
|
||||
# 格式化工具结果
|
||||
results_text = self._format_tool_results(tool_results)
|
||||
|
||||
continuation = f"""你之前尝试使用工具来回答用户的问题。
|
||||
|
||||
原始问题:{original_message}
|
||||
|
||||
你的工具调用:
|
||||
{self._extract_tool_calls_text(ai_response)}
|
||||
|
||||
工具执行结果:
|
||||
{results_text}
|
||||
|
||||
现在,请基于这些工具的执行结果,给出完整、详细的回答。不要重复调用工具,直接使用已有的结果来回答用户的问题。"""
|
||||
|
||||
return continuation
|
||||
|
||||
def _format_tool_results(self, tool_results: List[Dict[str, Any]]) -> str:
|
||||
"""格式化工具结果为可读文本"""
|
||||
lines = []
|
||||
|
||||
for i, result in enumerate(tool_results, 1):
|
||||
tool_name = result.get("name", "unknown")
|
||||
success = result.get("success", False)
|
||||
content = result.get("content", "")
|
||||
|
||||
status = "✅ 成功" if success else "❌ 失败"
|
||||
lines.append(f"{i}. {tool_name} - {status}")
|
||||
|
||||
if success:
|
||||
# 尝试美化JSON内容
|
||||
try:
|
||||
if isinstance(content, str):
|
||||
content_obj = json.loads(content)
|
||||
content = json.dumps(content_obj, ensure_ascii=False, indent=2)
|
||||
except:
|
||||
pass
|
||||
lines.append(f"```\n{content}\n```")
|
||||
else:
|
||||
error = result.get("error", "未知错误")
|
||||
lines.append(f"错误信息: {error}")
|
||||
|
||||
lines.append("")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
def _extract_tool_calls_text(self, ai_response: str) -> str:
|
||||
"""从AI响应中提取工具调用部分的文本"""
|
||||
match = re.search(
|
||||
r'<tool_calls>(.*?)</tool_calls>',
|
||||
ai_response,
|
||||
re.DOTALL | re.IGNORECASE
|
||||
)
|
||||
|
||||
if match:
|
||||
return match.group(0)
|
||||
return "(未找到工具调用)"
|
||||
@@ -0,0 +1,353 @@
|
||||
"""通用MCP适配器 - 自动检测API能力并选择最佳适配器"""
|
||||
|
||||
import time
|
||||
import asyncio
|
||||
from typing import Dict, Any, List, Optional
|
||||
from datetime import datetime, timedelta
|
||||
from dataclasses import dataclass
|
||||
|
||||
from app.mcp.adapters.base import BaseMCPAdapter, AdapterType, ToolCallResult
|
||||
from app.mcp.adapters.prompt_injection import PromptInjectionAdapter
|
||||
from app.mcp.adapters.function_calling import FunctionCallingAdapter
|
||||
from app.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class APICapability:
|
||||
"""API能力检测结果"""
|
||||
supports_function_calling: bool
|
||||
tested_at: datetime
|
||||
test_duration_ms: float
|
||||
error_message: Optional[str] = None
|
||||
|
||||
|
||||
class UniversalMCPAdapter:
|
||||
"""
|
||||
通用MCP适配器管理器
|
||||
|
||||
功能:
|
||||
1. 自动检测API是否支持Function Calling
|
||||
2. 缓存检测结果
|
||||
3. 自动降级策略:FC失败时切换到提示词注入
|
||||
4. 提供统一接口
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cache_ttl_hours: int = 24,
|
||||
enable_auto_fallback: bool = True
|
||||
):
|
||||
"""
|
||||
初始化通用适配器
|
||||
|
||||
Args:
|
||||
cache_ttl_hours: 能力检测缓存时长(小时)
|
||||
enable_auto_fallback: 是否启用自动降级
|
||||
"""
|
||||
# 适配器实例
|
||||
self.adapters = {
|
||||
AdapterType.FUNCTION_CALLING: FunctionCallingAdapter(),
|
||||
AdapterType.PROMPT_INJECTION: PromptInjectionAdapter()
|
||||
}
|
||||
|
||||
# API能力缓存: {api_identifier: APICapability}
|
||||
self._capability_cache: Dict[str, APICapability] = {}
|
||||
self._cache_ttl = timedelta(hours=cache_ttl_hours)
|
||||
self._cache_lock = asyncio.Lock()
|
||||
|
||||
# 配置
|
||||
self._enable_auto_fallback = enable_auto_fallback
|
||||
|
||||
logger.info(
|
||||
f"✅ UniversalMCPAdapter初始化完成 "
|
||||
f"(缓存TTL={cache_ttl_hours}小时, 自动降级={'开启' if enable_auto_fallback else '关闭'})"
|
||||
)
|
||||
|
||||
async def get_adapter(
|
||||
self,
|
||||
api_identifier: str,
|
||||
test_function: Optional[callable] = None
|
||||
) -> BaseMCPAdapter:
|
||||
"""
|
||||
获取适合当前API的适配器
|
||||
|
||||
Args:
|
||||
api_identifier: API标识符(如"openai_official", "azure_openai"等)
|
||||
test_function: 可选的测试函数,用于检测API能力
|
||||
|
||||
Returns:
|
||||
最适合的适配器实例
|
||||
"""
|
||||
|
||||
# 检查缓存
|
||||
capability = await self._get_cached_capability(api_identifier)
|
||||
|
||||
if capability is None and test_function:
|
||||
# 缓存未命中,执行检测
|
||||
capability = await self._detect_capability(api_identifier, test_function)
|
||||
|
||||
# 选择适配器
|
||||
if capability and capability.supports_function_calling:
|
||||
logger.info(f"🎯 使用Function Calling适配器: {api_identifier}")
|
||||
return self.adapters[AdapterType.FUNCTION_CALLING]
|
||||
else:
|
||||
logger.info(f"🎯 使用提示词注入适配器: {api_identifier}")
|
||||
return self.adapters[AdapterType.PROMPT_INJECTION]
|
||||
|
||||
async def _get_cached_capability(
|
||||
self,
|
||||
api_identifier: str
|
||||
) -> Optional[APICapability]:
|
||||
"""获取缓存的能力检测结果"""
|
||||
|
||||
async with self._cache_lock:
|
||||
if api_identifier not in self._capability_cache:
|
||||
return None
|
||||
|
||||
capability = self._capability_cache[api_identifier]
|
||||
|
||||
# 检查是否过期
|
||||
if datetime.now() - capability.tested_at > self._cache_ttl:
|
||||
logger.info(f"⏰ API能力缓存过期: {api_identifier}")
|
||||
del self._capability_cache[api_identifier]
|
||||
return None
|
||||
|
||||
logger.debug(f"🎯 API能力缓存命中: {api_identifier}")
|
||||
return capability
|
||||
|
||||
async def _detect_capability(
|
||||
self,
|
||||
api_identifier: str,
|
||||
test_function: callable
|
||||
) -> APICapability:
|
||||
"""
|
||||
检测API能力
|
||||
|
||||
Args:
|
||||
api_identifier: API标识符
|
||||
test_function: 测试函数,应该尝试使用Function Calling
|
||||
|
||||
Returns:
|
||||
能力检测结果
|
||||
"""
|
||||
|
||||
logger.info(f"🔍 开始检测API能力: {api_identifier}")
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# 调用测试函数
|
||||
result = await test_function()
|
||||
|
||||
# 判断是否成功
|
||||
supports_fc = self._is_function_calling_response(result)
|
||||
|
||||
duration_ms = (time.time() - start_time) * 1000
|
||||
|
||||
capability = APICapability(
|
||||
supports_function_calling=supports_fc,
|
||||
tested_at=datetime.now(),
|
||||
test_duration_ms=duration_ms
|
||||
)
|
||||
|
||||
# 缓存结果
|
||||
async with self._cache_lock:
|
||||
self._capability_cache[api_identifier] = capability
|
||||
|
||||
status = "✅ 支持" if supports_fc else "❌ 不支持"
|
||||
logger.info(
|
||||
f"{status} Function Calling: {api_identifier} "
|
||||
f"(耗时: {duration_ms:.2f}ms)"
|
||||
)
|
||||
|
||||
return capability
|
||||
|
||||
except Exception as e:
|
||||
duration_ms = (time.time() - start_time) * 1000
|
||||
|
||||
logger.warning(
|
||||
f"⚠️ API能力检测失败: {api_identifier}, 错误: {e}, "
|
||||
f"将使用提示词注入模式"
|
||||
)
|
||||
|
||||
capability = APICapability(
|
||||
supports_function_calling=False,
|
||||
tested_at=datetime.now(),
|
||||
test_duration_ms=duration_ms,
|
||||
error_message=str(e)
|
||||
)
|
||||
|
||||
# 缓存失败结果(避免重复测试)
|
||||
async with self._cache_lock:
|
||||
self._capability_cache[api_identifier] = capability
|
||||
|
||||
return capability
|
||||
|
||||
def _is_function_calling_response(self, response: Any) -> bool:
|
||||
"""
|
||||
判断响应是否是Function Calling格式
|
||||
|
||||
Args:
|
||||
response: API响应
|
||||
|
||||
Returns:
|
||||
是否支持Function Calling
|
||||
"""
|
||||
|
||||
try:
|
||||
# 检查字典格式
|
||||
if isinstance(response, dict):
|
||||
message = response.get("choices", [{}])[0].get("message", {})
|
||||
return "tool_calls" in message or "function_call" in message
|
||||
|
||||
# 检查对象格式(OpenAI SDK)
|
||||
if hasattr(response, "choices"):
|
||||
message = response.choices[0].message
|
||||
return hasattr(message, "tool_calls") or hasattr(message, "function_call")
|
||||
|
||||
return False
|
||||
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
async def call_with_fallback(
|
||||
self,
|
||||
api_identifier: str,
|
||||
tools: List[Dict[str, Any]],
|
||||
user_message: str,
|
||||
call_function: callable,
|
||||
test_function: Optional[callable] = None
|
||||
) -> ToolCallResult:
|
||||
"""
|
||||
带降级策略的工具调用
|
||||
|
||||
Args:
|
||||
api_identifier: API标识符
|
||||
tools: MCP工具列表
|
||||
user_message: 用户消息
|
||||
call_function: 实际调用API的函数
|
||||
test_function: 可选的测试函数
|
||||
|
||||
Returns:
|
||||
工具调用结果
|
||||
"""
|
||||
|
||||
# 获取适配器
|
||||
adapter = await self.get_adapter(api_identifier, test_function)
|
||||
|
||||
# 首次尝试
|
||||
try:
|
||||
if adapter.supports_native_tools():
|
||||
# Function Calling模式
|
||||
logger.info("🚀 尝试使用Function Calling模式")
|
||||
result = await self._try_function_calling(
|
||||
tools, user_message, call_function, adapter
|
||||
)
|
||||
else:
|
||||
# 提示词注入模式
|
||||
logger.info("🚀 使用提示词注入模式")
|
||||
result = await self._try_prompt_injection(
|
||||
tools, user_message, call_function, adapter
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 工具调用失败: {e}")
|
||||
|
||||
# 自动降级
|
||||
if self._enable_auto_fallback and adapter.supports_native_tools():
|
||||
logger.warning("⚠️ Function Calling失败,降级到提示词注入模式")
|
||||
|
||||
# 更新缓存,标记为不支持
|
||||
async with self._cache_lock:
|
||||
self._capability_cache[api_identifier] = APICapability(
|
||||
supports_function_calling=False,
|
||||
tested_at=datetime.now(),
|
||||
test_duration_ms=0,
|
||||
error_message=str(e)
|
||||
)
|
||||
|
||||
# 使用提示词注入重试
|
||||
fallback_adapter = self.adapters[AdapterType.PROMPT_INJECTION]
|
||||
return await self._try_prompt_injection(
|
||||
tools, user_message, call_function, fallback_adapter
|
||||
)
|
||||
|
||||
raise
|
||||
|
||||
async def _try_function_calling(
|
||||
self,
|
||||
tools: List[Dict[str, Any]],
|
||||
user_message: str,
|
||||
call_function: callable,
|
||||
adapter: FunctionCallingAdapter
|
||||
) -> ToolCallResult:
|
||||
"""尝试Function Calling模式"""
|
||||
|
||||
# Function Calling不需要修改提示词
|
||||
response = await call_function(
|
||||
message=user_message,
|
||||
tools_param=tools,
|
||||
tool_choice_param="auto"
|
||||
)
|
||||
|
||||
return adapter.parse_tool_calls(response)
|
||||
|
||||
async def _try_prompt_injection(
|
||||
self,
|
||||
tools: List[Dict[str, Any]],
|
||||
user_message: str,
|
||||
call_function: callable,
|
||||
adapter: PromptInjectionAdapter
|
||||
) -> ToolCallResult:
|
||||
"""尝试提示词注入模式"""
|
||||
|
||||
# 注入工具到提示词
|
||||
enhanced_prompt = adapter.format_tools_for_prompt(tools, user_message)
|
||||
|
||||
# 调用API(不传tools参数)
|
||||
response = await call_function(
|
||||
message=enhanced_prompt,
|
||||
tools_param=None,
|
||||
tool_choice_param=None
|
||||
)
|
||||
|
||||
# 从文本响应中解析工具调用
|
||||
return adapter.parse_tool_calls(response)
|
||||
|
||||
def clear_cache(self, api_identifier: Optional[str] = None):
|
||||
"""
|
||||
清理能力缓存
|
||||
|
||||
Args:
|
||||
api_identifier: 可选,只清理特定API的缓存
|
||||
"""
|
||||
if api_identifier:
|
||||
if api_identifier in self._capability_cache:
|
||||
del self._capability_cache[api_identifier]
|
||||
logger.info(f"🧹 已清理API能力缓存: {api_identifier}")
|
||||
else:
|
||||
self._capability_cache.clear()
|
||||
logger.info("🧹 已清理所有API能力缓存")
|
||||
|
||||
def get_cache_stats(self) -> Dict[str, Any]:
|
||||
"""获取缓存统计信息"""
|
||||
return {
|
||||
"total_cached": len(self._capability_cache),
|
||||
"cache_ttl_hours": self._cache_ttl.total_seconds() / 3600,
|
||||
"cached_apis": [
|
||||
{
|
||||
"api_identifier": api_id,
|
||||
"supports_fc": cap.supports_function_calling,
|
||||
"tested_at": cap.tested_at.isoformat(),
|
||||
"test_duration_ms": cap.test_duration_ms
|
||||
}
|
||||
for api_id, cap in self._capability_cache.items()
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
# 全局单例
|
||||
universal_mcp_adapter = UniversalMCPAdapter()
|
||||
@@ -4,6 +4,7 @@ from openai import AsyncOpenAI
|
||||
from anthropic import AsyncAnthropic
|
||||
from app.config import settings as app_settings
|
||||
from app.logger import get_logger
|
||||
from app.mcp.adapters import UniversalMCPAdapter, PromptInjectionAdapter
|
||||
import httpx
|
||||
import json
|
||||
import hashlib
|
||||
@@ -118,7 +119,8 @@ class AIService:
|
||||
api_base_url: Optional[str] = None,
|
||||
default_model: Optional[str] = None,
|
||||
default_temperature: Optional[float] = None,
|
||||
default_max_tokens: Optional[int] = None
|
||||
default_max_tokens: Optional[int] = None,
|
||||
enable_mcp_adapter: bool = True
|
||||
):
|
||||
"""
|
||||
初始化AI客户端(优化并发性能)
|
||||
@@ -137,6 +139,15 @@ class AIService:
|
||||
self.default_temperature = default_temperature or app_settings.default_temperature
|
||||
self.default_max_tokens = default_max_tokens or app_settings.default_max_tokens
|
||||
|
||||
# 初始化MCP适配器
|
||||
self.enable_mcp_adapter = enable_mcp_adapter
|
||||
if enable_mcp_adapter:
|
||||
self.mcp_adapter = UniversalMCPAdapter()
|
||||
logger.info("✅ MCP通用适配器已启用")
|
||||
else:
|
||||
self.mcp_adapter = None
|
||||
logger.info("⚠️ MCP适配器已禁用")
|
||||
|
||||
# 初始化OpenAI客户端(使用HTTP客户端池)
|
||||
openai_key = api_key if api_provider == "openai" else app_settings.openai_api_key
|
||||
if openai_key:
|
||||
@@ -396,7 +407,7 @@ class AIService:
|
||||
tools: Optional[List[Dict[str, Any]]] = None,
|
||||
tool_choice: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""使用OpenAI生成文本(支持工具调用)"""
|
||||
"""使用OpenAI生成文本(支持工具调用,集成MCP适配器)"""
|
||||
if not self.openai_http_client:
|
||||
raise ValueError("OpenAI客户端未初始化,请检查API key配置")
|
||||
|
||||
@@ -405,8 +416,101 @@ class AIService:
|
||||
messages.append({"role": "system", "content": system_prompt})
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
|
||||
# 如果启用了MCP适配器且有工具,使用适配器处理
|
||||
if self.enable_mcp_adapter and self.mcp_adapter and tools:
|
||||
logger.info(f"🎯 使用MCP适配器处理工具调用")
|
||||
|
||||
# 生成API标识符
|
||||
api_identifier = f"openai_{self.openai_base_url or 'default'}"
|
||||
|
||||
# 定义API调用函数
|
||||
async def call_api(message: str, tools_param: Optional[List] = None, tool_choice_param: Optional[str] = None):
|
||||
"""实际调用OpenAI API的函数"""
|
||||
call_messages = messages.copy()
|
||||
call_messages[-1]["content"] = message
|
||||
|
||||
url = f"{self.openai_base_url}/chat/completions"
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.openai_api_key}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
payload = {
|
||||
"model": model,
|
||||
"messages": call_messages,
|
||||
"temperature": temperature,
|
||||
"max_tokens": max_tokens
|
||||
}
|
||||
|
||||
# 只在tools_param不为None时添加工具参数
|
||||
if tools_param is not None:
|
||||
# 清理工具定义,移除$schema字段(某些API不支持)
|
||||
cleaned_tools = []
|
||||
for tool in tools_param:
|
||||
cleaned_tool = tool.copy()
|
||||
if "function" in cleaned_tool and "parameters" in cleaned_tool["function"]:
|
||||
params = cleaned_tool["function"]["parameters"].copy()
|
||||
# 移除$schema字段
|
||||
params.pop("$schema", None)
|
||||
cleaned_tool["function"]["parameters"] = params
|
||||
cleaned_tools.append(cleaned_tool)
|
||||
|
||||
payload["tools"] = cleaned_tools
|
||||
if tool_choice_param:
|
||||
payload["tool_choice"] = tool_choice_param
|
||||
|
||||
response = await self.openai_http_client.post(url, headers=headers, json=payload)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
# 定义测试函数(检测API是否支持Function Calling)
|
||||
async def test_fc():
|
||||
"""测试Function Calling支持"""
|
||||
test_tools = [{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "test_function",
|
||||
"description": "测试函数",
|
||||
"parameters": {"type": "object", "properties": {}}
|
||||
}
|
||||
}]
|
||||
try:
|
||||
result = await call_api("测试", tools_param=test_tools, tool_choice_param="none")
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.debug(f"Function Calling测试失败: {e}")
|
||||
raise
|
||||
|
||||
try:
|
||||
# 使用适配器处理(自动检测、降级、缓存)
|
||||
result = await self.mcp_adapter.call_with_fallback(
|
||||
api_identifier=api_identifier,
|
||||
tools=tools,
|
||||
user_message=prompt,
|
||||
call_function=call_api,
|
||||
test_function=test_fc
|
||||
)
|
||||
|
||||
# 转换结果格式
|
||||
if result.has_tool_calls:
|
||||
return {
|
||||
"tool_calls": result.tool_calls,
|
||||
"content": result.raw_response,
|
||||
"finish_reason": "tool_calls"
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"content": result.raw_response,
|
||||
"finish_reason": "stop"
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ MCP适配器调用失败: {str(e)}")
|
||||
# 降级到原始实现
|
||||
logger.warning("⚠️ 降级到原始OpenAI调用")
|
||||
|
||||
# 原始实现(无适配器或降级)
|
||||
try:
|
||||
logger.info(f"🔵 开始调用OpenAI API(支持工具调用)")
|
||||
logger.info(f"🔵 开始调用OpenAI API(原始模式)")
|
||||
logger.info(f" - 模型: {model}")
|
||||
logger.info(f" - 工具数量: {len(tools) if tools else 0}")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user