diff --git a/backend/app/api/wizard_stream.py b/backend/app/api/wizard_stream.py
index d6b0f55..1100cca 100644
--- a/backend/app/api/wizard_stream.py
+++ b/backend/app/api/wizard_stream.py
@@ -68,11 +68,19 @@ async def world_building_generator(
reference_materials = ""
if enable_mcp and user_id:
try:
- yield await SSEResponse.send_progress("🔍 尝试使用MCP工具收集参考资料...", 18)
+ # 先静默检查是否有可用工具
+ from app.services.mcp_tool_service import mcp_tool_service
+ available_tools = await mcp_tool_service.get_user_enabled_tools(
+ user_id=user_id,
+ db_session=db
+ )
- # 直接调用MCP增强的AI,内部会自动检查和加载工具
- # 构建资料收集提示词
- planning_prompt = f"""你正在为小说《{title}》设计世界观。
+ # 只有在真正有可用工具时才显示消息和调用
+ if available_tools:
+ yield await SSEResponse.send_progress("🔍 尝试使用MCP工具收集参考资料...", 18)
+
+ # 构建资料收集提示词
+ planning_prompt = f"""你正在为小说《{title}》设计世界观。
【小说信息】
- 题材:{genre}
@@ -88,28 +96,32 @@ async def world_building_generator(
4. 类似作品的设定参考
请根据题材特点,有针对性地查询2-3个关键问题。"""
-
- # 调用MCP增强的AI(非流式,最多2轮工具调用)
- planning_result = await user_ai_service.generate_text_with_mcp(
- prompt=planning_prompt,
- user_id=user_id,
- db_session=db,
- enable_mcp=True,
- max_tool_rounds=2,
- tool_choice="auto",
- provider=None,
- model=None
- )
-
- # 提取参考资料
- if planning_result.get("tool_calls_made", 0) > 0:
- yield await SSEResponse.send_progress(
- f"✅ MCP工具调用成功({planning_result['tool_calls_made']}次)",
- 25
+
+ # 调用MCP增强的AI(非流式,最多2轮工具调用)
+ planning_result = await user_ai_service.generate_text_with_mcp(
+ prompt=planning_prompt,
+ user_id=user_id,
+ db_session=db,
+ enable_mcp=True,
+ max_tool_rounds=2,
+ tool_choice="auto",
+ provider=None,
+ model=None
)
- reference_materials = planning_result.get("content", "")
+
+ # 提取参考资料
+ if planning_result.get("tool_calls_made", 0) > 0:
+ yield await SSEResponse.send_progress(
+ f"✅ MCP工具调用成功({planning_result['tool_calls_made']}次)",
+ 25
+ )
+ reference_materials = planning_result.get("content", "")
+ else:
+ # 有工具但未使用
+ logger.debug("MCP工具可用但AI未选择使用")
else:
- yield await SSEResponse.send_progress("ℹ️ 未使用MCP工具(无可用工具或不需要)", 25)
+ # 没有可用工具,静默跳过
+ logger.debug(f"用户 {user_id} 未启用MCP工具,跳过MCP增强")
except Exception as e:
logger.warning(f"MCP工具调用失败(降级处理): {e}")
@@ -325,10 +337,19 @@ async def characters_generator(
character_reference_materials = ""
if enable_mcp and user_id:
try:
- yield await SSEResponse.send_progress("🔍 尝试使用MCP工具收集角色参考资料...", 8)
+ # 先静默检查是否有可用工具
+ from app.services.mcp_tool_service import mcp_tool_service
+ available_tools = await mcp_tool_service.get_user_enabled_tools(
+ user_id=user_id,
+ db_session=db
+ )
- # 构建角色资料收集提示词
- planning_prompt = f"""你正在为小说《{project.title}》设计角色。
+ # 只有在真正有可用工具时才显示消息和调用
+ if available_tools:
+ yield await SSEResponse.send_progress("🔍 尝试使用MCP工具收集角色参考资料...", 8)
+
+ # 构建角色资料收集提示词
+ planning_prompt = f"""你正在为小说《{project.title}》设计角色。
【小说信息】
- 题材:{genre or project.genre}
@@ -345,28 +366,32 @@ async def characters_generator(
4. 相关领域的人物原型
请根据题材特点,有针对性地查询1-2个关键问题。"""
-
- # 调用MCP增强的AI(非流式,最多2轮工具调用)
- planning_result = await user_ai_service.generate_text_with_mcp(
- prompt=planning_prompt,
- user_id=user_id,
- db_session=db,
- enable_mcp=True,
- max_tool_rounds=2,
- tool_choice="auto",
- provider=None,
- model=None
- )
-
- # 提取参考资料
- if planning_result.get("tool_calls_made", 0) > 0:
- yield await SSEResponse.send_progress(
- f"✅ MCP工具调用成功({planning_result['tool_calls_made']}次)",
- 12
+
+ # 调用MCP增强的AI(非流式,最多2轮工具调用)
+ planning_result = await user_ai_service.generate_text_with_mcp(
+ prompt=planning_prompt,
+ user_id=user_id,
+ db_session=db,
+ enable_mcp=True,
+ max_tool_rounds=2,
+ tool_choice="auto",
+ provider=None,
+ model=None
)
- character_reference_materials = planning_result.get("content", "")
+
+ # 提取参考资料
+ if planning_result.get("tool_calls_made", 0) > 0:
+ yield await SSEResponse.send_progress(
+ f"✅ MCP工具调用成功({planning_result['tool_calls_made']}次)",
+ 12
+ )
+ character_reference_materials = planning_result.get("content", "")
+ else:
+ # 有工具但未使用
+ logger.debug("MCP工具可用但AI未选择使用")
else:
- yield await SSEResponse.send_progress("ℹ️ 未使用MCP工具(无可用工具或不需要)", 12)
+ # 没有可用工具,静默跳过
+ logger.debug(f"用户 {user_id} 未启用MCP工具,跳过MCP增强")
except Exception as e:
logger.warning(f"MCP工具调用失败(降级处理): {e}")
diff --git a/backend/app/config.py b/backend/app/config.py
index 456182e..c41817a 100644
--- a/backend/app/config.py
+++ b/backend/app/config.py
@@ -77,6 +77,11 @@ class Settings(BaseSettings):
default_temperature: float = 0.7
default_max_tokens: int = 2000
+ # MCP适配器配置
+ enable_mcp_adapter: bool = True # 是否启用MCP适配器(自动检测API能力)
+ mcp_adapter_cache_ttl_hours: int = 24 # API能力检测缓存时长(小时)
+ mcp_adapter_auto_fallback: bool = True # 是否启用自动降级(FC失败时切换到提示词注入)
+
# LinuxDO OAuth2 配置
LINUXDO_CLIENT_ID: Optional[str] = None
LINUXDO_CLIENT_SECRET: Optional[str] = None
diff --git a/backend/app/mcp/adapters/__init__.py b/backend/app/mcp/adapters/__init__.py
new file mode 100644
index 0000000..54f5236
--- /dev/null
+++ b/backend/app/mcp/adapters/__init__.py
@@ -0,0 +1,14 @@
+"""MCP适配器模块 - 支持多种AI API的工具调用方式"""
+
+from .base import BaseMCPAdapter, AdapterType
+from .prompt_injection import PromptInjectionAdapter
+from .function_calling import FunctionCallingAdapter
+from .universal import UniversalMCPAdapter
+
+__all__ = [
+ "BaseMCPAdapter",
+ "AdapterType",
+ "PromptInjectionAdapter",
+ "FunctionCallingAdapter",
+ "UniversalMCPAdapter",
+]
\ No newline at end of file
diff --git a/backend/app/mcp/adapters/base.py b/backend/app/mcp/adapters/base.py
new file mode 100644
index 0000000..a744987
--- /dev/null
+++ b/backend/app/mcp/adapters/base.py
@@ -0,0 +1,89 @@
+"""MCP适配器基类"""
+
+from abc import ABC, abstractmethod
+from enum import Enum
+from typing import Dict, Any, List, Optional
+from dataclasses import dataclass
+
+
+class AdapterType(Enum):
+ """适配器类型"""
+ FUNCTION_CALLING = "function_calling" # 标准Function Calling
+ PROMPT_INJECTION = "prompt_injection" # 提示词注入
+ REACT = "react" # ReAct模式
+ XML = "xml" # XML标记
+
+
+@dataclass
+class ToolCallResult:
+ """工具调用结果"""
+ tool_calls: List[Dict[str, Any]] # 解析出的工具调用
+ raw_response: str # 原始AI响应
+ has_tool_calls: bool # 是否包含工具调用
+ needs_continuation: bool = False # 是否需要继续对话
+
+
+class BaseMCPAdapter(ABC):
+ """MCP适配器基类"""
+
+ def __init__(self):
+ self.adapter_type: AdapterType = AdapterType.PROMPT_INJECTION
+
+ @abstractmethod
+ def format_tools_for_prompt(
+ self,
+ tools: List[Dict[str, Any]],
+ user_message: str
+ ) -> str:
+ """
+ 将工具列表格式化为提示词
+
+ Args:
+ tools: MCP工具列表
+ user_message: 用户消息
+
+ Returns:
+ 格式化后的提示词
+ """
+ pass
+
+ @abstractmethod
+ def parse_tool_calls(self, ai_response: str) -> ToolCallResult:
+ """
+ 从AI响应中解析工具调用
+
+ Args:
+ ai_response: AI的原始响应
+
+ Returns:
+ 解析结果
+ """
+ pass
+
+ @abstractmethod
+ def build_continuation_prompt(
+ self,
+ original_message: str,
+ ai_response: str,
+ tool_results: List[Dict[str, Any]]
+ ) -> str:
+ """
+ 构建包含工具结果的继续对话提示词
+
+ Args:
+ original_message: 原始用户消息
+ ai_response: AI响应
+ tool_results: 工具执行结果
+
+ Returns:
+ 继续对话的提示词
+ """
+ pass
+
+ def supports_native_tools(self) -> bool:
+ """是否支持原生工具调用(如Function Calling)"""
+ return False
+
+ def get_adapter_type(self) -> AdapterType:
+ """获取适配器类型"""
+ return self.adapter_type
\ No newline at end of file
diff --git a/backend/app/mcp/adapters/function_calling.py b/backend/app/mcp/adapters/function_calling.py
new file mode 100644
index 0000000..d302bb5
--- /dev/null
+++ b/backend/app/mcp/adapters/function_calling.py
@@ -0,0 +1,171 @@
+"""Function Calling适配器 - 支持原生Function Calling的API"""
+
+import json
+from typing import Dict, Any, List
+from app.mcp.adapters.base import BaseMCPAdapter, AdapterType, ToolCallResult
+from app.logger import get_logger
+
+logger = get_logger(__name__)
+
+
+class FunctionCallingAdapter(BaseMCPAdapter):
+ """Function Calling适配器 - 用于支持原生工具调用的AI API(如OpenAI)"""
+
+ def __init__(self):
+ super().__init__()
+ self.adapter_type = AdapterType.FUNCTION_CALLING
+
+ def supports_native_tools(self) -> bool:
+ """支持原生工具调用"""
+ return True
+
+ def format_tools_for_prompt(
+ self,
+ tools: List[Dict[str, Any]],
+ user_message: str
+ ) -> str:
+ """
+ Function Calling模式下,工具通过API参数传递,不需要修改提示词
+
+ Returns:
+ 原始用户消息
+ """
+ return user_message
+
+ def get_tools_for_api(self, tools: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
+ """
+ 获取适用于API的工具格式
+
+ Args:
+ tools: MCP工具列表
+
+ Returns:
+ 适用于OpenAI Function Calling的工具格式
+ """
+ return tools
+
+ def parse_tool_calls(self, ai_response: Any) -> ToolCallResult:
+ """
+ 从AI响应中解析工具调用(Function Calling格式)
+
+ Args:
+ ai_response: AI响应对象(通常是OpenAI的ChatCompletion对象)
+
+ Returns:
+ 解析结果
+ """
+
+ try:
+ # 处理不同类型的响应
+ if isinstance(ai_response, dict):
+ # 字典格式(OpenAI API响应)
+ message = ai_response.get("choices", [{}])[0].get("message", {})
+ tool_calls = message.get("tool_calls", [])
+ content = message.get("content", "")
+
+ elif hasattr(ai_response, "choices"):
+ # 对象格式(OpenAI SDK响应)
+ message = ai_response.choices[0].message
+ tool_calls = getattr(message, "tool_calls", None) or []
+ content = getattr(message, "content", "") or ""
+
+ # 转换为字典格式
+ if tool_calls:
+ tool_calls = [
+ {
+ "id": tc.id,
+ "type": tc.type,
+ "function": {
+ "name": tc.function.name,
+ "arguments": tc.function.arguments
+ }
+ }
+ for tc in tool_calls
+ ]
+ else:
+ # 字符串格式(降级为文本响应)
+ return ToolCallResult(
+ tool_calls=[],
+ raw_response=str(ai_response),
+ has_tool_calls=False
+ )
+
+ has_tool_calls = len(tool_calls) > 0
+
+ if has_tool_calls:
+ logger.info(f"✅ Function Calling模式解析出 {len(tool_calls)} 个工具调用")
+ for tc in tool_calls:
+ logger.info(f" - {tc['function']['name']}")
+
+ return ToolCallResult(
+ tool_calls=tool_calls,
+ raw_response=content or "",
+ has_tool_calls=has_tool_calls,
+ needs_continuation=has_tool_calls
+ )
+
+ except Exception as e:
+ logger.error(f"❌ 解析Function Calling响应失败: {e}", exc_info=True)
+ return ToolCallResult(
+ tool_calls=[],
+ raw_response=str(ai_response),
+ has_tool_calls=False
+ )
+
+ def build_continuation_prompt(
+ self,
+ original_message: str,
+ ai_response: str,
+ tool_results: List[Dict[str, Any]]
+ ) -> str:
+ """
+ 构建包含工具结果的继续对话提示词
+
+ 在Function Calling模式下,这通常不需要,因为工具结果会作为消息历史的一部分
+ """
+ # Function Calling模式下通常通过消息历史传递工具结果
+ # 这里提供一个降级方案
+ results_text = "\n\n".join([
+ f"工具 {r['name']} 的结果:\n{r['content']}"
+ for r in tool_results
+ ])
+
+ return f"{original_message}\n\n工具执行结果:\n{results_text}\n\n请基于以上工具结果回答用户的问题。"
+
+ def build_messages_with_tool_results(
+ self,
+ messages: List[Dict[str, Any]],
+ tool_calls: List[Dict[str, Any]],
+ tool_results: List[Dict[str, Any]]
+ ) -> List[Dict[str, Any]]:
+ """
+ 构建包含工具结果的消息历史(Function Calling标准格式)
+
+ Args:
+ messages: 原始消息历史
+ tool_calls: AI的工具调用
+ tool_results: 工具执行结果
+
+ Returns:
+ 更新后的消息历史
+ """
+
+ new_messages = messages.copy()
+
+ # 添加助手的工具调用消息
+ new_messages.append({
+ "role": "assistant",
+ "content": None,
+ "tool_calls": tool_calls
+ })
+
+ # 添加工具结果消息
+ for result in tool_results:
+ new_messages.append({
+ "role": "tool",
+ "tool_call_id": result.get("tool_call_id", ""),
+ "name": result.get("name", ""),
+ "content": result.get("content", "")
+ })
+
+ return new_messages
\ No newline at end of file
diff --git a/backend/app/mcp/adapters/prompt_injection.py b/backend/app/mcp/adapters/prompt_injection.py
new file mode 100644
index 0000000..15ea7b8
--- /dev/null
+++ b/backend/app/mcp/adapters/prompt_injection.py
@@ -0,0 +1,274 @@
+"""提示词注入适配器 - 最通用的MCP工具调用方式"""
+
+import re
+import json
+from typing import Dict, Any, List
+from app.mcp.adapters.base import BaseMCPAdapter, AdapterType, ToolCallResult
+from app.logger import get_logger
+
+logger = get_logger(__name__)
+
+
+class PromptInjectionAdapter(BaseMCPAdapter):
+ """提示词注入适配器 - 将工具转换为文本描述,通过提示词引导AI调用"""
+
+ def __init__(self):
+ super().__init__()
+ self.adapter_type = AdapterType.PROMPT_INJECTION
+
+ def format_tools_for_prompt(
+ self,
+ tools: List[Dict[str, Any]],
+ user_message: str
+ ) -> str:
+ """将工具列表注入到提示词中"""
+
+ if not tools:
+ return user_message
+
+ # 格式化工具描述
+ tool_descriptions = self._format_tools_as_text(tools)
+
+ # 构建增强的提示词
+ enhanced_prompt = f"""你现在可以使用以下工具来帮助回答用户的问题。
+
+## 可用工具
+
+{tool_descriptions}
+
+## 工具使用说明
+
+当你需要使用工具时,请按以下XML格式输出(可以一次调用多个工具):
+
+
+
+工具名称
+
+{{
+ "参数名1": "参数值1",
+ "参数名2": "参数值2"
+}}
+
+
+
+
+## 重要提示
+
+1. 只有在确实需要使用工具时才调用工具
+2. 参数必须是有效的JSON格式
+3. 仔细检查参数是否符合工具的要求
+4. 可以在一个标签内包含多个
+5. 调用工具后,你会收到工具的执行结果,然后需要基于结果继续回答
+
+---
+
+用户问题:{user_message}
+
+请分析问题,判断是否需要使用工具。如果需要,先输出工具调用,然后等待结果。如果不需要,直接回答问题。"""
+
+ return enhanced_prompt
+
+ def _format_tools_as_text(self, tools: List[Dict[str, Any]]) -> str:
+ """将工具格式化为可读的文本描述"""
+ lines = []
+
+ for i, tool in enumerate(tools, 1):
+ func = tool.get("function", {})
+ name = func.get("name", "unknown")
+ description = func.get("description", "无描述")
+ parameters = func.get("parameters", {})
+
+ lines.append(f"### {i}. {name}")
+ lines.append(f"**描述**: {description}")
+ lines.append("")
+
+ # 格式化参数信息
+ if parameters and "properties" in parameters:
+ lines.append("**参数**:")
+ properties = parameters.get("properties", {})
+ required = parameters.get("required", [])
+
+ for param_name, param_info in properties.items():
+ param_type = param_info.get("type", "string")
+ param_desc = param_info.get("description", "")
+ is_required = "必填" if param_name in required else "可选"
+
+ lines.append(f" - `{param_name}` ({param_type}, {is_required}): {param_desc}")
+ lines.append("")
+
+ # 添加示例
+ if "example" in func:
+ lines.append(f"**示例**: {json.dumps(func['example'], ensure_ascii=False)}")
+ lines.append("")
+
+ return "\n".join(lines)
+
+ def parse_tool_calls(self, ai_response) -> ToolCallResult:
+ """从AI响应中解析工具调用"""
+
+ tool_calls = []
+
+ try:
+ # 处理不同类型的响应
+ if isinstance(ai_response, dict):
+ # 如果是字典,提取content字段
+ ai_response = ai_response.get("choices", [{}])[0].get("message", {}).get("content", "")
+ if not ai_response:
+ return ToolCallResult(
+ tool_calls=[],
+ raw_response="",
+ has_tool_calls=False
+ )
+ elif not isinstance(ai_response, str):
+ # 转换为字符串
+ ai_response = str(ai_response)
+
+ # 使用正则提取 标签内容
+ tool_calls_match = re.search(
+ r'(.*?)',
+ ai_response,
+ re.DOTALL | re.IGNORECASE
+ )
+
+ if not tool_calls_match:
+ # 没有找到工具调用
+ return ToolCallResult(
+ tool_calls=[],
+ raw_response=ai_response,
+ has_tool_calls=False
+ )
+
+ tool_calls_content = tool_calls_match.group(1)
+
+ # 提取所有 标签
+ tool_call_pattern = r'(.*?)'
+ tool_call_matches = re.findall(
+ tool_call_pattern,
+ tool_calls_content,
+ re.DOTALL | re.IGNORECASE
+ )
+
+ for i, tool_call_content in enumerate(tool_call_matches):
+ # 提取工具名称
+ name_match = re.search(
+ r'(.*?)',
+ tool_call_content,
+ re.DOTALL | re.IGNORECASE
+ )
+
+ # 提取参数
+ args_match = re.search(
+ r'(.*?)',
+ tool_call_content,
+ re.DOTALL | re.IGNORECASE
+ )
+
+ if name_match and args_match:
+ tool_name = name_match.group(1).strip()
+ arguments_str = args_match.group(1).strip()
+
+ try:
+ # 解析JSON参数
+ arguments = json.loads(arguments_str)
+
+ # 构建标准格式的工具调用
+ tool_calls.append({
+ "id": f"call_{i}",
+ "type": "function",
+ "function": {
+ "name": tool_name,
+ "arguments": json.dumps(arguments, ensure_ascii=False)
+ }
+ })
+
+ logger.info(f"✅ 解析工具调用: {tool_name}")
+
+ except json.JSONDecodeError as e:
+ logger.error(f"❌ 解析工具参数失败: {arguments_str}, 错误: {e}")
+ continue
+
+ has_tool_calls = len(tool_calls) > 0
+
+ if has_tool_calls:
+ logger.info(f"✅ 从响应中解析出 {len(tool_calls)} 个工具调用")
+
+ return ToolCallResult(
+ tool_calls=tool_calls,
+ raw_response=ai_response,
+ has_tool_calls=has_tool_calls,
+ needs_continuation=has_tool_calls
+ )
+
+ except Exception as e:
+ logger.error(f"❌ 解析工具调用失败: {e}", exc_info=True)
+ return ToolCallResult(
+ tool_calls=[],
+ raw_response=ai_response,
+ has_tool_calls=False
+ )
+
+ def build_continuation_prompt(
+ self,
+ original_message: str,
+ ai_response: str,
+ tool_results: List[Dict[str, Any]]
+ ) -> str:
+ """构建包含工具结果的继续对话提示词"""
+
+ # 格式化工具结果
+ results_text = self._format_tool_results(tool_results)
+
+ continuation = f"""你之前尝试使用工具来回答用户的问题。
+
+原始问题:{original_message}
+
+你的工具调用:
+{self._extract_tool_calls_text(ai_response)}
+
+工具执行结果:
+{results_text}
+
+现在,请基于这些工具的执行结果,给出完整、详细的回答。不要重复调用工具,直接使用已有的结果来回答用户的问题。"""
+
+ return continuation
+
+ def _format_tool_results(self, tool_results: List[Dict[str, Any]]) -> str:
+ """格式化工具结果为可读文本"""
+ lines = []
+
+ for i, result in enumerate(tool_results, 1):
+ tool_name = result.get("name", "unknown")
+ success = result.get("success", False)
+ content = result.get("content", "")
+
+ status = "✅ 成功" if success else "❌ 失败"
+ lines.append(f"{i}. {tool_name} - {status}")
+
+ if success:
+ # 尝试美化JSON内容
+ try:
+ if isinstance(content, str):
+ content_obj = json.loads(content)
+ content = json.dumps(content_obj, ensure_ascii=False, indent=2)
+ except:
+ pass
+ lines.append(f"```\n{content}\n```")
+ else:
+ error = result.get("error", "未知错误")
+ lines.append(f"错误信息: {error}")
+
+ lines.append("")
+
+ return "\n".join(lines)
+
+ def _extract_tool_calls_text(self, ai_response: str) -> str:
+ """从AI响应中提取工具调用部分的文本"""
+ match = re.search(
+ r'(.*?)',
+ ai_response,
+ re.DOTALL | re.IGNORECASE
+ )
+
+ if match:
+ return match.group(0)
+ return "(未找到工具调用)"
\ No newline at end of file
diff --git a/backend/app/mcp/adapters/universal.py b/backend/app/mcp/adapters/universal.py
new file mode 100644
index 0000000..aa2be30
--- /dev/null
+++ b/backend/app/mcp/adapters/universal.py
@@ -0,0 +1,353 @@
+"""通用MCP适配器 - 自动检测API能力并选择最佳适配器"""
+
+import time
+import asyncio
+from typing import Dict, Any, List, Optional
+from datetime import datetime, timedelta
+from dataclasses import dataclass
+
+from app.mcp.adapters.base import BaseMCPAdapter, AdapterType, ToolCallResult
+from app.mcp.adapters.prompt_injection import PromptInjectionAdapter
+from app.mcp.adapters.function_calling import FunctionCallingAdapter
+from app.logger import get_logger
+
+logger = get_logger(__name__)
+
+
+@dataclass
+class APICapability:
+ """API能力检测结果"""
+ supports_function_calling: bool
+ tested_at: datetime
+ test_duration_ms: float
+ error_message: Optional[str] = None
+
+
+class UniversalMCPAdapter:
+ """
+ 通用MCP适配器管理器
+
+ 功能:
+ 1. 自动检测API是否支持Function Calling
+ 2. 缓存检测结果
+ 3. 自动降级策略:FC失败时切换到提示词注入
+ 4. 提供统一接口
+ """
+
+ def __init__(
+ self,
+ cache_ttl_hours: int = 24,
+ enable_auto_fallback: bool = True
+ ):
+ """
+ 初始化通用适配器
+
+ Args:
+ cache_ttl_hours: 能力检测缓存时长(小时)
+ enable_auto_fallback: 是否启用自动降级
+ """
+ # 适配器实例
+ self.adapters = {
+ AdapterType.FUNCTION_CALLING: FunctionCallingAdapter(),
+ AdapterType.PROMPT_INJECTION: PromptInjectionAdapter()
+ }
+
+ # API能力缓存: {api_identifier: APICapability}
+ self._capability_cache: Dict[str, APICapability] = {}
+ self._cache_ttl = timedelta(hours=cache_ttl_hours)
+ self._cache_lock = asyncio.Lock()
+
+ # 配置
+ self._enable_auto_fallback = enable_auto_fallback
+
+ logger.info(
+ f"✅ UniversalMCPAdapter初始化完成 "
+ f"(缓存TTL={cache_ttl_hours}小时, 自动降级={'开启' if enable_auto_fallback else '关闭'})"
+ )
+
+ async def get_adapter(
+ self,
+ api_identifier: str,
+ test_function: Optional[callable] = None
+ ) -> BaseMCPAdapter:
+ """
+ 获取适合当前API的适配器
+
+ Args:
+ api_identifier: API标识符(如"openai_official", "azure_openai"等)
+ test_function: 可选的测试函数,用于检测API能力
+
+ Returns:
+ 最适合的适配器实例
+ """
+
+ # 检查缓存
+ capability = await self._get_cached_capability(api_identifier)
+
+ if capability is None and test_function:
+ # 缓存未命中,执行检测
+ capability = await self._detect_capability(api_identifier, test_function)
+
+ # 选择适配器
+ if capability and capability.supports_function_calling:
+ logger.info(f"🎯 使用Function Calling适配器: {api_identifier}")
+ return self.adapters[AdapterType.FUNCTION_CALLING]
+ else:
+ logger.info(f"🎯 使用提示词注入适配器: {api_identifier}")
+ return self.adapters[AdapterType.PROMPT_INJECTION]
+
+ async def _get_cached_capability(
+ self,
+ api_identifier: str
+ ) -> Optional[APICapability]:
+ """获取缓存的能力检测结果"""
+
+ async with self._cache_lock:
+ if api_identifier not in self._capability_cache:
+ return None
+
+ capability = self._capability_cache[api_identifier]
+
+ # 检查是否过期
+ if datetime.now() - capability.tested_at > self._cache_ttl:
+ logger.info(f"⏰ API能力缓存过期: {api_identifier}")
+ del self._capability_cache[api_identifier]
+ return None
+
+ logger.debug(f"🎯 API能力缓存命中: {api_identifier}")
+ return capability
+
+ async def _detect_capability(
+ self,
+ api_identifier: str,
+ test_function: callable
+ ) -> APICapability:
+ """
+ 检测API能力
+
+ Args:
+ api_identifier: API标识符
+ test_function: 测试函数,应该尝试使用Function Calling
+
+ Returns:
+ 能力检测结果
+ """
+
+ logger.info(f"🔍 开始检测API能力: {api_identifier}")
+ start_time = time.time()
+
+ try:
+ # 调用测试函数
+ result = await test_function()
+
+ # 判断是否成功
+ supports_fc = self._is_function_calling_response(result)
+
+ duration_ms = (time.time() - start_time) * 1000
+
+ capability = APICapability(
+ supports_function_calling=supports_fc,
+ tested_at=datetime.now(),
+ test_duration_ms=duration_ms
+ )
+
+ # 缓存结果
+ async with self._cache_lock:
+ self._capability_cache[api_identifier] = capability
+
+ status = "✅ 支持" if supports_fc else "❌ 不支持"
+ logger.info(
+ f"{status} Function Calling: {api_identifier} "
+ f"(耗时: {duration_ms:.2f}ms)"
+ )
+
+ return capability
+
+ except Exception as e:
+ duration_ms = (time.time() - start_time) * 1000
+
+ logger.warning(
+ f"⚠️ API能力检测失败: {api_identifier}, 错误: {e}, "
+ f"将使用提示词注入模式"
+ )
+
+ capability = APICapability(
+ supports_function_calling=False,
+ tested_at=datetime.now(),
+ test_duration_ms=duration_ms,
+ error_message=str(e)
+ )
+
+ # 缓存失败结果(避免重复测试)
+ async with self._cache_lock:
+ self._capability_cache[api_identifier] = capability
+
+ return capability
+
+ def _is_function_calling_response(self, response: Any) -> bool:
+ """
+ 判断响应是否是Function Calling格式
+
+ Args:
+ response: API响应
+
+ Returns:
+ 是否支持Function Calling
+ """
+
+ try:
+ # 检查字典格式
+ if isinstance(response, dict):
+ message = response.get("choices", [{}])[0].get("message", {})
+ return "tool_calls" in message or "function_call" in message
+
+ # 检查对象格式(OpenAI SDK)
+ if hasattr(response, "choices"):
+ message = response.choices[0].message
+ return hasattr(message, "tool_calls") or hasattr(message, "function_call")
+
+ return False
+
+ except Exception:
+ return False
+
+ async def call_with_fallback(
+ self,
+ api_identifier: str,
+ tools: List[Dict[str, Any]],
+ user_message: str,
+ call_function: callable,
+ test_function: Optional[callable] = None
+ ) -> ToolCallResult:
+ """
+ 带降级策略的工具调用
+
+ Args:
+ api_identifier: API标识符
+ tools: MCP工具列表
+ user_message: 用户消息
+ call_function: 实际调用API的函数
+ test_function: 可选的测试函数
+
+ Returns:
+ 工具调用结果
+ """
+
+ # 获取适配器
+ adapter = await self.get_adapter(api_identifier, test_function)
+
+ # 首次尝试
+ try:
+ if adapter.supports_native_tools():
+ # Function Calling模式
+ logger.info("🚀 尝试使用Function Calling模式")
+ result = await self._try_function_calling(
+ tools, user_message, call_function, adapter
+ )
+ else:
+ # 提示词注入模式
+ logger.info("🚀 使用提示词注入模式")
+ result = await self._try_prompt_injection(
+ tools, user_message, call_function, adapter
+ )
+
+ return result
+
+ except Exception as e:
+ logger.error(f"❌ 工具调用失败: {e}")
+
+ # 自动降级
+ if self._enable_auto_fallback and adapter.supports_native_tools():
+ logger.warning("⚠️ Function Calling失败,降级到提示词注入模式")
+
+ # 更新缓存,标记为不支持
+ async with self._cache_lock:
+ self._capability_cache[api_identifier] = APICapability(
+ supports_function_calling=False,
+ tested_at=datetime.now(),
+ test_duration_ms=0,
+ error_message=str(e)
+ )
+
+ # 使用提示词注入重试
+ fallback_adapter = self.adapters[AdapterType.PROMPT_INJECTION]
+ return await self._try_prompt_injection(
+ tools, user_message, call_function, fallback_adapter
+ )
+
+ raise
+
+ async def _try_function_calling(
+ self,
+ tools: List[Dict[str, Any]],
+ user_message: str,
+ call_function: callable,
+ adapter: FunctionCallingAdapter
+ ) -> ToolCallResult:
+ """尝试Function Calling模式"""
+
+ # Function Calling不需要修改提示词
+ response = await call_function(
+ message=user_message,
+ tools_param=tools,
+ tool_choice_param="auto"
+ )
+
+ return adapter.parse_tool_calls(response)
+
+ async def _try_prompt_injection(
+ self,
+ tools: List[Dict[str, Any]],
+ user_message: str,
+ call_function: callable,
+ adapter: PromptInjectionAdapter
+ ) -> ToolCallResult:
+ """尝试提示词注入模式"""
+
+ # 注入工具到提示词
+ enhanced_prompt = adapter.format_tools_for_prompt(tools, user_message)
+
+ # 调用API(不传tools参数)
+ response = await call_function(
+ message=enhanced_prompt,
+ tools_param=None,
+ tool_choice_param=None
+ )
+
+ # 从文本响应中解析工具调用
+ return adapter.parse_tool_calls(response)
+
+ def clear_cache(self, api_identifier: Optional[str] = None):
+ """
+ 清理能力缓存
+
+ Args:
+ api_identifier: 可选,只清理特定API的缓存
+ """
+ if api_identifier:
+ if api_identifier in self._capability_cache:
+ del self._capability_cache[api_identifier]
+ logger.info(f"🧹 已清理API能力缓存: {api_identifier}")
+ else:
+ self._capability_cache.clear()
+ logger.info("🧹 已清理所有API能力缓存")
+
+ def get_cache_stats(self) -> Dict[str, Any]:
+ """获取缓存统计信息"""
+ return {
+ "total_cached": len(self._capability_cache),
+ "cache_ttl_hours": self._cache_ttl.total_seconds() / 3600,
+ "cached_apis": [
+ {
+ "api_identifier": api_id,
+ "supports_fc": cap.supports_function_calling,
+ "tested_at": cap.tested_at.isoformat(),
+ "test_duration_ms": cap.test_duration_ms
+ }
+ for api_id, cap in self._capability_cache.items()
+ ]
+ }
+
+
+# 全局单例
+universal_mcp_adapter = UniversalMCPAdapter()
\ No newline at end of file
diff --git a/backend/app/services/ai_service.py b/backend/app/services/ai_service.py
index 62cd8b0..49888ce 100644
--- a/backend/app/services/ai_service.py
+++ b/backend/app/services/ai_service.py
@@ -4,6 +4,7 @@ from openai import AsyncOpenAI
from anthropic import AsyncAnthropic
from app.config import settings as app_settings
from app.logger import get_logger
+from app.mcp.adapters import UniversalMCPAdapter, PromptInjectionAdapter
import httpx
import json
import hashlib
@@ -118,7 +119,8 @@ class AIService:
api_base_url: Optional[str] = None,
default_model: Optional[str] = None,
default_temperature: Optional[float] = None,
- default_max_tokens: Optional[int] = None
+ default_max_tokens: Optional[int] = None,
+ enable_mcp_adapter: bool = True
):
"""
初始化AI客户端(优化并发性能)
@@ -137,6 +139,15 @@ class AIService:
self.default_temperature = default_temperature or app_settings.default_temperature
self.default_max_tokens = default_max_tokens or app_settings.default_max_tokens
+ # 初始化MCP适配器
+ self.enable_mcp_adapter = enable_mcp_adapter
+ if enable_mcp_adapter:
+ self.mcp_adapter = UniversalMCPAdapter()
+ logger.info("✅ MCP通用适配器已启用")
+ else:
+ self.mcp_adapter = None
+ logger.info("⚠️ MCP适配器已禁用")
+
# 初始化OpenAI客户端(使用HTTP客户端池)
openai_key = api_key if api_provider == "openai" else app_settings.openai_api_key
if openai_key:
@@ -396,7 +407,7 @@ class AIService:
tools: Optional[List[Dict[str, Any]]] = None,
tool_choice: Optional[str] = None
) -> Dict[str, Any]:
- """使用OpenAI生成文本(支持工具调用)"""
+ """使用OpenAI生成文本(支持工具调用,集成MCP适配器)"""
if not self.openai_http_client:
raise ValueError("OpenAI客户端未初始化,请检查API key配置")
@@ -405,8 +416,101 @@ class AIService:
messages.append({"role": "system", "content": system_prompt})
messages.append({"role": "user", "content": prompt})
+ # 如果启用了MCP适配器且有工具,使用适配器处理
+ if self.enable_mcp_adapter and self.mcp_adapter and tools:
+ logger.info(f"🎯 使用MCP适配器处理工具调用")
+
+ # 生成API标识符
+ api_identifier = f"openai_{self.openai_base_url or 'default'}"
+
+ # 定义API调用函数
+ async def call_api(message: str, tools_param: Optional[List] = None, tool_choice_param: Optional[str] = None):
+ """实际调用OpenAI API的函数"""
+ call_messages = messages.copy()
+ call_messages[-1]["content"] = message
+
+ url = f"{self.openai_base_url}/chat/completions"
+ headers = {
+ "Authorization": f"Bearer {self.openai_api_key}",
+ "Content-Type": "application/json"
+ }
+ payload = {
+ "model": model,
+ "messages": call_messages,
+ "temperature": temperature,
+ "max_tokens": max_tokens
+ }
+
+ # 只在tools_param不为None时添加工具参数
+ if tools_param is not None:
+ # 清理工具定义,移除$schema字段(某些API不支持)
+ cleaned_tools = []
+ for tool in tools_param:
+ cleaned_tool = tool.copy()
+ if "function" in cleaned_tool and "parameters" in cleaned_tool["function"]:
+ params = cleaned_tool["function"]["parameters"].copy()
+ # 移除$schema字段
+ params.pop("$schema", None)
+ cleaned_tool["function"]["parameters"] = params
+ cleaned_tools.append(cleaned_tool)
+
+ payload["tools"] = cleaned_tools
+ if tool_choice_param:
+ payload["tool_choice"] = tool_choice_param
+
+ response = await self.openai_http_client.post(url, headers=headers, json=payload)
+ response.raise_for_status()
+ return response.json()
+
+ # 定义测试函数(检测API是否支持Function Calling)
+ async def test_fc():
+ """测试Function Calling支持"""
+ test_tools = [{
+ "type": "function",
+ "function": {
+ "name": "test_function",
+ "description": "测试函数",
+ "parameters": {"type": "object", "properties": {}}
+ }
+ }]
+ try:
+ result = await call_api("测试", tools_param=test_tools, tool_choice_param="none")
+ return result
+ except Exception as e:
+ logger.debug(f"Function Calling测试失败: {e}")
+ raise
+
+ try:
+ # 使用适配器处理(自动检测、降级、缓存)
+ result = await self.mcp_adapter.call_with_fallback(
+ api_identifier=api_identifier,
+ tools=tools,
+ user_message=prompt,
+ call_function=call_api,
+ test_function=test_fc
+ )
+
+ # 转换结果格式
+ if result.has_tool_calls:
+ return {
+ "tool_calls": result.tool_calls,
+ "content": result.raw_response,
+ "finish_reason": "tool_calls"
+ }
+ else:
+ return {
+ "content": result.raw_response,
+ "finish_reason": "stop"
+ }
+
+ except Exception as e:
+ logger.error(f"❌ MCP适配器调用失败: {str(e)}")
+ # 降级到原始实现
+ logger.warning("⚠️ 降级到原始OpenAI调用")
+
+ # 原始实现(无适配器或降级)
try:
- logger.info(f"🔵 开始调用OpenAI API(支持工具调用)")
+ logger.info(f"🔵 开始调用OpenAI API(原始模式)")
logger.info(f" - 模型: {model}")
logger.info(f" - 工具数量: {len(tools) if tools else 0}")