feat: 优化MCP工具调用体验并集成通用适配器
- 静默检查MCP工具可用性,支持提示词注入调用mcp - 集成UniversalMCPAdapter,支持自动API能力检测和智能降级 - 新增MCP适配器配置项,增强系统兼容性和健壮性
This commit is contained in:
@@ -68,9 +68,17 @@ async def world_building_generator(
|
|||||||
reference_materials = ""
|
reference_materials = ""
|
||||||
if enable_mcp and user_id:
|
if enable_mcp and user_id:
|
||||||
try:
|
try:
|
||||||
|
# 先静默检查是否有可用工具
|
||||||
|
from app.services.mcp_tool_service import mcp_tool_service
|
||||||
|
available_tools = await mcp_tool_service.get_user_enabled_tools(
|
||||||
|
user_id=user_id,
|
||||||
|
db_session=db
|
||||||
|
)
|
||||||
|
|
||||||
|
# 只有在真正有可用工具时才显示消息和调用
|
||||||
|
if available_tools:
|
||||||
yield await SSEResponse.send_progress("🔍 尝试使用MCP工具收集参考资料...", 18)
|
yield await SSEResponse.send_progress("🔍 尝试使用MCP工具收集参考资料...", 18)
|
||||||
|
|
||||||
# 直接调用MCP增强的AI,内部会自动检查和加载工具
|
|
||||||
# 构建资料收集提示词
|
# 构建资料收集提示词
|
||||||
planning_prompt = f"""你正在为小说《{title}》设计世界观。
|
planning_prompt = f"""你正在为小说《{title}》设计世界观。
|
||||||
|
|
||||||
@@ -109,7 +117,11 @@ async def world_building_generator(
|
|||||||
)
|
)
|
||||||
reference_materials = planning_result.get("content", "")
|
reference_materials = planning_result.get("content", "")
|
||||||
else:
|
else:
|
||||||
yield await SSEResponse.send_progress("ℹ️ 未使用MCP工具(无可用工具或不需要)", 25)
|
# 有工具但未使用
|
||||||
|
logger.debug("MCP工具可用但AI未选择使用")
|
||||||
|
else:
|
||||||
|
# 没有可用工具,静默跳过
|
||||||
|
logger.debug(f"用户 {user_id} 未启用MCP工具,跳过MCP增强")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"MCP工具调用失败(降级处理): {e}")
|
logger.warning(f"MCP工具调用失败(降级处理): {e}")
|
||||||
@@ -325,6 +337,15 @@ async def characters_generator(
|
|||||||
character_reference_materials = ""
|
character_reference_materials = ""
|
||||||
if enable_mcp and user_id:
|
if enable_mcp and user_id:
|
||||||
try:
|
try:
|
||||||
|
# 先静默检查是否有可用工具
|
||||||
|
from app.services.mcp_tool_service import mcp_tool_service
|
||||||
|
available_tools = await mcp_tool_service.get_user_enabled_tools(
|
||||||
|
user_id=user_id,
|
||||||
|
db_session=db
|
||||||
|
)
|
||||||
|
|
||||||
|
# 只有在真正有可用工具时才显示消息和调用
|
||||||
|
if available_tools:
|
||||||
yield await SSEResponse.send_progress("🔍 尝试使用MCP工具收集角色参考资料...", 8)
|
yield await SSEResponse.send_progress("🔍 尝试使用MCP工具收集角色参考资料...", 8)
|
||||||
|
|
||||||
# 构建角色资料收集提示词
|
# 构建角色资料收集提示词
|
||||||
@@ -366,7 +387,11 @@ async def characters_generator(
|
|||||||
)
|
)
|
||||||
character_reference_materials = planning_result.get("content", "")
|
character_reference_materials = planning_result.get("content", "")
|
||||||
else:
|
else:
|
||||||
yield await SSEResponse.send_progress("ℹ️ 未使用MCP工具(无可用工具或不需要)", 12)
|
# 有工具但未使用
|
||||||
|
logger.debug("MCP工具可用但AI未选择使用")
|
||||||
|
else:
|
||||||
|
# 没有可用工具,静默跳过
|
||||||
|
logger.debug(f"用户 {user_id} 未启用MCP工具,跳过MCP增强")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"MCP工具调用失败(降级处理): {e}")
|
logger.warning(f"MCP工具调用失败(降级处理): {e}")
|
||||||
|
|||||||
@@ -77,6 +77,11 @@ class Settings(BaseSettings):
|
|||||||
default_temperature: float = 0.7
|
default_temperature: float = 0.7
|
||||||
default_max_tokens: int = 2000
|
default_max_tokens: int = 2000
|
||||||
|
|
||||||
|
# MCP适配器配置
|
||||||
|
enable_mcp_adapter: bool = True # 是否启用MCP适配器(自动检测API能力)
|
||||||
|
mcp_adapter_cache_ttl_hours: int = 24 # API能力检测缓存时长(小时)
|
||||||
|
mcp_adapter_auto_fallback: bool = True # 是否启用自动降级(FC失败时切换到提示词注入)
|
||||||
|
|
||||||
# LinuxDO OAuth2 配置
|
# LinuxDO OAuth2 配置
|
||||||
LINUXDO_CLIENT_ID: Optional[str] = None
|
LINUXDO_CLIENT_ID: Optional[str] = None
|
||||||
LINUXDO_CLIENT_SECRET: Optional[str] = None
|
LINUXDO_CLIENT_SECRET: Optional[str] = None
|
||||||
|
|||||||
@@ -0,0 +1,14 @@
|
|||||||
|
"""MCP适配器模块 - 支持多种AI API的工具调用方式"""
|
||||||
|
|
||||||
|
from .base import BaseMCPAdapter, AdapterType
|
||||||
|
from .prompt_injection import PromptInjectionAdapter
|
||||||
|
from .function_calling import FunctionCallingAdapter
|
||||||
|
from .universal import UniversalMCPAdapter
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"BaseMCPAdapter",
|
||||||
|
"AdapterType",
|
||||||
|
"PromptInjectionAdapter",
|
||||||
|
"FunctionCallingAdapter",
|
||||||
|
"UniversalMCPAdapter",
|
||||||
|
]
|
||||||
@@ -0,0 +1,89 @@
|
|||||||
|
"""MCP适配器基类"""
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Dict, Any, List, Optional
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
|
||||||
|
class AdapterType(Enum):
|
||||||
|
"""适配器类型"""
|
||||||
|
FUNCTION_CALLING = "function_calling" # 标准Function Calling
|
||||||
|
PROMPT_INJECTION = "prompt_injection" # 提示词注入
|
||||||
|
REACT = "react" # ReAct模式
|
||||||
|
XML = "xml" # XML标记
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ToolCallResult:
|
||||||
|
"""工具调用结果"""
|
||||||
|
tool_calls: List[Dict[str, Any]] # 解析出的工具调用
|
||||||
|
raw_response: str # 原始AI响应
|
||||||
|
has_tool_calls: bool # 是否包含工具调用
|
||||||
|
needs_continuation: bool = False # 是否需要继续对话
|
||||||
|
|
||||||
|
|
||||||
|
class BaseMCPAdapter(ABC):
|
||||||
|
"""MCP适配器基类"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.adapter_type: AdapterType = AdapterType.PROMPT_INJECTION
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def format_tools_for_prompt(
|
||||||
|
self,
|
||||||
|
tools: List[Dict[str, Any]],
|
||||||
|
user_message: str
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
将工具列表格式化为提示词
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tools: MCP工具列表
|
||||||
|
user_message: 用户消息
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
格式化后的提示词
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def parse_tool_calls(self, ai_response: str) -> ToolCallResult:
|
||||||
|
"""
|
||||||
|
从AI响应中解析工具调用
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ai_response: AI的原始响应
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
解析结果
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def build_continuation_prompt(
|
||||||
|
self,
|
||||||
|
original_message: str,
|
||||||
|
ai_response: str,
|
||||||
|
tool_results: List[Dict[str, Any]]
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
构建包含工具结果的继续对话提示词
|
||||||
|
|
||||||
|
Args:
|
||||||
|
original_message: 原始用户消息
|
||||||
|
ai_response: AI响应
|
||||||
|
tool_results: 工具执行结果
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
继续对话的提示词
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def supports_native_tools(self) -> bool:
|
||||||
|
"""是否支持原生工具调用(如Function Calling)"""
|
||||||
|
return False
|
||||||
|
|
||||||
|
def get_adapter_type(self) -> AdapterType:
|
||||||
|
"""获取适配器类型"""
|
||||||
|
return self.adapter_type
|
||||||
@@ -0,0 +1,171 @@
|
|||||||
|
"""Function Calling适配器 - 支持原生Function Calling的API"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
from typing import Dict, Any, List
|
||||||
|
from app.mcp.adapters.base import BaseMCPAdapter, AdapterType, ToolCallResult
|
||||||
|
from app.logger import get_logger
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class FunctionCallingAdapter(BaseMCPAdapter):
|
||||||
|
"""Function Calling适配器 - 用于支持原生工具调用的AI API(如OpenAI)"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.adapter_type = AdapterType.FUNCTION_CALLING
|
||||||
|
|
||||||
|
def supports_native_tools(self) -> bool:
|
||||||
|
"""支持原生工具调用"""
|
||||||
|
return True
|
||||||
|
|
||||||
|
def format_tools_for_prompt(
|
||||||
|
self,
|
||||||
|
tools: List[Dict[str, Any]],
|
||||||
|
user_message: str
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Function Calling模式下,工具通过API参数传递,不需要修改提示词
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
原始用户消息
|
||||||
|
"""
|
||||||
|
return user_message
|
||||||
|
|
||||||
|
def get_tools_for_api(self, tools: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
获取适用于API的工具格式
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tools: MCP工具列表
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
适用于OpenAI Function Calling的工具格式
|
||||||
|
"""
|
||||||
|
return tools
|
||||||
|
|
||||||
|
def parse_tool_calls(self, ai_response: Any) -> ToolCallResult:
|
||||||
|
"""
|
||||||
|
从AI响应中解析工具调用(Function Calling格式)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ai_response: AI响应对象(通常是OpenAI的ChatCompletion对象)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
解析结果
|
||||||
|
"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 处理不同类型的响应
|
||||||
|
if isinstance(ai_response, dict):
|
||||||
|
# 字典格式(OpenAI API响应)
|
||||||
|
message = ai_response.get("choices", [{}])[0].get("message", {})
|
||||||
|
tool_calls = message.get("tool_calls", [])
|
||||||
|
content = message.get("content", "")
|
||||||
|
|
||||||
|
elif hasattr(ai_response, "choices"):
|
||||||
|
# 对象格式(OpenAI SDK响应)
|
||||||
|
message = ai_response.choices[0].message
|
||||||
|
tool_calls = getattr(message, "tool_calls", None) or []
|
||||||
|
content = getattr(message, "content", "") or ""
|
||||||
|
|
||||||
|
# 转换为字典格式
|
||||||
|
if tool_calls:
|
||||||
|
tool_calls = [
|
||||||
|
{
|
||||||
|
"id": tc.id,
|
||||||
|
"type": tc.type,
|
||||||
|
"function": {
|
||||||
|
"name": tc.function.name,
|
||||||
|
"arguments": tc.function.arguments
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for tc in tool_calls
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
# 字符串格式(降级为文本响应)
|
||||||
|
return ToolCallResult(
|
||||||
|
tool_calls=[],
|
||||||
|
raw_response=str(ai_response),
|
||||||
|
has_tool_calls=False
|
||||||
|
)
|
||||||
|
|
||||||
|
has_tool_calls = len(tool_calls) > 0
|
||||||
|
|
||||||
|
if has_tool_calls:
|
||||||
|
logger.info(f"✅ Function Calling模式解析出 {len(tool_calls)} 个工具调用")
|
||||||
|
for tc in tool_calls:
|
||||||
|
logger.info(f" - {tc['function']['name']}")
|
||||||
|
|
||||||
|
return ToolCallResult(
|
||||||
|
tool_calls=tool_calls,
|
||||||
|
raw_response=content or "",
|
||||||
|
has_tool_calls=has_tool_calls,
|
||||||
|
needs_continuation=has_tool_calls
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"❌ 解析Function Calling响应失败: {e}", exc_info=True)
|
||||||
|
return ToolCallResult(
|
||||||
|
tool_calls=[],
|
||||||
|
raw_response=str(ai_response),
|
||||||
|
has_tool_calls=False
|
||||||
|
)
|
||||||
|
|
||||||
|
def build_continuation_prompt(
|
||||||
|
self,
|
||||||
|
original_message: str,
|
||||||
|
ai_response: str,
|
||||||
|
tool_results: List[Dict[str, Any]]
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
构建包含工具结果的继续对话提示词
|
||||||
|
|
||||||
|
在Function Calling模式下,这通常不需要,因为工具结果会作为消息历史的一部分
|
||||||
|
"""
|
||||||
|
# Function Calling模式下通常通过消息历史传递工具结果
|
||||||
|
# 这里提供一个降级方案
|
||||||
|
results_text = "\n\n".join([
|
||||||
|
f"工具 {r['name']} 的结果:\n{r['content']}"
|
||||||
|
for r in tool_results
|
||||||
|
])
|
||||||
|
|
||||||
|
return f"{original_message}\n\n工具执行结果:\n{results_text}\n\n请基于以上工具结果回答用户的问题。"
|
||||||
|
|
||||||
|
def build_messages_with_tool_results(
|
||||||
|
self,
|
||||||
|
messages: List[Dict[str, Any]],
|
||||||
|
tool_calls: List[Dict[str, Any]],
|
||||||
|
tool_results: List[Dict[str, Any]]
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
构建包含工具结果的消息历史(Function Calling标准格式)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: 原始消息历史
|
||||||
|
tool_calls: AI的工具调用
|
||||||
|
tool_results: 工具执行结果
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
更新后的消息历史
|
||||||
|
"""
|
||||||
|
|
||||||
|
new_messages = messages.copy()
|
||||||
|
|
||||||
|
# 添加助手的工具调用消息
|
||||||
|
new_messages.append({
|
||||||
|
"role": "assistant",
|
||||||
|
"content": None,
|
||||||
|
"tool_calls": tool_calls
|
||||||
|
})
|
||||||
|
|
||||||
|
# 添加工具结果消息
|
||||||
|
for result in tool_results:
|
||||||
|
new_messages.append({
|
||||||
|
"role": "tool",
|
||||||
|
"tool_call_id": result.get("tool_call_id", ""),
|
||||||
|
"name": result.get("name", ""),
|
||||||
|
"content": result.get("content", "")
|
||||||
|
})
|
||||||
|
|
||||||
|
return new_messages
|
||||||
@@ -0,0 +1,274 @@
|
|||||||
|
"""提示词注入适配器 - 最通用的MCP工具调用方式"""
|
||||||
|
|
||||||
|
import re
|
||||||
|
import json
|
||||||
|
from typing import Dict, Any, List
|
||||||
|
from app.mcp.adapters.base import BaseMCPAdapter, AdapterType, ToolCallResult
|
||||||
|
from app.logger import get_logger
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class PromptInjectionAdapter(BaseMCPAdapter):
|
||||||
|
"""提示词注入适配器 - 将工具转换为文本描述,通过提示词引导AI调用"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.adapter_type = AdapterType.PROMPT_INJECTION
|
||||||
|
|
||||||
|
def format_tools_for_prompt(
|
||||||
|
self,
|
||||||
|
tools: List[Dict[str, Any]],
|
||||||
|
user_message: str
|
||||||
|
) -> str:
|
||||||
|
"""将工具列表注入到提示词中"""
|
||||||
|
|
||||||
|
if not tools:
|
||||||
|
return user_message
|
||||||
|
|
||||||
|
# 格式化工具描述
|
||||||
|
tool_descriptions = self._format_tools_as_text(tools)
|
||||||
|
|
||||||
|
# 构建增强的提示词
|
||||||
|
enhanced_prompt = f"""你现在可以使用以下工具来帮助回答用户的问题。
|
||||||
|
|
||||||
|
## 可用工具
|
||||||
|
|
||||||
|
{tool_descriptions}
|
||||||
|
|
||||||
|
## 工具使用说明
|
||||||
|
|
||||||
|
当你需要使用工具时,请按以下XML格式输出(可以一次调用多个工具):
|
||||||
|
|
||||||
|
<tool_calls>
|
||||||
|
<tool_call>
|
||||||
|
<tool_name>工具名称</tool_name>
|
||||||
|
<arguments>
|
||||||
|
{{
|
||||||
|
"参数名1": "参数值1",
|
||||||
|
"参数名2": "参数值2"
|
||||||
|
}}
|
||||||
|
</arguments>
|
||||||
|
</tool_call>
|
||||||
|
</tool_calls>
|
||||||
|
|
||||||
|
## 重要提示
|
||||||
|
|
||||||
|
1. 只有在确实需要使用工具时才调用工具
|
||||||
|
2. 参数必须是有效的JSON格式
|
||||||
|
3. 仔细检查参数是否符合工具的要求
|
||||||
|
4. 可以在一个<tool_calls>标签内包含多个<tool_call>
|
||||||
|
5. 调用工具后,你会收到工具的执行结果,然后需要基于结果继续回答
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
用户问题:{user_message}
|
||||||
|
|
||||||
|
请分析问题,判断是否需要使用工具。如果需要,先输出工具调用,然后等待结果。如果不需要,直接回答问题。"""
|
||||||
|
|
||||||
|
return enhanced_prompt
|
||||||
|
|
||||||
|
def _format_tools_as_text(self, tools: List[Dict[str, Any]]) -> str:
|
||||||
|
"""将工具格式化为可读的文本描述"""
|
||||||
|
lines = []
|
||||||
|
|
||||||
|
for i, tool in enumerate(tools, 1):
|
||||||
|
func = tool.get("function", {})
|
||||||
|
name = func.get("name", "unknown")
|
||||||
|
description = func.get("description", "无描述")
|
||||||
|
parameters = func.get("parameters", {})
|
||||||
|
|
||||||
|
lines.append(f"### {i}. {name}")
|
||||||
|
lines.append(f"**描述**: {description}")
|
||||||
|
lines.append("")
|
||||||
|
|
||||||
|
# 格式化参数信息
|
||||||
|
if parameters and "properties" in parameters:
|
||||||
|
lines.append("**参数**:")
|
||||||
|
properties = parameters.get("properties", {})
|
||||||
|
required = parameters.get("required", [])
|
||||||
|
|
||||||
|
for param_name, param_info in properties.items():
|
||||||
|
param_type = param_info.get("type", "string")
|
||||||
|
param_desc = param_info.get("description", "")
|
||||||
|
is_required = "必填" if param_name in required else "可选"
|
||||||
|
|
||||||
|
lines.append(f" - `{param_name}` ({param_type}, {is_required}): {param_desc}")
|
||||||
|
lines.append("")
|
||||||
|
|
||||||
|
# 添加示例
|
||||||
|
if "example" in func:
|
||||||
|
lines.append(f"**示例**: {json.dumps(func['example'], ensure_ascii=False)}")
|
||||||
|
lines.append("")
|
||||||
|
|
||||||
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
def parse_tool_calls(self, ai_response) -> ToolCallResult:
|
||||||
|
"""从AI响应中解析工具调用"""
|
||||||
|
|
||||||
|
tool_calls = []
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 处理不同类型的响应
|
||||||
|
if isinstance(ai_response, dict):
|
||||||
|
# 如果是字典,提取content字段
|
||||||
|
ai_response = ai_response.get("choices", [{}])[0].get("message", {}).get("content", "")
|
||||||
|
if not ai_response:
|
||||||
|
return ToolCallResult(
|
||||||
|
tool_calls=[],
|
||||||
|
raw_response="",
|
||||||
|
has_tool_calls=False
|
||||||
|
)
|
||||||
|
elif not isinstance(ai_response, str):
|
||||||
|
# 转换为字符串
|
||||||
|
ai_response = str(ai_response)
|
||||||
|
|
||||||
|
# 使用正则提取 <tool_calls> 标签内容
|
||||||
|
tool_calls_match = re.search(
|
||||||
|
r'<tool_calls>(.*?)</tool_calls>',
|
||||||
|
ai_response,
|
||||||
|
re.DOTALL | re.IGNORECASE
|
||||||
|
)
|
||||||
|
|
||||||
|
if not tool_calls_match:
|
||||||
|
# 没有找到工具调用
|
||||||
|
return ToolCallResult(
|
||||||
|
tool_calls=[],
|
||||||
|
raw_response=ai_response,
|
||||||
|
has_tool_calls=False
|
||||||
|
)
|
||||||
|
|
||||||
|
tool_calls_content = tool_calls_match.group(1)
|
||||||
|
|
||||||
|
# 提取所有 <tool_call> 标签
|
||||||
|
tool_call_pattern = r'<tool_call>(.*?)</tool_call>'
|
||||||
|
tool_call_matches = re.findall(
|
||||||
|
tool_call_pattern,
|
||||||
|
tool_calls_content,
|
||||||
|
re.DOTALL | re.IGNORECASE
|
||||||
|
)
|
||||||
|
|
||||||
|
for i, tool_call_content in enumerate(tool_call_matches):
|
||||||
|
# 提取工具名称
|
||||||
|
name_match = re.search(
|
||||||
|
r'<tool_name>(.*?)</tool_name>',
|
||||||
|
tool_call_content,
|
||||||
|
re.DOTALL | re.IGNORECASE
|
||||||
|
)
|
||||||
|
|
||||||
|
# 提取参数
|
||||||
|
args_match = re.search(
|
||||||
|
r'<arguments>(.*?)</arguments>',
|
||||||
|
tool_call_content,
|
||||||
|
re.DOTALL | re.IGNORECASE
|
||||||
|
)
|
||||||
|
|
||||||
|
if name_match and args_match:
|
||||||
|
tool_name = name_match.group(1).strip()
|
||||||
|
arguments_str = args_match.group(1).strip()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 解析JSON参数
|
||||||
|
arguments = json.loads(arguments_str)
|
||||||
|
|
||||||
|
# 构建标准格式的工具调用
|
||||||
|
tool_calls.append({
|
||||||
|
"id": f"call_{i}",
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": tool_name,
|
||||||
|
"arguments": json.dumps(arguments, ensure_ascii=False)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
logger.info(f"✅ 解析工具调用: {tool_name}")
|
||||||
|
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
logger.error(f"❌ 解析工具参数失败: {arguments_str}, 错误: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
has_tool_calls = len(tool_calls) > 0
|
||||||
|
|
||||||
|
if has_tool_calls:
|
||||||
|
logger.info(f"✅ 从响应中解析出 {len(tool_calls)} 个工具调用")
|
||||||
|
|
||||||
|
return ToolCallResult(
|
||||||
|
tool_calls=tool_calls,
|
||||||
|
raw_response=ai_response,
|
||||||
|
has_tool_calls=has_tool_calls,
|
||||||
|
needs_continuation=has_tool_calls
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"❌ 解析工具调用失败: {e}", exc_info=True)
|
||||||
|
return ToolCallResult(
|
||||||
|
tool_calls=[],
|
||||||
|
raw_response=ai_response,
|
||||||
|
has_tool_calls=False
|
||||||
|
)
|
||||||
|
|
||||||
|
def build_continuation_prompt(
|
||||||
|
self,
|
||||||
|
original_message: str,
|
||||||
|
ai_response: str,
|
||||||
|
tool_results: List[Dict[str, Any]]
|
||||||
|
) -> str:
|
||||||
|
"""构建包含工具结果的继续对话提示词"""
|
||||||
|
|
||||||
|
# 格式化工具结果
|
||||||
|
results_text = self._format_tool_results(tool_results)
|
||||||
|
|
||||||
|
continuation = f"""你之前尝试使用工具来回答用户的问题。
|
||||||
|
|
||||||
|
原始问题:{original_message}
|
||||||
|
|
||||||
|
你的工具调用:
|
||||||
|
{self._extract_tool_calls_text(ai_response)}
|
||||||
|
|
||||||
|
工具执行结果:
|
||||||
|
{results_text}
|
||||||
|
|
||||||
|
现在,请基于这些工具的执行结果,给出完整、详细的回答。不要重复调用工具,直接使用已有的结果来回答用户的问题。"""
|
||||||
|
|
||||||
|
return continuation
|
||||||
|
|
||||||
|
def _format_tool_results(self, tool_results: List[Dict[str, Any]]) -> str:
|
||||||
|
"""格式化工具结果为可读文本"""
|
||||||
|
lines = []
|
||||||
|
|
||||||
|
for i, result in enumerate(tool_results, 1):
|
||||||
|
tool_name = result.get("name", "unknown")
|
||||||
|
success = result.get("success", False)
|
||||||
|
content = result.get("content", "")
|
||||||
|
|
||||||
|
status = "✅ 成功" if success else "❌ 失败"
|
||||||
|
lines.append(f"{i}. {tool_name} - {status}")
|
||||||
|
|
||||||
|
if success:
|
||||||
|
# 尝试美化JSON内容
|
||||||
|
try:
|
||||||
|
if isinstance(content, str):
|
||||||
|
content_obj = json.loads(content)
|
||||||
|
content = json.dumps(content_obj, ensure_ascii=False, indent=2)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
lines.append(f"```\n{content}\n```")
|
||||||
|
else:
|
||||||
|
error = result.get("error", "未知错误")
|
||||||
|
lines.append(f"错误信息: {error}")
|
||||||
|
|
||||||
|
lines.append("")
|
||||||
|
|
||||||
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
def _extract_tool_calls_text(self, ai_response: str) -> str:
|
||||||
|
"""从AI响应中提取工具调用部分的文本"""
|
||||||
|
match = re.search(
|
||||||
|
r'<tool_calls>(.*?)</tool_calls>',
|
||||||
|
ai_response,
|
||||||
|
re.DOTALL | re.IGNORECASE
|
||||||
|
)
|
||||||
|
|
||||||
|
if match:
|
||||||
|
return match.group(0)
|
||||||
|
return "(未找到工具调用)"
|
||||||
@@ -0,0 +1,353 @@
|
|||||||
|
"""通用MCP适配器 - 自动检测API能力并选择最佳适配器"""
|
||||||
|
|
||||||
|
import time
|
||||||
|
import asyncio
|
||||||
|
from typing import Dict, Any, List, Optional
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
from app.mcp.adapters.base import BaseMCPAdapter, AdapterType, ToolCallResult
|
||||||
|
from app.mcp.adapters.prompt_injection import PromptInjectionAdapter
|
||||||
|
from app.mcp.adapters.function_calling import FunctionCallingAdapter
|
||||||
|
from app.logger import get_logger
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class APICapability:
|
||||||
|
"""API能力检测结果"""
|
||||||
|
supports_function_calling: bool
|
||||||
|
tested_at: datetime
|
||||||
|
test_duration_ms: float
|
||||||
|
error_message: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class UniversalMCPAdapter:
|
||||||
|
"""
|
||||||
|
通用MCP适配器管理器
|
||||||
|
|
||||||
|
功能:
|
||||||
|
1. 自动检测API是否支持Function Calling
|
||||||
|
2. 缓存检测结果
|
||||||
|
3. 自动降级策略:FC失败时切换到提示词注入
|
||||||
|
4. 提供统一接口
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
cache_ttl_hours: int = 24,
|
||||||
|
enable_auto_fallback: bool = True
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
初始化通用适配器
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cache_ttl_hours: 能力检测缓存时长(小时)
|
||||||
|
enable_auto_fallback: 是否启用自动降级
|
||||||
|
"""
|
||||||
|
# 适配器实例
|
||||||
|
self.adapters = {
|
||||||
|
AdapterType.FUNCTION_CALLING: FunctionCallingAdapter(),
|
||||||
|
AdapterType.PROMPT_INJECTION: PromptInjectionAdapter()
|
||||||
|
}
|
||||||
|
|
||||||
|
# API能力缓存: {api_identifier: APICapability}
|
||||||
|
self._capability_cache: Dict[str, APICapability] = {}
|
||||||
|
self._cache_ttl = timedelta(hours=cache_ttl_hours)
|
||||||
|
self._cache_lock = asyncio.Lock()
|
||||||
|
|
||||||
|
# 配置
|
||||||
|
self._enable_auto_fallback = enable_auto_fallback
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"✅ UniversalMCPAdapter初始化完成 "
|
||||||
|
f"(缓存TTL={cache_ttl_hours}小时, 自动降级={'开启' if enable_auto_fallback else '关闭'})"
|
||||||
|
)
|
||||||
|
|
||||||
|
async def get_adapter(
|
||||||
|
self,
|
||||||
|
api_identifier: str,
|
||||||
|
test_function: Optional[callable] = None
|
||||||
|
) -> BaseMCPAdapter:
|
||||||
|
"""
|
||||||
|
获取适合当前API的适配器
|
||||||
|
|
||||||
|
Args:
|
||||||
|
api_identifier: API标识符(如"openai_official", "azure_openai"等)
|
||||||
|
test_function: 可选的测试函数,用于检测API能力
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
最适合的适配器实例
|
||||||
|
"""
|
||||||
|
|
||||||
|
# 检查缓存
|
||||||
|
capability = await self._get_cached_capability(api_identifier)
|
||||||
|
|
||||||
|
if capability is None and test_function:
|
||||||
|
# 缓存未命中,执行检测
|
||||||
|
capability = await self._detect_capability(api_identifier, test_function)
|
||||||
|
|
||||||
|
# 选择适配器
|
||||||
|
if capability and capability.supports_function_calling:
|
||||||
|
logger.info(f"🎯 使用Function Calling适配器: {api_identifier}")
|
||||||
|
return self.adapters[AdapterType.FUNCTION_CALLING]
|
||||||
|
else:
|
||||||
|
logger.info(f"🎯 使用提示词注入适配器: {api_identifier}")
|
||||||
|
return self.adapters[AdapterType.PROMPT_INJECTION]
|
||||||
|
|
||||||
|
async def _get_cached_capability(
|
||||||
|
self,
|
||||||
|
api_identifier: str
|
||||||
|
) -> Optional[APICapability]:
|
||||||
|
"""获取缓存的能力检测结果"""
|
||||||
|
|
||||||
|
async with self._cache_lock:
|
||||||
|
if api_identifier not in self._capability_cache:
|
||||||
|
return None
|
||||||
|
|
||||||
|
capability = self._capability_cache[api_identifier]
|
||||||
|
|
||||||
|
# 检查是否过期
|
||||||
|
if datetime.now() - capability.tested_at > self._cache_ttl:
|
||||||
|
logger.info(f"⏰ API能力缓存过期: {api_identifier}")
|
||||||
|
del self._capability_cache[api_identifier]
|
||||||
|
return None
|
||||||
|
|
||||||
|
logger.debug(f"🎯 API能力缓存命中: {api_identifier}")
|
||||||
|
return capability
|
||||||
|
|
||||||
|
async def _detect_capability(
|
||||||
|
self,
|
||||||
|
api_identifier: str,
|
||||||
|
test_function: callable
|
||||||
|
) -> APICapability:
|
||||||
|
"""
|
||||||
|
检测API能力
|
||||||
|
|
||||||
|
Args:
|
||||||
|
api_identifier: API标识符
|
||||||
|
test_function: 测试函数,应该尝试使用Function Calling
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
能力检测结果
|
||||||
|
"""
|
||||||
|
|
||||||
|
logger.info(f"🔍 开始检测API能力: {api_identifier}")
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 调用测试函数
|
||||||
|
result = await test_function()
|
||||||
|
|
||||||
|
# 判断是否成功
|
||||||
|
supports_fc = self._is_function_calling_response(result)
|
||||||
|
|
||||||
|
duration_ms = (time.time() - start_time) * 1000
|
||||||
|
|
||||||
|
capability = APICapability(
|
||||||
|
supports_function_calling=supports_fc,
|
||||||
|
tested_at=datetime.now(),
|
||||||
|
test_duration_ms=duration_ms
|
||||||
|
)
|
||||||
|
|
||||||
|
# 缓存结果
|
||||||
|
async with self._cache_lock:
|
||||||
|
self._capability_cache[api_identifier] = capability
|
||||||
|
|
||||||
|
status = "✅ 支持" if supports_fc else "❌ 不支持"
|
||||||
|
logger.info(
|
||||||
|
f"{status} Function Calling: {api_identifier} "
|
||||||
|
f"(耗时: {duration_ms:.2f}ms)"
|
||||||
|
)
|
||||||
|
|
||||||
|
return capability
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
duration_ms = (time.time() - start_time) * 1000
|
||||||
|
|
||||||
|
logger.warning(
|
||||||
|
f"⚠️ API能力检测失败: {api_identifier}, 错误: {e}, "
|
||||||
|
f"将使用提示词注入模式"
|
||||||
|
)
|
||||||
|
|
||||||
|
capability = APICapability(
|
||||||
|
supports_function_calling=False,
|
||||||
|
tested_at=datetime.now(),
|
||||||
|
test_duration_ms=duration_ms,
|
||||||
|
error_message=str(e)
|
||||||
|
)
|
||||||
|
|
||||||
|
# 缓存失败结果(避免重复测试)
|
||||||
|
async with self._cache_lock:
|
||||||
|
self._capability_cache[api_identifier] = capability
|
||||||
|
|
||||||
|
return capability
|
||||||
|
|
||||||
|
def _is_function_calling_response(self, response: Any) -> bool:
|
||||||
|
"""
|
||||||
|
判断响应是否是Function Calling格式
|
||||||
|
|
||||||
|
Args:
|
||||||
|
response: API响应
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
是否支持Function Calling
|
||||||
|
"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 检查字典格式
|
||||||
|
if isinstance(response, dict):
|
||||||
|
message = response.get("choices", [{}])[0].get("message", {})
|
||||||
|
return "tool_calls" in message or "function_call" in message
|
||||||
|
|
||||||
|
# 检查对象格式(OpenAI SDK)
|
||||||
|
if hasattr(response, "choices"):
|
||||||
|
message = response.choices[0].message
|
||||||
|
return hasattr(message, "tool_calls") or hasattr(message, "function_call")
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def call_with_fallback(
|
||||||
|
self,
|
||||||
|
api_identifier: str,
|
||||||
|
tools: List[Dict[str, Any]],
|
||||||
|
user_message: str,
|
||||||
|
call_function: callable,
|
||||||
|
test_function: Optional[callable] = None
|
||||||
|
) -> ToolCallResult:
|
||||||
|
"""
|
||||||
|
带降级策略的工具调用
|
||||||
|
|
||||||
|
Args:
|
||||||
|
api_identifier: API标识符
|
||||||
|
tools: MCP工具列表
|
||||||
|
user_message: 用户消息
|
||||||
|
call_function: 实际调用API的函数
|
||||||
|
test_function: 可选的测试函数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
工具调用结果
|
||||||
|
"""
|
||||||
|
|
||||||
|
# 获取适配器
|
||||||
|
adapter = await self.get_adapter(api_identifier, test_function)
|
||||||
|
|
||||||
|
# 首次尝试
|
||||||
|
try:
|
||||||
|
if adapter.supports_native_tools():
|
||||||
|
# Function Calling模式
|
||||||
|
logger.info("🚀 尝试使用Function Calling模式")
|
||||||
|
result = await self._try_function_calling(
|
||||||
|
tools, user_message, call_function, adapter
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# 提示词注入模式
|
||||||
|
logger.info("🚀 使用提示词注入模式")
|
||||||
|
result = await self._try_prompt_injection(
|
||||||
|
tools, user_message, call_function, adapter
|
||||||
|
)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"❌ 工具调用失败: {e}")
|
||||||
|
|
||||||
|
# 自动降级
|
||||||
|
if self._enable_auto_fallback and adapter.supports_native_tools():
|
||||||
|
logger.warning("⚠️ Function Calling失败,降级到提示词注入模式")
|
||||||
|
|
||||||
|
# 更新缓存,标记为不支持
|
||||||
|
async with self._cache_lock:
|
||||||
|
self._capability_cache[api_identifier] = APICapability(
|
||||||
|
supports_function_calling=False,
|
||||||
|
tested_at=datetime.now(),
|
||||||
|
test_duration_ms=0,
|
||||||
|
error_message=str(e)
|
||||||
|
)
|
||||||
|
|
||||||
|
# 使用提示词注入重试
|
||||||
|
fallback_adapter = self.adapters[AdapterType.PROMPT_INJECTION]
|
||||||
|
return await self._try_prompt_injection(
|
||||||
|
tools, user_message, call_function, fallback_adapter
|
||||||
|
)
|
||||||
|
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def _try_function_calling(
|
||||||
|
self,
|
||||||
|
tools: List[Dict[str, Any]],
|
||||||
|
user_message: str,
|
||||||
|
call_function: callable,
|
||||||
|
adapter: FunctionCallingAdapter
|
||||||
|
) -> ToolCallResult:
|
||||||
|
"""尝试Function Calling模式"""
|
||||||
|
|
||||||
|
# Function Calling不需要修改提示词
|
||||||
|
response = await call_function(
|
||||||
|
message=user_message,
|
||||||
|
tools_param=tools,
|
||||||
|
tool_choice_param="auto"
|
||||||
|
)
|
||||||
|
|
||||||
|
return adapter.parse_tool_calls(response)
|
||||||
|
|
||||||
|
async def _try_prompt_injection(
|
||||||
|
self,
|
||||||
|
tools: List[Dict[str, Any]],
|
||||||
|
user_message: str,
|
||||||
|
call_function: callable,
|
||||||
|
adapter: PromptInjectionAdapter
|
||||||
|
) -> ToolCallResult:
|
||||||
|
"""尝试提示词注入模式"""
|
||||||
|
|
||||||
|
# 注入工具到提示词
|
||||||
|
enhanced_prompt = adapter.format_tools_for_prompt(tools, user_message)
|
||||||
|
|
||||||
|
# 调用API(不传tools参数)
|
||||||
|
response = await call_function(
|
||||||
|
message=enhanced_prompt,
|
||||||
|
tools_param=None,
|
||||||
|
tool_choice_param=None
|
||||||
|
)
|
||||||
|
|
||||||
|
# 从文本响应中解析工具调用
|
||||||
|
return adapter.parse_tool_calls(response)
|
||||||
|
|
||||||
|
def clear_cache(self, api_identifier: Optional[str] = None):
|
||||||
|
"""
|
||||||
|
清理能力缓存
|
||||||
|
|
||||||
|
Args:
|
||||||
|
api_identifier: 可选,只清理特定API的缓存
|
||||||
|
"""
|
||||||
|
if api_identifier:
|
||||||
|
if api_identifier in self._capability_cache:
|
||||||
|
del self._capability_cache[api_identifier]
|
||||||
|
logger.info(f"🧹 已清理API能力缓存: {api_identifier}")
|
||||||
|
else:
|
||||||
|
self._capability_cache.clear()
|
||||||
|
logger.info("🧹 已清理所有API能力缓存")
|
||||||
|
|
||||||
|
def get_cache_stats(self) -> Dict[str, Any]:
|
||||||
|
"""获取缓存统计信息"""
|
||||||
|
return {
|
||||||
|
"total_cached": len(self._capability_cache),
|
||||||
|
"cache_ttl_hours": self._cache_ttl.total_seconds() / 3600,
|
||||||
|
"cached_apis": [
|
||||||
|
{
|
||||||
|
"api_identifier": api_id,
|
||||||
|
"supports_fc": cap.supports_function_calling,
|
||||||
|
"tested_at": cap.tested_at.isoformat(),
|
||||||
|
"test_duration_ms": cap.test_duration_ms
|
||||||
|
}
|
||||||
|
for api_id, cap in self._capability_cache.items()
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# 全局单例
|
||||||
|
universal_mcp_adapter = UniversalMCPAdapter()
|
||||||
@@ -4,6 +4,7 @@ from openai import AsyncOpenAI
|
|||||||
from anthropic import AsyncAnthropic
|
from anthropic import AsyncAnthropic
|
||||||
from app.config import settings as app_settings
|
from app.config import settings as app_settings
|
||||||
from app.logger import get_logger
|
from app.logger import get_logger
|
||||||
|
from app.mcp.adapters import UniversalMCPAdapter, PromptInjectionAdapter
|
||||||
import httpx
|
import httpx
|
||||||
import json
|
import json
|
||||||
import hashlib
|
import hashlib
|
||||||
@@ -118,7 +119,8 @@ class AIService:
|
|||||||
api_base_url: Optional[str] = None,
|
api_base_url: Optional[str] = None,
|
||||||
default_model: Optional[str] = None,
|
default_model: Optional[str] = None,
|
||||||
default_temperature: Optional[float] = None,
|
default_temperature: Optional[float] = None,
|
||||||
default_max_tokens: Optional[int] = None
|
default_max_tokens: Optional[int] = None,
|
||||||
|
enable_mcp_adapter: bool = True
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
初始化AI客户端(优化并发性能)
|
初始化AI客户端(优化并发性能)
|
||||||
@@ -137,6 +139,15 @@ class AIService:
|
|||||||
self.default_temperature = default_temperature or app_settings.default_temperature
|
self.default_temperature = default_temperature or app_settings.default_temperature
|
||||||
self.default_max_tokens = default_max_tokens or app_settings.default_max_tokens
|
self.default_max_tokens = default_max_tokens or app_settings.default_max_tokens
|
||||||
|
|
||||||
|
# 初始化MCP适配器
|
||||||
|
self.enable_mcp_adapter = enable_mcp_adapter
|
||||||
|
if enable_mcp_adapter:
|
||||||
|
self.mcp_adapter = UniversalMCPAdapter()
|
||||||
|
logger.info("✅ MCP通用适配器已启用")
|
||||||
|
else:
|
||||||
|
self.mcp_adapter = None
|
||||||
|
logger.info("⚠️ MCP适配器已禁用")
|
||||||
|
|
||||||
# 初始化OpenAI客户端(使用HTTP客户端池)
|
# 初始化OpenAI客户端(使用HTTP客户端池)
|
||||||
openai_key = api_key if api_provider == "openai" else app_settings.openai_api_key
|
openai_key = api_key if api_provider == "openai" else app_settings.openai_api_key
|
||||||
if openai_key:
|
if openai_key:
|
||||||
@@ -396,7 +407,7 @@ class AIService:
|
|||||||
tools: Optional[List[Dict[str, Any]]] = None,
|
tools: Optional[List[Dict[str, Any]]] = None,
|
||||||
tool_choice: Optional[str] = None
|
tool_choice: Optional[str] = None
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""使用OpenAI生成文本(支持工具调用)"""
|
"""使用OpenAI生成文本(支持工具调用,集成MCP适配器)"""
|
||||||
if not self.openai_http_client:
|
if not self.openai_http_client:
|
||||||
raise ValueError("OpenAI客户端未初始化,请检查API key配置")
|
raise ValueError("OpenAI客户端未初始化,请检查API key配置")
|
||||||
|
|
||||||
@@ -405,8 +416,101 @@ class AIService:
|
|||||||
messages.append({"role": "system", "content": system_prompt})
|
messages.append({"role": "system", "content": system_prompt})
|
||||||
messages.append({"role": "user", "content": prompt})
|
messages.append({"role": "user", "content": prompt})
|
||||||
|
|
||||||
|
# 如果启用了MCP适配器且有工具,使用适配器处理
|
||||||
|
if self.enable_mcp_adapter and self.mcp_adapter and tools:
|
||||||
|
logger.info(f"🎯 使用MCP适配器处理工具调用")
|
||||||
|
|
||||||
|
# 生成API标识符
|
||||||
|
api_identifier = f"openai_{self.openai_base_url or 'default'}"
|
||||||
|
|
||||||
|
# 定义API调用函数
|
||||||
|
async def call_api(message: str, tools_param: Optional[List] = None, tool_choice_param: Optional[str] = None):
|
||||||
|
"""实际调用OpenAI API的函数"""
|
||||||
|
call_messages = messages.copy()
|
||||||
|
call_messages[-1]["content"] = message
|
||||||
|
|
||||||
|
url = f"{self.openai_base_url}/chat/completions"
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {self.openai_api_key}",
|
||||||
|
"Content-Type": "application/json"
|
||||||
|
}
|
||||||
|
payload = {
|
||||||
|
"model": model,
|
||||||
|
"messages": call_messages,
|
||||||
|
"temperature": temperature,
|
||||||
|
"max_tokens": max_tokens
|
||||||
|
}
|
||||||
|
|
||||||
|
# 只在tools_param不为None时添加工具参数
|
||||||
|
if tools_param is not None:
|
||||||
|
# 清理工具定义,移除$schema字段(某些API不支持)
|
||||||
|
cleaned_tools = []
|
||||||
|
for tool in tools_param:
|
||||||
|
cleaned_tool = tool.copy()
|
||||||
|
if "function" in cleaned_tool and "parameters" in cleaned_tool["function"]:
|
||||||
|
params = cleaned_tool["function"]["parameters"].copy()
|
||||||
|
# 移除$schema字段
|
||||||
|
params.pop("$schema", None)
|
||||||
|
cleaned_tool["function"]["parameters"] = params
|
||||||
|
cleaned_tools.append(cleaned_tool)
|
||||||
|
|
||||||
|
payload["tools"] = cleaned_tools
|
||||||
|
if tool_choice_param:
|
||||||
|
payload["tool_choice"] = tool_choice_param
|
||||||
|
|
||||||
|
response = await self.openai_http_client.post(url, headers=headers, json=payload)
|
||||||
|
response.raise_for_status()
|
||||||
|
return response.json()
|
||||||
|
|
||||||
|
# 定义测试函数(检测API是否支持Function Calling)
|
||||||
|
async def test_fc():
|
||||||
|
"""测试Function Calling支持"""
|
||||||
|
test_tools = [{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "test_function",
|
||||||
|
"description": "测试函数",
|
||||||
|
"parameters": {"type": "object", "properties": {}}
|
||||||
|
}
|
||||||
|
}]
|
||||||
try:
|
try:
|
||||||
logger.info(f"🔵 开始调用OpenAI API(支持工具调用)")
|
result = await call_api("测试", tools_param=test_tools, tool_choice_param="none")
|
||||||
|
return result
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Function Calling测试失败: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 使用适配器处理(自动检测、降级、缓存)
|
||||||
|
result = await self.mcp_adapter.call_with_fallback(
|
||||||
|
api_identifier=api_identifier,
|
||||||
|
tools=tools,
|
||||||
|
user_message=prompt,
|
||||||
|
call_function=call_api,
|
||||||
|
test_function=test_fc
|
||||||
|
)
|
||||||
|
|
||||||
|
# 转换结果格式
|
||||||
|
if result.has_tool_calls:
|
||||||
|
return {
|
||||||
|
"tool_calls": result.tool_calls,
|
||||||
|
"content": result.raw_response,
|
||||||
|
"finish_reason": "tool_calls"
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
return {
|
||||||
|
"content": result.raw_response,
|
||||||
|
"finish_reason": "stop"
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"❌ MCP适配器调用失败: {str(e)}")
|
||||||
|
# 降级到原始实现
|
||||||
|
logger.warning("⚠️ 降级到原始OpenAI调用")
|
||||||
|
|
||||||
|
# 原始实现(无适配器或降级)
|
||||||
|
try:
|
||||||
|
logger.info(f"🔵 开始调用OpenAI API(原始模式)")
|
||||||
logger.info(f" - 模型: {model}")
|
logger.info(f" - 模型: {model}")
|
||||||
logger.info(f" - 工具数量: {len(tools) if tools else 0}")
|
logger.info(f" - 工具数量: {len(tools) if tools else 0}")
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user