From 69e3e46c96e7ec071e1e87f9d6d265664c7ba92c Mon Sep 17 00:00:00 2001 From: xiamuceer Date: Mon, 24 Nov 2025 11:30:27 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E4=BC=98=E5=8C=96MCP=E5=B7=A5=E5=85=B7?= =?UTF-8?q?=E8=B0=83=E7=94=A8=E4=BD=93=E9=AA=8C=E5=B9=B6=E9=9B=86=E6=88=90?= =?UTF-8?q?=E9=80=9A=E7=94=A8=E9=80=82=E9=85=8D=E5=99=A8=20-=20=E9=9D=99?= =?UTF-8?q?=E9=BB=98=E6=A3=80=E6=9F=A5MCP=E5=B7=A5=E5=85=B7=E5=8F=AF?= =?UTF-8?q?=E7=94=A8=E6=80=A7=EF=BC=8C=E6=94=AF=E6=8C=81=E6=8F=90=E7=A4=BA?= =?UTF-8?q?=E8=AF=8D=E6=B3=A8=E5=85=A5=E8=B0=83=E7=94=A8mcp=20-=20?= =?UTF-8?q?=E9=9B=86=E6=88=90UniversalMCPAdapter=EF=BC=8C=E6=94=AF?= =?UTF-8?q?=E6=8C=81=E8=87=AA=E5=8A=A8API=E8=83=BD=E5=8A=9B=E6=A3=80?= =?UTF-8?q?=E6=B5=8B=E5=92=8C=E6=99=BA=E8=83=BD=E9=99=8D=E7=BA=A7=20-=20?= =?UTF-8?q?=E6=96=B0=E5=A2=9EMCP=E9=80=82=E9=85=8D=E5=99=A8=E9=85=8D?= =?UTF-8?q?=E7=BD=AE=E9=A1=B9=EF=BC=8C=E5=A2=9E=E5=BC=BA=E7=B3=BB=E7=BB=9F?= =?UTF-8?q?=E5=85=BC=E5=AE=B9=E6=80=A7=E5=92=8C=E5=81=A5=E5=A3=AE=E6=80=A7?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/app/api/wizard_stream.py | 119 ++++--- backend/app/config.py | 5 + backend/app/mcp/adapters/__init__.py | 14 + backend/app/mcp/adapters/base.py | 89 +++++ backend/app/mcp/adapters/function_calling.py | 171 +++++++++ backend/app/mcp/adapters/prompt_injection.py | 274 ++++++++++++++ backend/app/mcp/adapters/universal.py | 353 +++++++++++++++++++ backend/app/services/ai_service.py | 110 +++++- 8 files changed, 1085 insertions(+), 50 deletions(-) create mode 100644 backend/app/mcp/adapters/__init__.py create mode 100644 backend/app/mcp/adapters/base.py create mode 100644 backend/app/mcp/adapters/function_calling.py create mode 100644 backend/app/mcp/adapters/prompt_injection.py create mode 100644 backend/app/mcp/adapters/universal.py diff --git a/backend/app/api/wizard_stream.py b/backend/app/api/wizard_stream.py index d6b0f55..1100cca 100644 --- a/backend/app/api/wizard_stream.py +++ b/backend/app/api/wizard_stream.py @@ -68,11 +68,19 @@ async def world_building_generator( reference_materials = "" if enable_mcp and user_id: try: - yield await SSEResponse.send_progress("🔍 尝试使用MCP工具收集参考资料...", 18) + # 先静默检查是否有可用工具 + from app.services.mcp_tool_service import mcp_tool_service + available_tools = await mcp_tool_service.get_user_enabled_tools( + user_id=user_id, + db_session=db + ) - # 直接调用MCP增强的AI,内部会自动检查和加载工具 - # 构建资料收集提示词 - planning_prompt = f"""你正在为小说《{title}》设计世界观。 + # 只有在真正有可用工具时才显示消息和调用 + if available_tools: + yield await SSEResponse.send_progress("🔍 尝试使用MCP工具收集参考资料...", 18) + + # 构建资料收集提示词 + planning_prompt = f"""你正在为小说《{title}》设计世界观。 【小说信息】 - 题材:{genre} @@ -88,28 +96,32 @@ async def world_building_generator( 4. 类似作品的设定参考 请根据题材特点,有针对性地查询2-3个关键问题。""" - - # 调用MCP增强的AI(非流式,最多2轮工具调用) - planning_result = await user_ai_service.generate_text_with_mcp( - prompt=planning_prompt, - user_id=user_id, - db_session=db, - enable_mcp=True, - max_tool_rounds=2, - tool_choice="auto", - provider=None, - model=None - ) - - # 提取参考资料 - if planning_result.get("tool_calls_made", 0) > 0: - yield await SSEResponse.send_progress( - f"✅ MCP工具调用成功({planning_result['tool_calls_made']}次)", - 25 + + # 调用MCP增强的AI(非流式,最多2轮工具调用) + planning_result = await user_ai_service.generate_text_with_mcp( + prompt=planning_prompt, + user_id=user_id, + db_session=db, + enable_mcp=True, + max_tool_rounds=2, + tool_choice="auto", + provider=None, + model=None ) - reference_materials = planning_result.get("content", "") + + # 提取参考资料 + if planning_result.get("tool_calls_made", 0) > 0: + yield await SSEResponse.send_progress( + f"✅ MCP工具调用成功({planning_result['tool_calls_made']}次)", + 25 + ) + reference_materials = planning_result.get("content", "") + else: + # 有工具但未使用 + logger.debug("MCP工具可用但AI未选择使用") else: - yield await SSEResponse.send_progress("ℹ️ 未使用MCP工具(无可用工具或不需要)", 25) + # 没有可用工具,静默跳过 + logger.debug(f"用户 {user_id} 未启用MCP工具,跳过MCP增强") except Exception as e: logger.warning(f"MCP工具调用失败(降级处理): {e}") @@ -325,10 +337,19 @@ async def characters_generator( character_reference_materials = "" if enable_mcp and user_id: try: - yield await SSEResponse.send_progress("🔍 尝试使用MCP工具收集角色参考资料...", 8) + # 先静默检查是否有可用工具 + from app.services.mcp_tool_service import mcp_tool_service + available_tools = await mcp_tool_service.get_user_enabled_tools( + user_id=user_id, + db_session=db + ) - # 构建角色资料收集提示词 - planning_prompt = f"""你正在为小说《{project.title}》设计角色。 + # 只有在真正有可用工具时才显示消息和调用 + if available_tools: + yield await SSEResponse.send_progress("🔍 尝试使用MCP工具收集角色参考资料...", 8) + + # 构建角色资料收集提示词 + planning_prompt = f"""你正在为小说《{project.title}》设计角色。 【小说信息】 - 题材:{genre or project.genre} @@ -345,28 +366,32 @@ async def characters_generator( 4. 相关领域的人物原型 请根据题材特点,有针对性地查询1-2个关键问题。""" - - # 调用MCP增强的AI(非流式,最多2轮工具调用) - planning_result = await user_ai_service.generate_text_with_mcp( - prompt=planning_prompt, - user_id=user_id, - db_session=db, - enable_mcp=True, - max_tool_rounds=2, - tool_choice="auto", - provider=None, - model=None - ) - - # 提取参考资料 - if planning_result.get("tool_calls_made", 0) > 0: - yield await SSEResponse.send_progress( - f"✅ MCP工具调用成功({planning_result['tool_calls_made']}次)", - 12 + + # 调用MCP增强的AI(非流式,最多2轮工具调用) + planning_result = await user_ai_service.generate_text_with_mcp( + prompt=planning_prompt, + user_id=user_id, + db_session=db, + enable_mcp=True, + max_tool_rounds=2, + tool_choice="auto", + provider=None, + model=None ) - character_reference_materials = planning_result.get("content", "") + + # 提取参考资料 + if planning_result.get("tool_calls_made", 0) > 0: + yield await SSEResponse.send_progress( + f"✅ MCP工具调用成功({planning_result['tool_calls_made']}次)", + 12 + ) + character_reference_materials = planning_result.get("content", "") + else: + # 有工具但未使用 + logger.debug("MCP工具可用但AI未选择使用") else: - yield await SSEResponse.send_progress("ℹ️ 未使用MCP工具(无可用工具或不需要)", 12) + # 没有可用工具,静默跳过 + logger.debug(f"用户 {user_id} 未启用MCP工具,跳过MCP增强") except Exception as e: logger.warning(f"MCP工具调用失败(降级处理): {e}") diff --git a/backend/app/config.py b/backend/app/config.py index 456182e..c41817a 100644 --- a/backend/app/config.py +++ b/backend/app/config.py @@ -77,6 +77,11 @@ class Settings(BaseSettings): default_temperature: float = 0.7 default_max_tokens: int = 2000 + # MCP适配器配置 + enable_mcp_adapter: bool = True # 是否启用MCP适配器(自动检测API能力) + mcp_adapter_cache_ttl_hours: int = 24 # API能力检测缓存时长(小时) + mcp_adapter_auto_fallback: bool = True # 是否启用自动降级(FC失败时切换到提示词注入) + # LinuxDO OAuth2 配置 LINUXDO_CLIENT_ID: Optional[str] = None LINUXDO_CLIENT_SECRET: Optional[str] = None diff --git a/backend/app/mcp/adapters/__init__.py b/backend/app/mcp/adapters/__init__.py new file mode 100644 index 0000000..54f5236 --- /dev/null +++ b/backend/app/mcp/adapters/__init__.py @@ -0,0 +1,14 @@ +"""MCP适配器模块 - 支持多种AI API的工具调用方式""" + +from .base import BaseMCPAdapter, AdapterType +from .prompt_injection import PromptInjectionAdapter +from .function_calling import FunctionCallingAdapter +from .universal import UniversalMCPAdapter + +__all__ = [ + "BaseMCPAdapter", + "AdapterType", + "PromptInjectionAdapter", + "FunctionCallingAdapter", + "UniversalMCPAdapter", +] \ No newline at end of file diff --git a/backend/app/mcp/adapters/base.py b/backend/app/mcp/adapters/base.py new file mode 100644 index 0000000..a744987 --- /dev/null +++ b/backend/app/mcp/adapters/base.py @@ -0,0 +1,89 @@ +"""MCP适配器基类""" + +from abc import ABC, abstractmethod +from enum import Enum +from typing import Dict, Any, List, Optional +from dataclasses import dataclass + + +class AdapterType(Enum): + """适配器类型""" + FUNCTION_CALLING = "function_calling" # 标准Function Calling + PROMPT_INJECTION = "prompt_injection" # 提示词注入 + REACT = "react" # ReAct模式 + XML = "xml" # XML标记 + + +@dataclass +class ToolCallResult: + """工具调用结果""" + tool_calls: List[Dict[str, Any]] # 解析出的工具调用 + raw_response: str # 原始AI响应 + has_tool_calls: bool # 是否包含工具调用 + needs_continuation: bool = False # 是否需要继续对话 + + +class BaseMCPAdapter(ABC): + """MCP适配器基类""" + + def __init__(self): + self.adapter_type: AdapterType = AdapterType.PROMPT_INJECTION + + @abstractmethod + def format_tools_for_prompt( + self, + tools: List[Dict[str, Any]], + user_message: str + ) -> str: + """ + 将工具列表格式化为提示词 + + Args: + tools: MCP工具列表 + user_message: 用户消息 + + Returns: + 格式化后的提示词 + """ + pass + + @abstractmethod + def parse_tool_calls(self, ai_response: str) -> ToolCallResult: + """ + 从AI响应中解析工具调用 + + Args: + ai_response: AI的原始响应 + + Returns: + 解析结果 + """ + pass + + @abstractmethod + def build_continuation_prompt( + self, + original_message: str, + ai_response: str, + tool_results: List[Dict[str, Any]] + ) -> str: + """ + 构建包含工具结果的继续对话提示词 + + Args: + original_message: 原始用户消息 + ai_response: AI响应 + tool_results: 工具执行结果 + + Returns: + 继续对话的提示词 + """ + pass + + def supports_native_tools(self) -> bool: + """是否支持原生工具调用(如Function Calling)""" + return False + + def get_adapter_type(self) -> AdapterType: + """获取适配器类型""" + return self.adapter_type \ No newline at end of file diff --git a/backend/app/mcp/adapters/function_calling.py b/backend/app/mcp/adapters/function_calling.py new file mode 100644 index 0000000..d302bb5 --- /dev/null +++ b/backend/app/mcp/adapters/function_calling.py @@ -0,0 +1,171 @@ +"""Function Calling适配器 - 支持原生Function Calling的API""" + +import json +from typing import Dict, Any, List +from app.mcp.adapters.base import BaseMCPAdapter, AdapterType, ToolCallResult +from app.logger import get_logger + +logger = get_logger(__name__) + + +class FunctionCallingAdapter(BaseMCPAdapter): + """Function Calling适配器 - 用于支持原生工具调用的AI API(如OpenAI)""" + + def __init__(self): + super().__init__() + self.adapter_type = AdapterType.FUNCTION_CALLING + + def supports_native_tools(self) -> bool: + """支持原生工具调用""" + return True + + def format_tools_for_prompt( + self, + tools: List[Dict[str, Any]], + user_message: str + ) -> str: + """ + Function Calling模式下,工具通过API参数传递,不需要修改提示词 + + Returns: + 原始用户消息 + """ + return user_message + + def get_tools_for_api(self, tools: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """ + 获取适用于API的工具格式 + + Args: + tools: MCP工具列表 + + Returns: + 适用于OpenAI Function Calling的工具格式 + """ + return tools + + def parse_tool_calls(self, ai_response: Any) -> ToolCallResult: + """ + 从AI响应中解析工具调用(Function Calling格式) + + Args: + ai_response: AI响应对象(通常是OpenAI的ChatCompletion对象) + + Returns: + 解析结果 + """ + + try: + # 处理不同类型的响应 + if isinstance(ai_response, dict): + # 字典格式(OpenAI API响应) + message = ai_response.get("choices", [{}])[0].get("message", {}) + tool_calls = message.get("tool_calls", []) + content = message.get("content", "") + + elif hasattr(ai_response, "choices"): + # 对象格式(OpenAI SDK响应) + message = ai_response.choices[0].message + tool_calls = getattr(message, "tool_calls", None) or [] + content = getattr(message, "content", "") or "" + + # 转换为字典格式 + if tool_calls: + tool_calls = [ + { + "id": tc.id, + "type": tc.type, + "function": { + "name": tc.function.name, + "arguments": tc.function.arguments + } + } + for tc in tool_calls + ] + else: + # 字符串格式(降级为文本响应) + return ToolCallResult( + tool_calls=[], + raw_response=str(ai_response), + has_tool_calls=False + ) + + has_tool_calls = len(tool_calls) > 0 + + if has_tool_calls: + logger.info(f"✅ Function Calling模式解析出 {len(tool_calls)} 个工具调用") + for tc in tool_calls: + logger.info(f" - {tc['function']['name']}") + + return ToolCallResult( + tool_calls=tool_calls, + raw_response=content or "", + has_tool_calls=has_tool_calls, + needs_continuation=has_tool_calls + ) + + except Exception as e: + logger.error(f"❌ 解析Function Calling响应失败: {e}", exc_info=True) + return ToolCallResult( + tool_calls=[], + raw_response=str(ai_response), + has_tool_calls=False + ) + + def build_continuation_prompt( + self, + original_message: str, + ai_response: str, + tool_results: List[Dict[str, Any]] + ) -> str: + """ + 构建包含工具结果的继续对话提示词 + + 在Function Calling模式下,这通常不需要,因为工具结果会作为消息历史的一部分 + """ + # Function Calling模式下通常通过消息历史传递工具结果 + # 这里提供一个降级方案 + results_text = "\n\n".join([ + f"工具 {r['name']} 的结果:\n{r['content']}" + for r in tool_results + ]) + + return f"{original_message}\n\n工具执行结果:\n{results_text}\n\n请基于以上工具结果回答用户的问题。" + + def build_messages_with_tool_results( + self, + messages: List[Dict[str, Any]], + tool_calls: List[Dict[str, Any]], + tool_results: List[Dict[str, Any]] + ) -> List[Dict[str, Any]]: + """ + 构建包含工具结果的消息历史(Function Calling标准格式) + + Args: + messages: 原始消息历史 + tool_calls: AI的工具调用 + tool_results: 工具执行结果 + + Returns: + 更新后的消息历史 + """ + + new_messages = messages.copy() + + # 添加助手的工具调用消息 + new_messages.append({ + "role": "assistant", + "content": None, + "tool_calls": tool_calls + }) + + # 添加工具结果消息 + for result in tool_results: + new_messages.append({ + "role": "tool", + "tool_call_id": result.get("tool_call_id", ""), + "name": result.get("name", ""), + "content": result.get("content", "") + }) + + return new_messages \ No newline at end of file diff --git a/backend/app/mcp/adapters/prompt_injection.py b/backend/app/mcp/adapters/prompt_injection.py new file mode 100644 index 0000000..15ea7b8 --- /dev/null +++ b/backend/app/mcp/adapters/prompt_injection.py @@ -0,0 +1,274 @@ +"""提示词注入适配器 - 最通用的MCP工具调用方式""" + +import re +import json +from typing import Dict, Any, List +from app.mcp.adapters.base import BaseMCPAdapter, AdapterType, ToolCallResult +from app.logger import get_logger + +logger = get_logger(__name__) + + +class PromptInjectionAdapter(BaseMCPAdapter): + """提示词注入适配器 - 将工具转换为文本描述,通过提示词引导AI调用""" + + def __init__(self): + super().__init__() + self.adapter_type = AdapterType.PROMPT_INJECTION + + def format_tools_for_prompt( + self, + tools: List[Dict[str, Any]], + user_message: str + ) -> str: + """将工具列表注入到提示词中""" + + if not tools: + return user_message + + # 格式化工具描述 + tool_descriptions = self._format_tools_as_text(tools) + + # 构建增强的提示词 + enhanced_prompt = f"""你现在可以使用以下工具来帮助回答用户的问题。 + +## 可用工具 + +{tool_descriptions} + +## 工具使用说明 + +当你需要使用工具时,请按以下XML格式输出(可以一次调用多个工具): + + + +工具名称 + +{{ + "参数名1": "参数值1", + "参数名2": "参数值2" +}} + + + + +## 重要提示 + +1. 只有在确实需要使用工具时才调用工具 +2. 参数必须是有效的JSON格式 +3. 仔细检查参数是否符合工具的要求 +4. 可以在一个标签内包含多个 +5. 调用工具后,你会收到工具的执行结果,然后需要基于结果继续回答 + +--- + +用户问题:{user_message} + +请分析问题,判断是否需要使用工具。如果需要,先输出工具调用,然后等待结果。如果不需要,直接回答问题。""" + + return enhanced_prompt + + def _format_tools_as_text(self, tools: List[Dict[str, Any]]) -> str: + """将工具格式化为可读的文本描述""" + lines = [] + + for i, tool in enumerate(tools, 1): + func = tool.get("function", {}) + name = func.get("name", "unknown") + description = func.get("description", "无描述") + parameters = func.get("parameters", {}) + + lines.append(f"### {i}. {name}") + lines.append(f"**描述**: {description}") + lines.append("") + + # 格式化参数信息 + if parameters and "properties" in parameters: + lines.append("**参数**:") + properties = parameters.get("properties", {}) + required = parameters.get("required", []) + + for param_name, param_info in properties.items(): + param_type = param_info.get("type", "string") + param_desc = param_info.get("description", "") + is_required = "必填" if param_name in required else "可选" + + lines.append(f" - `{param_name}` ({param_type}, {is_required}): {param_desc}") + lines.append("") + + # 添加示例 + if "example" in func: + lines.append(f"**示例**: {json.dumps(func['example'], ensure_ascii=False)}") + lines.append("") + + return "\n".join(lines) + + def parse_tool_calls(self, ai_response) -> ToolCallResult: + """从AI响应中解析工具调用""" + + tool_calls = [] + + try: + # 处理不同类型的响应 + if isinstance(ai_response, dict): + # 如果是字典,提取content字段 + ai_response = ai_response.get("choices", [{}])[0].get("message", {}).get("content", "") + if not ai_response: + return ToolCallResult( + tool_calls=[], + raw_response="", + has_tool_calls=False + ) + elif not isinstance(ai_response, str): + # 转换为字符串 + ai_response = str(ai_response) + + # 使用正则提取 标签内容 + tool_calls_match = re.search( + r'(.*?)', + ai_response, + re.DOTALL | re.IGNORECASE + ) + + if not tool_calls_match: + # 没有找到工具调用 + return ToolCallResult( + tool_calls=[], + raw_response=ai_response, + has_tool_calls=False + ) + + tool_calls_content = tool_calls_match.group(1) + + # 提取所有 标签 + tool_call_pattern = r'(.*?)' + tool_call_matches = re.findall( + tool_call_pattern, + tool_calls_content, + re.DOTALL | re.IGNORECASE + ) + + for i, tool_call_content in enumerate(tool_call_matches): + # 提取工具名称 + name_match = re.search( + r'(.*?)', + tool_call_content, + re.DOTALL | re.IGNORECASE + ) + + # 提取参数 + args_match = re.search( + r'(.*?)', + tool_call_content, + re.DOTALL | re.IGNORECASE + ) + + if name_match and args_match: + tool_name = name_match.group(1).strip() + arguments_str = args_match.group(1).strip() + + try: + # 解析JSON参数 + arguments = json.loads(arguments_str) + + # 构建标准格式的工具调用 + tool_calls.append({ + "id": f"call_{i}", + "type": "function", + "function": { + "name": tool_name, + "arguments": json.dumps(arguments, ensure_ascii=False) + } + }) + + logger.info(f"✅ 解析工具调用: {tool_name}") + + except json.JSONDecodeError as e: + logger.error(f"❌ 解析工具参数失败: {arguments_str}, 错误: {e}") + continue + + has_tool_calls = len(tool_calls) > 0 + + if has_tool_calls: + logger.info(f"✅ 从响应中解析出 {len(tool_calls)} 个工具调用") + + return ToolCallResult( + tool_calls=tool_calls, + raw_response=ai_response, + has_tool_calls=has_tool_calls, + needs_continuation=has_tool_calls + ) + + except Exception as e: + logger.error(f"❌ 解析工具调用失败: {e}", exc_info=True) + return ToolCallResult( + tool_calls=[], + raw_response=ai_response, + has_tool_calls=False + ) + + def build_continuation_prompt( + self, + original_message: str, + ai_response: str, + tool_results: List[Dict[str, Any]] + ) -> str: + """构建包含工具结果的继续对话提示词""" + + # 格式化工具结果 + results_text = self._format_tool_results(tool_results) + + continuation = f"""你之前尝试使用工具来回答用户的问题。 + +原始问题:{original_message} + +你的工具调用: +{self._extract_tool_calls_text(ai_response)} + +工具执行结果: +{results_text} + +现在,请基于这些工具的执行结果,给出完整、详细的回答。不要重复调用工具,直接使用已有的结果来回答用户的问题。""" + + return continuation + + def _format_tool_results(self, tool_results: List[Dict[str, Any]]) -> str: + """格式化工具结果为可读文本""" + lines = [] + + for i, result in enumerate(tool_results, 1): + tool_name = result.get("name", "unknown") + success = result.get("success", False) + content = result.get("content", "") + + status = "✅ 成功" if success else "❌ 失败" + lines.append(f"{i}. {tool_name} - {status}") + + if success: + # 尝试美化JSON内容 + try: + if isinstance(content, str): + content_obj = json.loads(content) + content = json.dumps(content_obj, ensure_ascii=False, indent=2) + except: + pass + lines.append(f"```\n{content}\n```") + else: + error = result.get("error", "未知错误") + lines.append(f"错误信息: {error}") + + lines.append("") + + return "\n".join(lines) + + def _extract_tool_calls_text(self, ai_response: str) -> str: + """从AI响应中提取工具调用部分的文本""" + match = re.search( + r'(.*?)', + ai_response, + re.DOTALL | re.IGNORECASE + ) + + if match: + return match.group(0) + return "(未找到工具调用)" \ No newline at end of file diff --git a/backend/app/mcp/adapters/universal.py b/backend/app/mcp/adapters/universal.py new file mode 100644 index 0000000..aa2be30 --- /dev/null +++ b/backend/app/mcp/adapters/universal.py @@ -0,0 +1,353 @@ +"""通用MCP适配器 - 自动检测API能力并选择最佳适配器""" + +import time +import asyncio +from typing import Dict, Any, List, Optional +from datetime import datetime, timedelta +from dataclasses import dataclass + +from app.mcp.adapters.base import BaseMCPAdapter, AdapterType, ToolCallResult +from app.mcp.adapters.prompt_injection import PromptInjectionAdapter +from app.mcp.adapters.function_calling import FunctionCallingAdapter +from app.logger import get_logger + +logger = get_logger(__name__) + + +@dataclass +class APICapability: + """API能力检测结果""" + supports_function_calling: bool + tested_at: datetime + test_duration_ms: float + error_message: Optional[str] = None + + +class UniversalMCPAdapter: + """ + 通用MCP适配器管理器 + + 功能: + 1. 自动检测API是否支持Function Calling + 2. 缓存检测结果 + 3. 自动降级策略:FC失败时切换到提示词注入 + 4. 提供统一接口 + """ + + def __init__( + self, + cache_ttl_hours: int = 24, + enable_auto_fallback: bool = True + ): + """ + 初始化通用适配器 + + Args: + cache_ttl_hours: 能力检测缓存时长(小时) + enable_auto_fallback: 是否启用自动降级 + """ + # 适配器实例 + self.adapters = { + AdapterType.FUNCTION_CALLING: FunctionCallingAdapter(), + AdapterType.PROMPT_INJECTION: PromptInjectionAdapter() + } + + # API能力缓存: {api_identifier: APICapability} + self._capability_cache: Dict[str, APICapability] = {} + self._cache_ttl = timedelta(hours=cache_ttl_hours) + self._cache_lock = asyncio.Lock() + + # 配置 + self._enable_auto_fallback = enable_auto_fallback + + logger.info( + f"✅ UniversalMCPAdapter初始化完成 " + f"(缓存TTL={cache_ttl_hours}小时, 自动降级={'开启' if enable_auto_fallback else '关闭'})" + ) + + async def get_adapter( + self, + api_identifier: str, + test_function: Optional[callable] = None + ) -> BaseMCPAdapter: + """ + 获取适合当前API的适配器 + + Args: + api_identifier: API标识符(如"openai_official", "azure_openai"等) + test_function: 可选的测试函数,用于检测API能力 + + Returns: + 最适合的适配器实例 + """ + + # 检查缓存 + capability = await self._get_cached_capability(api_identifier) + + if capability is None and test_function: + # 缓存未命中,执行检测 + capability = await self._detect_capability(api_identifier, test_function) + + # 选择适配器 + if capability and capability.supports_function_calling: + logger.info(f"🎯 使用Function Calling适配器: {api_identifier}") + return self.adapters[AdapterType.FUNCTION_CALLING] + else: + logger.info(f"🎯 使用提示词注入适配器: {api_identifier}") + return self.adapters[AdapterType.PROMPT_INJECTION] + + async def _get_cached_capability( + self, + api_identifier: str + ) -> Optional[APICapability]: + """获取缓存的能力检测结果""" + + async with self._cache_lock: + if api_identifier not in self._capability_cache: + return None + + capability = self._capability_cache[api_identifier] + + # 检查是否过期 + if datetime.now() - capability.tested_at > self._cache_ttl: + logger.info(f"⏰ API能力缓存过期: {api_identifier}") + del self._capability_cache[api_identifier] + return None + + logger.debug(f"🎯 API能力缓存命中: {api_identifier}") + return capability + + async def _detect_capability( + self, + api_identifier: str, + test_function: callable + ) -> APICapability: + """ + 检测API能力 + + Args: + api_identifier: API标识符 + test_function: 测试函数,应该尝试使用Function Calling + + Returns: + 能力检测结果 + """ + + logger.info(f"🔍 开始检测API能力: {api_identifier}") + start_time = time.time() + + try: + # 调用测试函数 + result = await test_function() + + # 判断是否成功 + supports_fc = self._is_function_calling_response(result) + + duration_ms = (time.time() - start_time) * 1000 + + capability = APICapability( + supports_function_calling=supports_fc, + tested_at=datetime.now(), + test_duration_ms=duration_ms + ) + + # 缓存结果 + async with self._cache_lock: + self._capability_cache[api_identifier] = capability + + status = "✅ 支持" if supports_fc else "❌ 不支持" + logger.info( + f"{status} Function Calling: {api_identifier} " + f"(耗时: {duration_ms:.2f}ms)" + ) + + return capability + + except Exception as e: + duration_ms = (time.time() - start_time) * 1000 + + logger.warning( + f"⚠️ API能力检测失败: {api_identifier}, 错误: {e}, " + f"将使用提示词注入模式" + ) + + capability = APICapability( + supports_function_calling=False, + tested_at=datetime.now(), + test_duration_ms=duration_ms, + error_message=str(e) + ) + + # 缓存失败结果(避免重复测试) + async with self._cache_lock: + self._capability_cache[api_identifier] = capability + + return capability + + def _is_function_calling_response(self, response: Any) -> bool: + """ + 判断响应是否是Function Calling格式 + + Args: + response: API响应 + + Returns: + 是否支持Function Calling + """ + + try: + # 检查字典格式 + if isinstance(response, dict): + message = response.get("choices", [{}])[0].get("message", {}) + return "tool_calls" in message or "function_call" in message + + # 检查对象格式(OpenAI SDK) + if hasattr(response, "choices"): + message = response.choices[0].message + return hasattr(message, "tool_calls") or hasattr(message, "function_call") + + return False + + except Exception: + return False + + async def call_with_fallback( + self, + api_identifier: str, + tools: List[Dict[str, Any]], + user_message: str, + call_function: callable, + test_function: Optional[callable] = None + ) -> ToolCallResult: + """ + 带降级策略的工具调用 + + Args: + api_identifier: API标识符 + tools: MCP工具列表 + user_message: 用户消息 + call_function: 实际调用API的函数 + test_function: 可选的测试函数 + + Returns: + 工具调用结果 + """ + + # 获取适配器 + adapter = await self.get_adapter(api_identifier, test_function) + + # 首次尝试 + try: + if adapter.supports_native_tools(): + # Function Calling模式 + logger.info("🚀 尝试使用Function Calling模式") + result = await self._try_function_calling( + tools, user_message, call_function, adapter + ) + else: + # 提示词注入模式 + logger.info("🚀 使用提示词注入模式") + result = await self._try_prompt_injection( + tools, user_message, call_function, adapter + ) + + return result + + except Exception as e: + logger.error(f"❌ 工具调用失败: {e}") + + # 自动降级 + if self._enable_auto_fallback and adapter.supports_native_tools(): + logger.warning("⚠️ Function Calling失败,降级到提示词注入模式") + + # 更新缓存,标记为不支持 + async with self._cache_lock: + self._capability_cache[api_identifier] = APICapability( + supports_function_calling=False, + tested_at=datetime.now(), + test_duration_ms=0, + error_message=str(e) + ) + + # 使用提示词注入重试 + fallback_adapter = self.adapters[AdapterType.PROMPT_INJECTION] + return await self._try_prompt_injection( + tools, user_message, call_function, fallback_adapter + ) + + raise + + async def _try_function_calling( + self, + tools: List[Dict[str, Any]], + user_message: str, + call_function: callable, + adapter: FunctionCallingAdapter + ) -> ToolCallResult: + """尝试Function Calling模式""" + + # Function Calling不需要修改提示词 + response = await call_function( + message=user_message, + tools_param=tools, + tool_choice_param="auto" + ) + + return adapter.parse_tool_calls(response) + + async def _try_prompt_injection( + self, + tools: List[Dict[str, Any]], + user_message: str, + call_function: callable, + adapter: PromptInjectionAdapter + ) -> ToolCallResult: + """尝试提示词注入模式""" + + # 注入工具到提示词 + enhanced_prompt = adapter.format_tools_for_prompt(tools, user_message) + + # 调用API(不传tools参数) + response = await call_function( + message=enhanced_prompt, + tools_param=None, + tool_choice_param=None + ) + + # 从文本响应中解析工具调用 + return adapter.parse_tool_calls(response) + + def clear_cache(self, api_identifier: Optional[str] = None): + """ + 清理能力缓存 + + Args: + api_identifier: 可选,只清理特定API的缓存 + """ + if api_identifier: + if api_identifier in self._capability_cache: + del self._capability_cache[api_identifier] + logger.info(f"🧹 已清理API能力缓存: {api_identifier}") + else: + self._capability_cache.clear() + logger.info("🧹 已清理所有API能力缓存") + + def get_cache_stats(self) -> Dict[str, Any]: + """获取缓存统计信息""" + return { + "total_cached": len(self._capability_cache), + "cache_ttl_hours": self._cache_ttl.total_seconds() / 3600, + "cached_apis": [ + { + "api_identifier": api_id, + "supports_fc": cap.supports_function_calling, + "tested_at": cap.tested_at.isoformat(), + "test_duration_ms": cap.test_duration_ms + } + for api_id, cap in self._capability_cache.items() + ] + } + + +# 全局单例 +universal_mcp_adapter = UniversalMCPAdapter() \ No newline at end of file diff --git a/backend/app/services/ai_service.py b/backend/app/services/ai_service.py index 62cd8b0..49888ce 100644 --- a/backend/app/services/ai_service.py +++ b/backend/app/services/ai_service.py @@ -4,6 +4,7 @@ from openai import AsyncOpenAI from anthropic import AsyncAnthropic from app.config import settings as app_settings from app.logger import get_logger +from app.mcp.adapters import UniversalMCPAdapter, PromptInjectionAdapter import httpx import json import hashlib @@ -118,7 +119,8 @@ class AIService: api_base_url: Optional[str] = None, default_model: Optional[str] = None, default_temperature: Optional[float] = None, - default_max_tokens: Optional[int] = None + default_max_tokens: Optional[int] = None, + enable_mcp_adapter: bool = True ): """ 初始化AI客户端(优化并发性能) @@ -137,6 +139,15 @@ class AIService: self.default_temperature = default_temperature or app_settings.default_temperature self.default_max_tokens = default_max_tokens or app_settings.default_max_tokens + # 初始化MCP适配器 + self.enable_mcp_adapter = enable_mcp_adapter + if enable_mcp_adapter: + self.mcp_adapter = UniversalMCPAdapter() + logger.info("✅ MCP通用适配器已启用") + else: + self.mcp_adapter = None + logger.info("⚠️ MCP适配器已禁用") + # 初始化OpenAI客户端(使用HTTP客户端池) openai_key = api_key if api_provider == "openai" else app_settings.openai_api_key if openai_key: @@ -396,7 +407,7 @@ class AIService: tools: Optional[List[Dict[str, Any]]] = None, tool_choice: Optional[str] = None ) -> Dict[str, Any]: - """使用OpenAI生成文本(支持工具调用)""" + """使用OpenAI生成文本(支持工具调用,集成MCP适配器)""" if not self.openai_http_client: raise ValueError("OpenAI客户端未初始化,请检查API key配置") @@ -405,8 +416,101 @@ class AIService: messages.append({"role": "system", "content": system_prompt}) messages.append({"role": "user", "content": prompt}) + # 如果启用了MCP适配器且有工具,使用适配器处理 + if self.enable_mcp_adapter and self.mcp_adapter and tools: + logger.info(f"🎯 使用MCP适配器处理工具调用") + + # 生成API标识符 + api_identifier = f"openai_{self.openai_base_url or 'default'}" + + # 定义API调用函数 + async def call_api(message: str, tools_param: Optional[List] = None, tool_choice_param: Optional[str] = None): + """实际调用OpenAI API的函数""" + call_messages = messages.copy() + call_messages[-1]["content"] = message + + url = f"{self.openai_base_url}/chat/completions" + headers = { + "Authorization": f"Bearer {self.openai_api_key}", + "Content-Type": "application/json" + } + payload = { + "model": model, + "messages": call_messages, + "temperature": temperature, + "max_tokens": max_tokens + } + + # 只在tools_param不为None时添加工具参数 + if tools_param is not None: + # 清理工具定义,移除$schema字段(某些API不支持) + cleaned_tools = [] + for tool in tools_param: + cleaned_tool = tool.copy() + if "function" in cleaned_tool and "parameters" in cleaned_tool["function"]: + params = cleaned_tool["function"]["parameters"].copy() + # 移除$schema字段 + params.pop("$schema", None) + cleaned_tool["function"]["parameters"] = params + cleaned_tools.append(cleaned_tool) + + payload["tools"] = cleaned_tools + if tool_choice_param: + payload["tool_choice"] = tool_choice_param + + response = await self.openai_http_client.post(url, headers=headers, json=payload) + response.raise_for_status() + return response.json() + + # 定义测试函数(检测API是否支持Function Calling) + async def test_fc(): + """测试Function Calling支持""" + test_tools = [{ + "type": "function", + "function": { + "name": "test_function", + "description": "测试函数", + "parameters": {"type": "object", "properties": {}} + } + }] + try: + result = await call_api("测试", tools_param=test_tools, tool_choice_param="none") + return result + except Exception as e: + logger.debug(f"Function Calling测试失败: {e}") + raise + + try: + # 使用适配器处理(自动检测、降级、缓存) + result = await self.mcp_adapter.call_with_fallback( + api_identifier=api_identifier, + tools=tools, + user_message=prompt, + call_function=call_api, + test_function=test_fc + ) + + # 转换结果格式 + if result.has_tool_calls: + return { + "tool_calls": result.tool_calls, + "content": result.raw_response, + "finish_reason": "tool_calls" + } + else: + return { + "content": result.raw_response, + "finish_reason": "stop" + } + + except Exception as e: + logger.error(f"❌ MCP适配器调用失败: {str(e)}") + # 降级到原始实现 + logger.warning("⚠️ 降级到原始OpenAI调用") + + # 原始实现(无适配器或降级) try: - logger.info(f"🔵 开始调用OpenAI API(支持工具调用)") + logger.info(f"🔵 开始调用OpenAI API(原始模式)") logger.info(f" - 模型: {model}") logger.info(f" - 工具数量: {len(tools) if tools else 0}")