feat: 重构MCP功能和AI服务提供者架构

This commit is contained in:
xiamuceer-j
2026-01-09 17:13:19 +08:00
parent f3c224261d
commit 77c5489ff8
49 changed files with 4763 additions and 4307 deletions
@@ -71,7 +71,18 @@ class AnthropicClient:
temperature: float,
max_tokens: int,
system_prompt: Optional[str] = None,
) -> AsyncGenerator[str, None]:
tools: Optional[list] = None,
tool_choice: Optional[str] = None,
) -> AsyncGenerator[Dict[str, Any], None]:
"""
流式生成,支持工具调用
Yields:
Dict with keys:
- content: str - 文本内容块
- tool_calls: list - 工具调用列表(如果有)
- done: bool - 是否结束
"""
kwargs = {
"model": model,
"max_tokens": max_tokens,
@@ -80,12 +91,42 @@ class AnthropicClient:
}
if system_prompt:
kwargs["system"] = system_prompt
if tools:
kwargs["tools"] = tools
if tool_choice == "required":
kwargs["tool_choice"] = {"type": "any"}
elif tool_choice == "auto":
kwargs["tool_choice"] = {"type": "auto"}
try:
async with self.client.messages.stream(**kwargs) as stream:
try:
async for text in stream.text_stream:
yield text
tool_calls = []
async for chunk in stream:
# 处理不同类型的块
if chunk.type == "text_delta":
yield {"content": chunk.text}
elif chunk.type == "tool_use_delta":
# 工具调用增量
if not tool_calls or tool_calls[-1].get("id") != chunk.id:
tool_calls.append({
"id": chunk.id,
"type": "function",
"function": {
"name": chunk.name,
"arguments": ""
}
})
# 追加参数
if tool_calls[-1]["function"]["arguments"] is None:
tool_calls[-1]["function"]["arguments"] = ""
tool_calls[-1]["function"]["arguments"] += chunk.input_gets_new_text or ""
elif chunk.type == "message_delta":
if chunk.stop_reason:
# 流结束
if tool_calls:
yield {"tool_calls": tool_calls}
yield {"done": True, "finish_reason": chunk.stop_reason}
except GeneratorExit:
# 生成器被关闭,这是正常的清理过程
logger.debug("Anthropic 流式响应生成器被关闭(GeneratorExit)")
@@ -111,7 +111,18 @@ class GeminiClient:
temperature: float,
max_tokens: int,
system_prompt: Optional[str] = None,
) -> AsyncGenerator[str, None]:
tools: Optional[list] = None,
tool_choice: Optional[str] = None,
) -> AsyncGenerator[Dict[str, Any], None]:
"""
流式生成,支持工具调用
Yields:
Dict with keys:
- content: str - 文本内容块
- tool_calls: list - 工具调用列表(如果有)
- done: bool - 是否结束
"""
url = f"{self.base_url}/models/{model}:streamGenerateContent?key={self.api_key}&alt=sse"
contents = []
@@ -125,6 +136,8 @@ class GeminiClient:
}
if system_prompt:
payload["systemInstruction"] = {"parts": [{"text": system_prompt}]}
if tools:
payload["tools"] = self._convert_tools_to_gemini(tools)
try:
async with self.client.stream("POST", url, json=payload) as response:
@@ -139,9 +152,26 @@ class GeminiClient:
if candidates and len(candidates) > 0:
parts = candidates[0].get("content", {}).get("parts", [])
if parts and len(parts) > 0:
text = parts[0].get("text", "")
text = ""
function_calls = []
for part in parts:
if "text" in part:
text += part["text"]
elif "functionCall" in part:
fc = part["functionCall"]
function_calls.append({
"id": f"call_{fc['name']}",
"type": "function",
"function": {
"name": fc["name"],
"arguments": fc.get("args", {})
}
})
if text:
yield text
yield {"content": text}
if function_calls:
yield {"tool_calls": function_calls}
except json.JSONDecodeError:
continue
except GeneratorExit:
@@ -86,8 +86,21 @@ class OpenAIClient(BaseAIClient):
model: str,
temperature: float,
max_tokens: int,
) -> AsyncGenerator[str, None]:
payload = self._build_payload(messages, model, temperature, max_tokens, stream=True)
tools: Optional[list] = None,
tool_choice: Optional[str] = None,
) -> AsyncGenerator[Dict[str, Any], None]:
"""
流式生成,支持工具调用
Yields:
Dict with keys:
- content: str - 文本内容块
- tool_calls: list - 工具调用列表(如果有)
- done: bool - 是否结束
"""
payload = self._build_payload(messages, model, temperature, max_tokens, tools, tool_choice, stream=True)
tool_calls_buffer = {} # 收集工具调用块
try:
async with await self._request_with_retry("POST", "/chat/completions", payload, stream=True) as response:
@@ -97,14 +110,38 @@ class OpenAIClient(BaseAIClient):
if line.startswith("data: "):
data_str = line[6:]
if data_str.strip() == "[DONE]":
# 流结束,检查是否有工具调用需要处理
if tool_calls_buffer:
yield {"tool_calls": list(tool_calls_buffer.values()), "done": True}
yield {"done": True}
break
try:
data = json.loads(data_str)
choices = data.get("choices", [])
if choices and len(choices) > 0:
content = choices[0].get("delta", {}).get("content", "")
delta = choices[0].get("delta", {})
content = delta.get("content", "")
# 检查工具调用
tc_list = delta.get("tool_calls")
if tc_list:
for tc in tc_list:
index = tc.get("index", 0)
if index not in tool_calls_buffer:
tool_calls_buffer[index] = tc
else:
existing = tool_calls_buffer[index]
# 合并 function.arguments
if "function" in tc and "function" in existing:
if tc["function"].get("arguments"):
existing["function"]["arguments"] = (
existing["function"].get("arguments", "") +
tc["function"]["arguments"]
)
if content:
yield content
yield {"content": content}
except json.JSONDecodeError:
continue
except GeneratorExit:
@@ -1,9 +1,12 @@
"""Anthropic Provider"""
from typing import Any, AsyncGenerator, Dict, List, Optional
from app.logger import get_logger
from app.services.ai_clients.anthropic_client import AnthropicClient
from .base_provider import BaseAIProvider
logger = get_logger(__name__)
class AnthropicProvider(BaseAIProvider):
"""Anthropic 提供商"""
@@ -39,7 +42,62 @@ class AnthropicProvider(BaseAIProvider):
temperature: float,
max_tokens: int,
system_prompt: Optional[str] = None,
tools: Optional[List[Dict]] = None,
tool_choice: Optional[str] = None,
user_id: Optional[str] = None,
) -> AsyncGenerator[str, None]:
# 如果有工具,使用真正的流式工具调用
if tools:
logger.debug(f"🔧 AnthropicProvider: 有 {len(tools)} 个工具,使用流式处理")
messages = [{"role": "user", "content": prompt}]
actual_tool_choice = tool_choice if tool_choice else "auto"
tool_calls_buffer = []
async for chunk in self.client.chat_completion_stream(
messages=messages,
model=model,
temperature=temperature,
max_tokens=max_tokens,
system_prompt=system_prompt,
tools=tools,
tool_choice=actual_tool_choice,
):
# 检查是否有工具调用
if chunk.get("tool_calls"):
tool_calls_buffer.extend(chunk["tool_calls"])
logger.debug(f"🔧 收到工具调用: {len(chunk['tool_calls'])}")
# 检查是否结束
if chunk.get("done"):
if tool_calls_buffer:
logger.info(f"🔧 流式结束,处理 {len(tool_calls_buffer)} 个工具调用")
from app.mcp import mcp_client
actual_user_id = user_id or ""
tool_results = await mcp_client.batch_call_tools(
user_id=actual_user_id,
tool_calls=tool_calls_buffer
)
# 将工具结果注入到上下文中
tool_context = mcp_client.build_tool_context(tool_results, format="markdown")
# 构建最终提示词,要求AI基于工具结果回答
final_prompt = f"{prompt}\n\n{tool_context}\n\n请基于以上工具查询结果,给出完整详细的回答。"
final_messages = [{"role": "user", "content": final_prompt}]
# 递归调用生成最终结果
async for final_chunk in self._generate_with_tools(
final_messages, model, temperature, max_tokens, system_prompt, tools, user_id
):
yield final_chunk
break
# 输出文本内容
if chunk.get("content"):
yield chunk["content"]
return
# 无工具时普通流式生成
messages = [{"role": "user", "content": prompt}]
async for chunk in self.client.chat_completion_stream(
messages=messages,
@@ -48,4 +106,56 @@ class AnthropicProvider(BaseAIProvider):
max_tokens=max_tokens,
system_prompt=system_prompt,
):
yield chunk
# 确保只 yield 字符串内容,避免 yield 字典导致类型错误
if isinstance(chunk, dict):
if chunk.get("content"):
yield chunk["content"]
else:
yield chunk
async def _generate_with_tools(
self,
messages: list,
model: str,
temperature: float,
max_tokens: int,
system_prompt: Optional[str] = None,
tools: list = None,
user_id: Optional[str] = None,
) -> AsyncGenerator[str, None]:
"""辅助方法:带工具的流式生成"""
tool_calls_buffer = []
async for chunk in self.client.chat_completion_stream(
messages=messages,
model=model,
temperature=temperature,
max_tokens=max_tokens,
system_prompt=system_prompt,
tools=tools,
tool_choice="auto",
):
if chunk.get("tool_calls"):
tool_calls_buffer.extend(chunk["tool_calls"])
logger.debug(f"🔧 _generate_with_tools 收到工具调用: {len(chunk['tool_calls'])}")
if chunk.get("done"):
if tool_calls_buffer:
from app.mcp import mcp_client
actual_user_id = user_id or ""
tool_results = await mcp_client.batch_call_tools(
user_id=actual_user_id,
tool_calls=tool_calls_buffer
)
tool_context = mcp_client.build_tool_context(tool_results, format="markdown")
messages.append({"role": "user", "content": f"{tool_context}\n\n请基于以上工具查询结果,给出完整详细的回答。"})
async for final_chunk in self._generate_with_tools(
messages, model, temperature, max_tokens, system_prompt, tools, user_id
):
yield final_chunk
break
if chunk.get("content"):
yield chunk["content"]
@@ -28,6 +28,9 @@ class BaseAIProvider(ABC):
temperature: float,
max_tokens: int,
system_prompt: Optional[str] = None,
tools: Optional[List[Dict]] = None,
tool_choice: Optional[str] = None,
user_id: Optional[str] = None,
) -> AsyncGenerator[str, None]:
"""流式生成"""
pass
@@ -1,8 +1,12 @@
"""Gemini Provider"""
from typing import Any, AsyncGenerator, Dict, List, Optional
from app.logger import get_logger
from app.services.ai_clients.gemini_client import GeminiClient
from .base_provider import BaseAIProvider
logger = get_logger(__name__)
class GeminiProvider(BaseAIProvider):
def __init__(self, client: GeminiClient):
@@ -36,7 +40,62 @@ class GeminiProvider(BaseAIProvider):
temperature: float,
max_tokens: int,
system_prompt: Optional[str] = None,
tools: Optional[List[Dict]] = None,
tool_choice: Optional[str] = None,
user_id: Optional[str] = None,
) -> AsyncGenerator[str, None]:
# 如果有工具,使用真正的流式工具调用
if tools:
logger.debug(f"🔧 GeminiProvider: 有 {len(tools)} 个工具,使用流式处理")
messages = [{"role": "user", "content": prompt}]
actual_tool_choice = tool_choice if tool_choice else "auto"
tool_calls_buffer = []
async for chunk in self.client.chat_completion_stream(
messages=messages,
model=model,
temperature=temperature,
max_tokens=max_tokens,
system_prompt=system_prompt,
tools=tools,
tool_choice=actual_tool_choice,
):
# 检查是否有工具调用
if chunk.get("tool_calls"):
tool_calls_buffer.extend(chunk["tool_calls"])
logger.debug(f"🔧 收到工具调用: {len(chunk['tool_calls'])}")
# 检查是否结束
if chunk.get("done"):
if tool_calls_buffer:
logger.info(f"🔧 流式结束,处理 {len(tool_calls_buffer)} 个工具调用")
from app.mcp import mcp_client
actual_user_id = user_id or ""
tool_results = await mcp_client.batch_call_tools(
user_id=actual_user_id,
tool_calls=tool_calls_buffer
)
# 将工具结果注入到上下文中
tool_context = mcp_client.build_tool_context(tool_results, format="markdown")
# 构建最终提示词,要求AI基于工具结果回答
final_prompt = f"{prompt}\n\n{tool_context}\n\n请基于以上工具查询结果,给出完整详细的回答。"
final_messages = [{"role": "user", "content": final_prompt}]
# 递归调用生成最终结果
async for final_chunk in self._generate_with_tools(
final_messages, model, temperature, max_tokens, system_prompt, tools, user_id
):
yield final_chunk
break
# 输出文本内容
if chunk.get("content"):
yield chunk["content"]
return
# 无工具时普通流式生成
messages = [{"role": "user", "content": prompt}]
async for chunk in self.client.chat_completion_stream(
messages=messages,
@@ -45,4 +104,56 @@ class GeminiProvider(BaseAIProvider):
max_tokens=max_tokens,
system_prompt=system_prompt,
):
yield chunk
# 确保只 yield 字符串内容,避免 yield 字典导致类型错误
if isinstance(chunk, dict):
if chunk.get("content"):
yield chunk["content"]
else:
yield chunk
async def _generate_with_tools(
self,
messages: list,
model: str,
temperature: float,
max_tokens: int,
system_prompt: Optional[str] = None,
tools: list = None,
user_id: Optional[str] = None,
) -> AsyncGenerator[str, None]:
"""辅助方法:带工具的流式生成"""
tool_calls_buffer = []
async for chunk in self.client.chat_completion_stream(
messages=messages,
model=model,
temperature=temperature,
max_tokens=max_tokens,
system_prompt=system_prompt,
tools=tools,
tool_choice="auto",
):
if chunk.get("tool_calls"):
tool_calls_buffer.extend(chunk["tool_calls"])
logger.debug(f"🔧 _generate_with_tools 收到工具调用: {len(chunk['tool_calls'])}")
if chunk.get("done"):
if tool_calls_buffer:
from app.mcp import mcp_client
actual_user_id = user_id or ""
tool_results = await mcp_client.batch_call_tools(
user_id=actual_user_id,
tool_calls=tool_calls_buffer
)
tool_context = mcp_client.build_tool_context(tool_results, format="markdown")
messages.append({"role": "user", "content": f"{tool_context}\n\n请基于以上工具查询结果,给出完整详细的回答。"})
async for final_chunk in self._generate_with_tools(
messages, model, temperature, max_tokens, system_prompt, tools, user_id
):
yield final_chunk
break
if chunk.get("content"):
yield chunk["content"]
@@ -1,9 +1,12 @@
"""OpenAI Provider"""
from typing import Any, AsyncGenerator, Dict, List, Optional
from app.logger import get_logger
from app.services.ai_clients.openai_client import OpenAIClient
from .base_provider import BaseAIProvider
logger = get_logger(__name__)
class OpenAIProvider(BaseAIProvider):
"""OpenAI 提供商"""
@@ -42,16 +45,117 @@ class OpenAIProvider(BaseAIProvider):
temperature: float,
max_tokens: int,
system_prompt: Optional[str] = None,
tools: Optional[List[Dict]] = None,
tool_choice: Optional[str] = None,
user_id: Optional[str] = None,
) -> AsyncGenerator[str, None]:
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
messages.append({"role": "user", "content": prompt})
# 如果有工具,使用真正的流式工具调用
if tools:
logger.debug(f"🔧 OpenAIProvider: 有 {len(tools)} 个工具,使用流式处理")
actual_tool_choice = tool_choice if tool_choice else "auto"
tool_calls_buffer = []
async for chunk in self.client.chat_completion_stream(
messages=messages,
model=model,
temperature=temperature,
max_tokens=max_tokens,
tools=tools,
tool_choice=actual_tool_choice,
):
# 检查是否有工具调用
if chunk.get("tool_calls"):
tool_calls_buffer.extend(chunk["tool_calls"])
logger.debug(f"🔧 收到工具调用: {len(chunk['tool_calls'])}")
# 检查是否结束
if chunk.get("done"):
if tool_calls_buffer:
logger.info(f"🔧 流式结束,处理 {len(tool_calls_buffer)} 个工具调用")
from app.mcp import mcp_client
actual_user_id = user_id or ""
tool_results = await mcp_client.batch_call_tools(
user_id=actual_user_id,
tool_calls=tool_calls_buffer
)
# 将工具结果注入到上下文中
tool_context = mcp_client.build_tool_context(tool_results, format="markdown")
# 构建最终提示词,要求AI基于工具结果回答
final_prompt = f"{prompt}\n\n{tool_context}\n\n请基于以上工具查询结果,给出完整详细的回答。"
final_messages = messages.copy()
final_messages.append({"role": "user", "content": final_prompt})
# 递归调用生成最终结果
async for final_chunk in self._generate_with_tools(
final_messages, model, temperature, max_tokens, tools, user_id
):
yield final_chunk
break
# 输出文本内容
if chunk.get("content"):
yield chunk["content"]
return
# 无工具时普通流式生成
async for chunk in self.client.chat_completion_stream(
messages=messages,
model=model,
temperature=temperature,
max_tokens=max_tokens,
):
yield chunk
# 确保只 yield 字符串内容,避免 yield 字典导致类型错误
if isinstance(chunk, dict):
if chunk.get("content"):
yield chunk["content"]
else:
yield chunk
async def _generate_with_tools(
self,
messages: list,
model: str,
temperature: float,
max_tokens: int,
tools: list,
user_id: Optional[str] = None,
) -> AsyncGenerator[str, None]:
"""辅助方法:带工具的流式生成(无tool_choice,AI自由决定)"""
async for chunk in self.client.chat_completion_stream(
messages=messages,
model=model,
temperature=temperature,
max_tokens=max_tokens,
tools=tools,
tool_choice="auto",
):
if chunk.get("tool_calls"):
from app.mcp import mcp_client
actual_user_id = user_id or ""
tool_results = await mcp_client.batch_call_tools(
user_id=actual_user_id,
tool_calls=chunk["tool_calls"]
)
tool_context = mcp_client.build_tool_context(tool_results, format="markdown")
# 再次调用获取最终回答
messages.append({"role": "user", "content": f"{tool_context}\n\n请基于以上工具查询结果,给出完整详细的回答。"})
async for final_chunk in self._generate_with_tools(
messages, model, temperature, max_tokens, tools, user_id
):
yield final_chunk
break
if chunk.get("done"):
break
if chunk.get("content"):
yield chunk["content"]
+391 -112
View File
@@ -1,4 +1,10 @@
"""AI服务封装 - 统一的AI接口"""
"""AI服务封装 - 统一的AI接口
重构后支持自动MCP工具加载:
- 所有AI方法在请求前自动检查用户MCP配置
- 如果有启用的MCP插件且有可用工具,自动发送tools
- 通过 auto_mcp 参数控制是否启用自动工具加载
"""
from typing import Optional, AsyncGenerator, List, Dict, Any, Union
from app.config import settings as app_settings
@@ -13,7 +19,6 @@ from app.services.ai_providers.anthropic_provider import AnthropicProvider
from app.services.ai_providers.gemini_provider import GeminiProvider
from app.services.ai_providers.base_provider import BaseAIProvider
from app.services.json_helper import clean_json_response, parse_json
from app.mcp.adapters.universal import universal_mcp_adapter
# 导出清理函数
cleanup_http_clients = cleanup_all_clients
@@ -22,7 +27,41 @@ logger = get_logger(__name__)
class AIService:
"""AI服务统一接口"""
"""
AI服务统一接口
MCP工具支持:
- 在创建服务时传入 user_id 和 db_session
- 根据用户MCP插件的enabled状态自动决定是否启用MCP
- 如果有任意一个MCP插件启用,则加载并使用工具
- 如果所有插件都关闭,则不使用任何MCP工具
- 通过 auto_mcp=False 可临时禁用自动工具加载
- 通过 mcp_max_rounds 控制工具调用轮数
- 通过 clear_mcp_cache() 可清理MCP工具缓存
MCP启用逻辑(backend/app/api/settings.py 中的 get_user_ai_service):
- 查询用户的所有MCP插件
- 如果有启用的插件 (enabled=True),则 enable_mcp=True
- 如果所有插件都关闭或没有插件,则 enable_mcp=False
使用示例:
# 创建支持MCP的AI服务(根据插件状态自动决定是否启用)
ai_service = create_user_ai_service_with_mcp(
api_provider="openai",
api_key="...",
user_id="user123",
db_session=db
)
# 自动加载MCP工具(如果有启用的插件)
result = await ai_service.generate_text(prompt="...")
# 临时禁用MCP工具
result = await ai_service.generate_text(prompt="...", auto_mcp=False)
# 自定义轮数
result = await ai_service.generate_text(prompt="...", mcp_max_rounds=3)
"""
def __init__(
self,
@@ -33,8 +72,11 @@ class AIService:
default_temperature: Optional[float] = None,
default_max_tokens: Optional[int] = None,
default_system_prompt: Optional[str] = None,
enable_mcp_adapter: bool = True,
config: Optional[AIClientConfig] = None,
# MCP支持参数
user_id: Optional[str] = None,
db_session: Optional[Any] = None,
enable_mcp: bool = True,
):
self.api_provider = api_provider or app_settings.default_ai_provider
self.default_model = default_model or app_settings.default_model
@@ -43,7 +85,12 @@ class AIService:
self.default_system_prompt = default_system_prompt
self.config = config or default_config
self.mcp_adapter = universal_mcp_adapter if enable_mcp_adapter else None
# MCP配置
self.user_id = user_id
self.db_session = db_session
self._enable_mcp = enable_mcp
self._cached_tools: Optional[List[Dict]] = None
self._tools_loaded = False
self._openai_provider: Optional[OpenAIProvider] = None
self._anthropic_provider: Optional[AnthropicProvider] = None
@@ -68,6 +115,36 @@ class AIService:
client = GeminiClient(api_key, api_base_url, self.config)
self._gemini_provider = GeminiProvider(client)
@property
def enable_mcp(self) -> bool:
"""是否启用MCP工具"""
return self._enable_mcp
@enable_mcp.setter
def enable_mcp(self, value: bool):
"""设置MCP启用状态,如果禁用则清理缓存"""
if value is False and self._enable_mcp is True:
# 从启用变为禁用,清理缓存
self.clear_mcp_cache()
self._enable_mcp = value
def clear_mcp_cache(self):
"""
清理MCP工具缓存
当禁用MCP时调用此方法,确保后续AI调用不会使用缓存的工具。
同时更新 _tools_loaded 状态,使下次调用时重新检查。
"""
if self._cached_tools is not None:
logger.info(f"🔧 清理MCP工具缓存,移除 {len(self._cached_tools)} 个工具")
self._cached_tools = None
else:
logger.debug(f"🔧 MCP工具缓存已经是空,无需清理")
# 更新加载状态,确保下次调用会重新检查
self._tools_loaded = False
logger.debug(f"🔧 MCP工具状态已重置: enable_mcp={self._enable_mcp}, _tools_loaded=False")
def _get_provider(self, provider: Optional[str] = None) -> BaseAIProvider:
"""获取对应的 Provider"""
p = provider or self.api_provider
@@ -79,6 +156,166 @@ class AIService:
return self._gemini_provider
raise ValueError(f"Provider {p} 未初始化")
async def _prepare_mcp_tools(self, auto_mcp: bool = True, force_refresh: bool = False) -> Optional[List[Dict]]:
"""
预处理MCP工具
检查用户MCP配置并加载可用工具。
结果会被缓存,避免重复加载。
Args:
auto_mcp: 是否自动加载MCP工具(来自调用方参数)
force_refresh: 是否强制刷新缓存
Returns:
- None: 无可用工具(未配置/未启用/加载失败)
- List[Dict]: OpenAI格式的工具列表
"""
# 前置条件检查
if not self._enable_mcp:
logger.debug(f"🔧 MCP工具未启用 (_enable_mcp=False)")
# 即使有缓存也清理掉,确保不使用
self._cached_tools = None
self._tools_loaded = False
return None
if not auto_mcp:
logger.debug(f"🔧 auto_mcp=False,跳过MCP工具加载")
# 即使有缓存也清理掉,确保不使用
self._cached_tools = None
self._tools_loaded = False
return None
if not self.user_id:
logger.debug(f"🔧 MCP工具加载跳过: user_id未设置")
return None
if not self.db_session:
logger.debug(f"🔧 MCP工具加载跳过: db_session未设置")
return None
# 使用缓存(只有 enable_mcp=True 时才使用缓存)
if self._tools_loaded and not force_refresh:
if self._cached_tools:
logger.debug(f"🔧 使用缓存的MCP工具 ({len(self._cached_tools)}个)")
return self._cached_tools
try:
from app.services.mcp_tools_loader import mcp_tools_loader
self._cached_tools = await mcp_tools_loader.get_user_tools(
user_id=self.user_id,
db_session=self.db_session,
use_cache=True,
force_refresh=force_refresh
)
self._tools_loaded = True
if self._cached_tools:
logger.info(f"🔧 已加载 {len(self._cached_tools)} 个MCP工具")
else:
logger.debug(f"📭 用户 {self.user_id} 没有可用的MCP工具")
return self._cached_tools
except Exception as e:
logger.warning(f"⚠️ 加载MCP工具失败: {e}")
self._tools_loaded = True
self._cached_tools = None
return None
async def _handle_tool_calls(
self,
original_prompt: str,
response: Dict[str, Any],
max_rounds: int = 2,
**kwargs
) -> Dict[str, Any]:
"""
处理AI返回的工具调用
Args:
original_prompt: 原始提示词
response: AI响应(包含tool_calls
max_rounds: 最大工具调用轮数
**kwargs: 传递给generate_text的其他参数
Returns:
最终的AI响应
"""
from app.mcp import mcp_client
tool_calls = response.get("tool_calls", [])
if not tool_calls or not self.user_id:
return response
result = {
"content": response.get("content", ""),
"tool_calls_made": 0,
"tools_used": [],
"finish_reason": response.get("finish_reason", ""),
"mcp_enhanced": True
}
prompt = original_prompt
for round_num in range(max_rounds):
logger.info(f"🔧 工具调用 - 第{round_num+1}/{max_rounds}轮,{len(tool_calls)}个工具")
try:
# 批量执行工具调用
tool_results = await mcp_client.batch_call_tools(
user_id=self.user_id,
tool_calls=tool_calls
)
# 记录使用的工具
for tc in tool_calls:
name = tc["function"]["name"]
if name not in result["tools_used"]:
result["tools_used"].append(name)
result["tool_calls_made"] += len(tool_calls)
# 构建工具上下文
tool_context = mcp_client.build_tool_context(tool_results, format="markdown")
# 更新提示词
if round_num == max_rounds - 1:
# 最后一轮,强制要求回答
prompt = f"{original_prompt}\n\n{tool_context}\n\n⚠️ 重要:请基于以上工具查询结果,给出完整详细的最终答案。不要再调用工具。"
tool_choice = "none"
else:
prompt = f"{original_prompt}\n\n{tool_context}\n\n请基于以上工具查询结果,继续完成任务。"
tool_choice = kwargs.get("tool_choice", "auto")
# 继续调用AI
prov = self._get_provider(kwargs.get("provider"))
next_response = await prov.generate(
prompt=prompt,
model=kwargs.get("model") or self.default_model,
temperature=kwargs.get("temperature") or self.default_temperature,
max_tokens=kwargs.get("max_tokens") or self.default_max_tokens,
system_prompt=kwargs.get("system_prompt") or self.default_system_prompt,
tools=None if tool_choice == "none" else self._cached_tools,
tool_choice=tool_choice,
)
tool_calls = next_response.get("tool_calls", [])
if not tool_calls:
# 没有更多工具调用,返回结果
result["content"] = next_response.get("content", "")
result["finish_reason"] = next_response.get("finish_reason", "stop")
break
except Exception as e:
logger.error(f"❌ 工具调用失败: {e}")
result["content"] = response.get("content", "")
result["finish_reason"] = "tool_error"
break
return result
async def generate_text(
self,
prompt: str,
@@ -89,10 +326,39 @@ class AIService:
system_prompt: Optional[str] = None,
tools: Optional[List[Dict]] = None,
tool_choice: Optional[str] = None,
auto_mcp: bool = True,
handle_tool_calls: bool = True,
mcp_max_rounds: Optional[int] = None,
) -> Dict[str, Any]:
"""生成文本"""
"""
生成文本(自动支持MCP工具)
Args:
prompt: 用户提示词
provider: AI提供商
model: 模型名称
temperature: 温度
max_tokens: 最大令牌数
system_prompt: 系统提示词
tools: 手动指定的工具列表(优先级高于自动加载)
tool_choice: 工具选择策略
auto_mcp: 是否自动加载MCP工具(默认True)
handle_tool_calls: 是否自动处理工具调用(默认True)
mcp_max_rounds: 最大工具调用轮数(None使用默认值3)
Returns:
包含生成内容的字典
"""
# 使用全局配置的MCP轮数(如果未指定)
if mcp_max_rounds is None:
mcp_max_rounds = app_settings.mcp_max_rounds
# 自动加载MCP工具
if auto_mcp and tools is None:
tools = await self._prepare_mcp_tools(auto_mcp=auto_mcp)
prov = self._get_provider(provider)
return await prov.generate(
response = await prov.generate(
prompt=prompt,
model=model or self.default_model,
temperature=temperature or self.default_temperature,
@@ -101,6 +367,22 @@ class AIService:
tools=tools,
tool_choice=tool_choice,
)
# 处理工具调用
if handle_tool_calls and response.get("tool_calls"):
return await self._handle_tool_calls(
original_prompt=prompt,
response=response,
provider=provider,
model=model,
temperature=temperature,
max_tokens=max_tokens,
system_prompt=system_prompt,
tool_choice=tool_choice,
max_rounds=mcp_max_rounds,
)
return response
async def generate_text_stream(
self,
@@ -110,15 +392,51 @@ class AIService:
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
system_prompt: Optional[str] = None,
tool_choice: Optional[str] = None,
auto_mcp: bool = True,
mcp_max_rounds: Optional[int] = None,
) -> AsyncGenerator[str, None]:
"""流式生成"""
"""
流式生成文本(自动支持MCP工具)
工具调用在 Provider 层通过流式方式处理,支持真正的流式工具调用。
Args:
prompt: 用户提示词
provider: AI提供商
model: 模型名称
temperature: 温度
max_tokens: 最大令牌数
system_prompt: 系统提示词
tool_choice: 工具选择策略("auto"/"none"/"required"
auto_mcp: 是否自动加载MCP工具
mcp_max_rounds: 最大工具调用轮数(None使用默认值3)
Yields:
生成的文本块
"""
logger.debug(f"🔧 generate_text_stream: auto_mcp={auto_mcp}, tool_choice={tool_choice}")
tools_to_use = None
# 加载MCP工具
if auto_mcp:
tools_to_use = await self._prepare_mcp_tools(auto_mcp=auto_mcp)
if tools_to_use:
logger.info(f"🔧 已获取 {len(tools_to_use)} 个MCP工具")
# 流式生成(Provider 层处理工具调用)
prov = self._get_provider(provider)
logger.debug(f"🔧 开始流式生成,provider={provider or self.api_provider}, tools_count={len(tools_to_use) if tools_to_use else 0}")
async for chunk in prov.generate_stream(
prompt=prompt,
model=model or self.default_model,
temperature=temperature or self.default_temperature,
max_tokens=max_tokens or self.default_max_tokens,
system_prompt=system_prompt or self.default_system_prompt,
tools=tools_to_use,
tool_choice=tool_choice,
user_id=self.user_id,
):
yield chunk
@@ -132,8 +450,25 @@ class AIService:
provider: Optional[str] = None,
model: Optional[str] = None,
expected_type: Optional[str] = None,
auto_mcp: bool = True,
) -> Union[Dict, List]:
"""带重试的 JSON 调用"""
"""
带重试的 JSON 调用(自动支持MCP工具)
Args:
prompt: 用户提示词
system_prompt: 系统提示词
max_retries: 最大重试次数
temperature: 温度
max_tokens: 最大令牌数
provider: AI提供商
model: 模型名称
expected_type: 期望的返回类型("object""array"
auto_mcp: 是否自动加载MCP工具
Returns:
解析后的JSON数据
"""
last_response = ""
for attempt in range(1, max_retries + 1):
@@ -146,6 +481,8 @@ class AIService:
temperature=temperature,
max_tokens=max_tokens,
system_prompt=system_prompt,
auto_mcp=auto_mcp,
handle_tool_calls=True,
)
last_response = result.get("content", "")
@@ -172,108 +509,6 @@ class AIService:
"""清洗 JSON 响应"""
return clean_json_response(text)
async def generate_text_with_mcp(
self,
prompt: str,
user_id: str,
db_session,
enable_mcp: bool = True,
max_tool_rounds: int = 3,
tool_choice: str = "auto",
**kwargs
) -> Dict[str, Any]:
"""支持MCP工具的AI文本生成"""
from app.services.mcp_tool_service import mcp_tool_service, MCPToolServiceError
result = {"content": "", "tool_calls_made": 0, "tools_used": [], "finish_reason": "", "mcp_enhanced": False}
tools = None
if enable_mcp:
try:
tools = await mcp_tool_service.get_user_enabled_tools(user_id=user_id, db_session=db_session)
if tools:
result["mcp_enhanced"] = True
except MCPToolServiceError:
tools = None
original_prompt = prompt # 保存原始提示词
for round_num in range(max_tool_rounds):
logger.debug(f"🔄 MCP工具调用 - 第{round_num+1}/{max_tool_rounds}")
logger.debug(f" prompt长度: {len(prompt)}, tools数量: {len(tools) if tools else 0}, tool_choice: {tool_choice}")
ai_response = await self.generate_text(prompt=prompt, tools=tools, tool_choice=tool_choice, **kwargs)
logger.debug(f" AI响应: finish_reason={ai_response.get('finish_reason')}, content长度={len(ai_response.get('content', ''))}")
tool_calls = ai_response.get("tool_calls", [])
if not tool_calls:
content = ai_response.get("content", "")
result["content"] = content
result["finish_reason"] = ai_response.get("finish_reason", "stop")
logger.debug(f" ✅ 无工具调用,返回内容长度: {len(content)}")
# 🔧 修复:如果内容为空且已经调用过工具,强制要求AI给出答案
if not content.strip() and result["tool_calls_made"] > 0:
logger.warning(f"⚠️ AI在工具调用后返回空内容,尝试强制要求回答(第{round_num+1}轮)")
prompt = f"{prompt}\n\n⚠️ 请注意:你必须基于以上工具查询结果,给出完整的回答。不要返回空内容。"
tools = None
tool_choice = "none" # 强制不使用工具
continue
break
logger.info(f"🔧 检测到 {len(tool_calls)} 个工具调用")
for idx, tc in enumerate(tool_calls):
logger.debug(f" 工具{idx+1}: {tc.get('function', {}).get('name')} - 参数: {tc.get('function', {}).get('arguments')}")
try:
logger.debug(f" 开始执行工具调用...")
tool_results = await mcp_tool_service.execute_tool_calls(user_id=user_id, tool_calls=tool_calls, db_session=db_session)
logger.debug(f" 工具执行完成,结果数量: {len(tool_results)}")
# 🔍 检查工具结果
for idx, tr in enumerate(tool_results):
success = tr.get("success", False)
content_preview = tr.get("content", "")[:200] if tr.get("content") else "None"
logger.debug(f" 工具结果[{idx}]: success={success}, content预览={content_preview}")
for tc in tool_calls:
name = tc["function"]["name"]
if name not in result["tools_used"]:
result["tools_used"].append(name)
result["tool_calls_made"] += len(tool_calls)
tool_context = await mcp_tool_service.build_tool_context(tool_results, format="markdown")
logger.debug(f" 工具上下文长度: {len(tool_context)}")
logger.debug(f" 工具上下文预览: {tool_context[:300] if len(tool_context) > 300 else tool_context}")
# 🔧 改进:在最后一轮时,明确要求AI给出完整答案
if round_num == max_tool_rounds - 1:
logger.info(f"⚠️ 最后一轮,强制要求AI给出最终答案")
prompt = f"{original_prompt}\n\n{tool_context}\n\n⚠️ 重要:这是最后一轮,请基于以上工具查询的参考资料,给出完整详细的最终答案。不要再调用工具。"
tool_choice = "none"
else:
prompt = f"{original_prompt}\n\n{tool_context}\n\n请基于以上工具查询结果,继续完成任务。"
logger.debug(f" 新prompt长度: {len(prompt)}")
tools = None # 工具调用后禁用工具列表,避免重复调用
logger.debug(f" ✅ 工具调用成功,准备下一轮")
except Exception as tool_error:
logger.error(f"❌ 工具调用执行失败: {tool_error}", exc_info=True)
logger.error(f" 错误类型: {type(tool_error).__name__}")
logger.error(f" AI响应内容: {ai_response.get('content', '')[:200]}")
result["content"] = ai_response.get("content", "")
result["finish_reason"] = "tool_error"
break
return result
# 全局实例
ai_service = AIService()
def create_user_ai_service(
api_provider: str,
@@ -284,7 +519,7 @@ def create_user_ai_service(
max_tokens: int,
system_prompt: Optional[str] = None,
) -> AIService:
"""创建用户 AI 服务"""
"""创建用户 AI 服务(不带MCP支持)"""
return AIService(
api_provider=api_provider,
api_key=api_key,
@@ -293,4 +528,48 @@ def create_user_ai_service(
default_temperature=temperature,
default_max_tokens=max_tokens,
default_system_prompt=system_prompt,
)
def create_user_ai_service_with_mcp(
api_provider: str,
api_key: str,
api_base_url: str,
model_name: str,
temperature: float,
max_tokens: int,
user_id: str,
db_session,
system_prompt: Optional[str] = None,
enable_mcp: bool = True,
) -> AIService:
"""
创建支持MCP的用户AI服务
Args:
api_provider: AI提供商
api_key: API密钥
api_base_url: API基础URL
model_name: 模型名称
temperature: 温度
max_tokens: 最大令牌数
user_id: 用户ID(用于加载MCP工具)
db_session: 数据库会话
system_prompt: 系统提示词
enable_mcp: 是否启用MCP工具
Returns:
配置好的AIService实例
"""
return AIService(
api_provider=api_provider,
api_key=api_key,
api_base_url=api_base_url,
default_model=model_name,
default_temperature=temperature,
default_max_tokens=max_tokens,
default_system_prompt=system_prompt,
user_id=user_id,
db_session=db_session,
enable_mcp=enable_mcp,
)
+10 -24
View File
@@ -269,25 +269,11 @@ class AutoCharacterService:
)
try:
# 调用AI分析(使用统一的JSON调用方法)
if enable_mcp and user_id:
result = await self.ai_service.generate_text_with_mcp(
prompt=prompt,
user_id=user_id,
db_session=db,
enable_mcp=True,
max_tool_rounds=2
)
content = result.get("content", "")
# 使用统一的JSON清洗方法
cleaned = self.ai_service._clean_json_response(content)
analysis = json.loads(cleaned)
else:
# 非MCP调用:使用带自动重试的JSON调用
analysis = await self.ai_service.call_with_json_retry(
prompt=prompt,
max_retries=3
)
# 使用统一的JSON调用方法(支持自动MCP工具加载
analysis = await self.ai_service.call_with_json_retry(
prompt=prompt,
max_retries=3,
)
logger.info(f" ✅ AI分析完成: needs_new_characters={analysis.get('needs_new_characters')}")
return analysis
@@ -364,16 +350,16 @@ class AutoCharacterService:
existing_characters=existing_chars_summary + careers_info,
plot_context="根据剧情需要引入的新角色",
character_specification=json.dumps(spec, ensure_ascii=False, indent=2),
mcp_references="" # 暂时不使用MCP增强
mcp_references="" # MCP工具通过AI服务自动加载
)
# 调用AI生成(禁用MCP,避免累积超时导致卡死)
logger.info(f"🔧 角色详情生成: enable_mcp={enable_mcp}")
# 调用AI生成
try:
# 🔧 优化:角色详情生成不使用MCP,只在分析阶段使用MCP
# 这样可以减少大量的外部工具调用,避免超时和卡死
character_data = await self.ai_service.call_with_json_retry(
prompt=prompt,
max_retries=2 # 减少重试次数以加快速度
max_retries=2, # 减少重试次数以加快速度
)
char_name = character_data.get('name', '未知')
@@ -292,25 +292,11 @@ class AutoOrganizationService:
)
try:
# 调用AI分析(使用统一的JSON调用方法)
if enable_mcp and user_id:
result = await self.ai_service.generate_text_with_mcp(
prompt=prompt,
user_id=user_id,
db_session=db,
enable_mcp=True,
max_tool_rounds=2
)
content = result.get("content", "")
# 使用统一的JSON清洗方法
cleaned = self.ai_service._clean_json_response(content)
analysis = json.loads(cleaned)
else:
# 非MCP调用:使用带自动重试的JSON调用
analysis = await self.ai_service.call_with_json_retry(
prompt=prompt,
max_retries=3
)
# 使用统一的JSON调用方法(支持自动MCP工具加载
analysis = await self.ai_service.call_with_json_retry(
prompt=prompt,
max_retries=3,
)
logger.info(f" ✅ AI分析完成: needs_new_organizations={analysis.get('needs_new_organizations')}")
return analysis
@@ -362,24 +348,11 @@ class AutoOrganizationService:
# 调用AI生成(使用统一的JSON调用方法)
try:
if enable_mcp and user_id:
result = await self.ai_service.generate_text_with_mcp(
prompt=prompt,
user_id=user_id,
db_session=db,
enable_mcp=True,
max_tool_rounds=2
)
content = result.get("content", "")
# 使用统一的JSON清洗方法
cleaned = self.ai_service._clean_json_response(content)
organization_data = json.loads(cleaned)
else:
# 非MCP调用:使用带自动重试的JSON调用
organization_data = await self.ai_service.call_with_json_retry(
prompt=prompt,
max_retries=3
)
# 使用统一的JSON调用方法(支持自动MCP工具加载)
organization_data = await self.ai_service.call_with_json_retry(
prompt=prompt,
max_retries=3,
)
org_name = organization_data.get('name', '未知')
logger.info(f" ✅ 组织详情生成成功: {org_name}")
+78 -60
View File
@@ -1,4 +1,7 @@
"""MCP插件测试服务 - 专门处理插件测试逻辑"""
"""MCP插件测试服务 - 专门处理插件测试逻辑
重构后使用统一的MCPClientFacade门面来管理所有MCP操作。
"""
import time
import json
@@ -10,7 +13,7 @@ from sqlalchemy import select
from app.models.mcp_plugin import MCPPlugin
from app.models.settings import Settings as UserSettings
from app.mcp.registry import mcp_registry
from app.mcp import mcp_client, MCPPluginConfig # 使用新的统一门面
from app.services.ai_service import create_user_ai_service
from app.schemas.mcp_plugin import MCPTestResult
from app.services.prompt_service import prompt_service
@@ -21,7 +24,32 @@ logger = get_logger(__name__)
class MCPTestService:
"""MCP插件测试服务(分离的测试逻辑"""
"""MCP插件测试服务(使用统一门面重构"""
async def _ensure_plugin_registered(
self,
plugin: MCPPlugin,
user_id: str
) -> bool:
"""
确保插件已注册到统一门面
Args:
plugin: 插件配置
user_id: 用户ID
Returns:
是否成功
"""
if plugin.plugin_type in ("http", "streamable_http", "sse") and plugin.server_url:
return await mcp_client.ensure_registered(
user_id=user_id,
plugin_name=plugin.plugin_name,
url=plugin.server_url,
plugin_type=plugin.plugin_type,
headers=plugin.headers
)
return False
async def test_plugin_connection(
self,
@@ -41,19 +69,18 @@ class MCPTestService:
start_time = time.time()
try:
# 确保插件已加载
if not mcp_registry.get_client(user_id, plugin.plugin_name):
success = await mcp_registry.load_plugin(plugin)
if not success:
return MCPTestResult(
success=False,
message="插件加载失败",
error="无法创建MCP客户端",
suggestions=["请检查插件配置", "请确认服务器URL正确"]
)
# 确保插件已注册
registered = await self._ensure_plugin_registered(plugin, user_id)
if not registered:
return MCPTestResult(
success=False,
message="插件注册失败",
error="无法创建MCP客户端",
suggestions=["请检查插件配置", "请确认服务器URL正确"]
)
# 测试连接并获取工具列表
test_result = await mcp_registry.test_plugin(user_id, plugin.plugin_name)
# 使用统一门面测试连接
test_result = await mcp_client.test_connection(user_id, plugin.plugin_name)
end_time = time.time()
response_time = round((end_time - start_time) * 1000, 2)
@@ -70,7 +97,18 @@ class MCPTestService:
]
)
else:
return MCPTestResult(**test_result)
return MCPTestResult(
success=False,
message="❌ 连接测试失败",
response_time_ms=response_time,
error=test_result.get("message", "未知错误"),
error_type=test_result.get("error_type"),
suggestions=[
"请检查服务器是否在线",
"请确认配置正确",
"请检查API Key是否有效"
]
)
except Exception as e:
end_time = time.time()
@@ -117,8 +155,8 @@ class MCPTestService:
if not connection_result.success:
return connection_result
# 2. 获取工具列表
tools = await mcp_registry.get_plugin_tools(user.user_id, plugin.plugin_name)
# 2. 使用统一门面获取工具列表
tools = await mcp_client.get_tools(user.user_id, plugin.plugin_name)
if not tools:
return MCPTestResult(
@@ -162,8 +200,8 @@ class MCPTestService:
max_tokens=1000
)
# 转换为OpenAI Function Calling格式
openai_tools = self._convert_tools_to_openai_format(tools)
# 使用统一门面转换为OpenAI Function Calling格式
openai_tools = mcp_client.format_tools_for_openai(tools, plugin.plugin_name)
logger.info(f"📋 转换后的OpenAI工具数量: {len(openai_tools)}")
logger.debug(f"📋 OpenAI工具列表: {[t['function']['name'] for t in openai_tools]}")
@@ -175,26 +213,16 @@ class MCPTestService:
db=db_session
)
# 注意: generate_text_stream 返回的是异步生成器,但在 tool_choice="required" 模式下
# AI服务会直接返回包含 tool_calls 的完整响应,而不是流式chunks
# 因此这里需要特殊处理
accumulated_text = ""
tool_calls = None
async for chunk in ai_service.generate_text_stream(
# 使用 generate_text 进行 Function Calling(非流式)
ai_response = await ai_service.generate_text(
prompt=prompts["user"],
system_prompt=prompts["system"],
tools=openai_tools,
tool_choice="required"
):
# 在 function calling 模式下,chunk 可能是字典格式包含 tool_calls
if isinstance(chunk, dict):
if "tool_calls" in chunk:
tool_calls = chunk["tool_calls"]
if "content" in chunk:
accumulated_text += chunk.get("content", "")
else:
accumulated_text += chunk
tool_choice="auto"
)
accumulated_text = ai_response.get("content", "")
tool_calls = ai_response.get("tool_calls")
# 5. 检查AI是否返回工具调用
if not tool_calls:
@@ -214,7 +242,7 @@ class MCPTestService:
# 6. 解析工具调用
tool_call = tool_calls[0]
function = tool_call["function"]
tool_name = function["name"]
tool_name_with_prefix = function["name"]
test_arguments = function["arguments"]
if isinstance(test_arguments, str):
@@ -231,17 +259,23 @@ class MCPTestService:
tools_count=len(tools)
)
# 解析插件名和工具名
try:
_, tool_name = mcp_client.parse_function_name(tool_name_with_prefix)
except ValueError:
tool_name = tool_name_with_prefix
logger.info(f"🤖 AI选择的工具: {tool_name}")
logger.info(f"📝 AI生成的参数: {test_arguments}")
# 7. 调用MCP工具
# 7. 使用统一门面调用MCP工具
call_start = time.time()
try:
tool_result = await mcp_registry.call_tool(
user.user_id,
plugin.plugin_name,
tool_name,
test_arguments
tool_result = await mcp_client.call_tool(
user_id=user.user_id,
plugin_name=plugin.plugin_name,
tool_name=tool_name,
arguments=test_arguments
)
call_end = time.time()
@@ -307,22 +341,6 @@ class MCPTestService:
"请检查API Key是否有效"
]
)
def _convert_tools_to_openai_format(self, tools: list) -> list:
"""将MCP工具格式转换为OpenAI Function Calling格式"""
openai_tools = []
for tool in tools:
openai_tool = {
"type": "function",
"function": {
"name": tool["name"],
"description": tool.get("description", ""),
}
}
if "inputSchema" in tool:
openai_tool["function"]["parameters"] = tool["inputSchema"]
openai_tools.append(openai_tool)
return openai_tools
# 全局单例
-691
View File
@@ -1,691 +0,0 @@
"""MCP工具服务 - 统一管理MCP工具的注入和执行"""
from typing import List, Dict, Any, Optional
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
import asyncio
import json
import time
from datetime import datetime, timedelta
from dataclasses import dataclass, field
from collections import defaultdict
from app.models.mcp_plugin import MCPPlugin
from app.mcp.registry import mcp_registry
from app.mcp.config import mcp_config
from app.logger import get_logger
logger = get_logger(__name__)
@dataclass
class ToolMetrics:
"""工具调用指标"""
total_calls: int = 0
success_calls: int = 0
failed_calls: int = 0
total_duration_ms: float = 0.0
avg_duration_ms: float = 0.0
last_call_time: Optional[datetime] = None
def update_success(self, duration_ms: float):
"""更新成功调用指标"""
self.total_calls += 1
self.success_calls += 1
self.total_duration_ms += duration_ms
self.avg_duration_ms = self.total_duration_ms / self.total_calls
self.last_call_time = datetime.now()
def update_failure(self, duration_ms: float):
"""更新失败调用指标"""
self.total_calls += 1
self.failed_calls += 1
self.total_duration_ms += duration_ms
self.avg_duration_ms = self.total_duration_ms / self.total_calls
self.last_call_time = datetime.now()
@property
def success_rate(self) -> float:
"""成功率"""
if self.total_calls == 0:
return 0.0
return self.success_calls / self.total_calls
@dataclass
class ToolCacheEntry:
"""工具缓存条目"""
tools: List[Dict[str, Any]]
expire_time: datetime
hit_count: int = 0
class MCPToolServiceError(Exception):
"""MCP工具服务异常"""
pass
class MCPToolService:
"""MCP工具服务 - 统一管理MCP工具的注入和执行(优化版)"""
def __init__(
self,
cache_ttl_minutes: Optional[int] = None,
max_retries: Optional[int] = None
):
"""
初始化MCP工具服务
Args:
cache_ttl_minutes: 工具缓存TTL(分钟,默认使用配置)
max_retries: 最大重试次数(默认使用配置)
"""
# 工具定义缓存: {cache_key: ToolCacheEntry}
self._tool_cache: Dict[str, ToolCacheEntry] = {}
self._cache_ttl = timedelta(
minutes=cache_ttl_minutes or mcp_config.TOOL_CACHE_TTL_MINUTES
)
# 调用指标: {tool_key: ToolMetrics}
self._metrics: Dict[str, ToolMetrics] = defaultdict(ToolMetrics)
# 重试配置(使用配置常量)
self._max_retries = max_retries or mcp_config.MAX_RETRIES
self._base_retry_delay = mcp_config.BASE_RETRY_DELAY_SECONDS
self._max_retry_delay = mcp_config.MAX_RETRY_DELAY_SECONDS
logger.info(
f"✅ MCPToolService初始化完成 "
f"(缓存TTL={self._cache_ttl.total_seconds()/60:.1f}分钟, "
f"最大重试={self._max_retries}次)"
)
async def get_user_enabled_tools(
self,
user_id: str,
db_session: AsyncSession,
category: Optional[str] = None
) -> List[Dict[str, Any]]:
"""
获取用户启用的MCP工具列表
Args:
user_id: 用户ID
db_session: 数据库会话
category: 工具类别筛选(search/analysis/filesystem等)
Returns:
工具定义列表,格式符合OpenAI Function Calling规范
"""
try:
# 1. 查询用户启用的插件(enabled=True即可,不强制要求status=active
# 因为新启用的插件status可能还是inactive,需要给它机会被调用
query = select(MCPPlugin).where(
MCPPlugin.user_id == user_id,
MCPPlugin.enabled == True
)
if category:
query = query.where(MCPPlugin.category == category)
result = await db_session.execute(query)
plugins = result.scalars().all()
if not plugins:
logger.info(f"用户 {user_id} 没有启用的MCP插件")
return []
# 2. 获取所有工具定义(使用缓存)
all_tools = []
for plugin in plugins:
try:
# 确保插件已加载到注册表
if not mcp_registry.get_client(user_id, plugin.plugin_name):
logger.info(f"插件 {plugin.plugin_name} 未加载,尝试加载...")
success = await mcp_registry.load_plugin(plugin)
if not success:
logger.warning(f"插件 {plugin.plugin_name} 加载失败,跳过")
continue
# ✅ 使用缓存获取工具列表
plugin_tools = await self._get_plugin_tools_cached(
user_id=user_id,
plugin_name=plugin.plugin_name
)
# 格式化为Function Calling格式
formatted_tools = self._format_tools_for_ai(
plugin_tools,
plugin.plugin_name
)
all_tools.extend(formatted_tools)
logger.info(
f"从插件 {plugin.plugin_name} 加载了 "
f"{len(formatted_tools)} 个工具"
)
except Exception as e:
logger.error(
f"获取插件 {plugin.plugin_name} 的工具失败: {e}",
exc_info=True
)
continue
logger.info(f"用户 {user_id} 共加载 {len(all_tools)} 个MCP工具")
return all_tools
except Exception as e:
logger.error(f"获取用户MCP工具失败: {e}", exc_info=True)
raise MCPToolServiceError(f"获取MCP工具失败: {str(e)}")
def _format_tools_for_ai(
self,
plugin_tools: List[Dict[str, Any]],
plugin_name: str
) -> List[Dict[str, Any]]:
"""
将MCP工具定义格式化为AI Function Calling格式
Args:
plugin_tools: MCP插件的工具列表
plugin_name: 插件名称
Returns:
格式化后的工具列表
"""
formatted_tools = []
for tool in plugin_tools:
formatted_tool = {
"type": "function",
"function": {
"name": f"{plugin_name}_{tool['name']}", # 加插件前缀避免冲突
"description": tool.get("description", ""),
"parameters": tool.get("inputSchema", {
"type": "object",
"properties": {},
"required": []
})
}
}
formatted_tools.append(formatted_tool)
return formatted_tools
async def _get_plugin_tools_cached(
self,
user_id: str,
plugin_name: str
) -> List[Dict[str, Any]]:
"""
带缓存的工具列表获取
Args:
user_id: 用户ID
plugin_name: 插件名称
Returns:
工具列表
"""
cache_key = f"{user_id}:{plugin_name}"
now = datetime.now()
# 检查缓存
if cache_key in self._tool_cache:
entry = self._tool_cache[cache_key]
if now < entry.expire_time:
entry.hit_count += 1
logger.debug(
f"🎯 工具缓存命中: {cache_key} "
f"(命中次数: {entry.hit_count})"
)
return entry.tools
else:
logger.debug(f"⏰ 工具缓存过期: {cache_key}")
del self._tool_cache[cache_key]
# 缓存未命中,从MCP获取
logger.debug(f"🔍 工具缓存未命中,从MCP获取: {cache_key}")
tools = await mcp_registry.get_plugin_tools(user_id, plugin_name)
# 更新缓存
self._tool_cache[cache_key] = ToolCacheEntry(
tools=tools,
expire_time=now + self._cache_ttl,
hit_count=0
)
return tools
def clear_cache(self, user_id: Optional[str] = None, plugin_name: Optional[str] = None):
"""
清理缓存
Args:
user_id: 用户ID(可选,清理特定用户的缓存)
plugin_name: 插件名称(可选,清理特定插件的缓存)
"""
if user_id is None and plugin_name is None:
# 清理所有缓存
self._tool_cache.clear()
logger.info("🧹 已清理所有工具缓存")
elif user_id and plugin_name:
# 清理特定插件缓存
cache_key = f"{user_id}:{plugin_name}"
if cache_key in self._tool_cache:
del self._tool_cache[cache_key]
logger.info(f"🧹 已清理缓存: {cache_key}")
elif user_id:
# 清理用户所有缓存
keys_to_delete = [
key for key in self._tool_cache.keys()
if key.startswith(f"{user_id}:")
]
for key in keys_to_delete:
del self._tool_cache[key]
logger.info(f"🧹 已清理用户缓存: {user_id} ({len(keys_to_delete)}个)")
async def execute_tool_calls(
self,
user_id: str,
tool_calls: List[Dict[str, Any]],
db_session: AsyncSession,
timeout: Optional[float] = None,
max_concurrent: int = 2
) -> List[Dict[str, Any]]:
"""
批量执行AI请求的工具调用(限制并发数,避免超时)
Args:
user_id: 用户ID
tool_calls: AI返回的工具调用列表
db_session: 数据库会话
timeout: 单个工具调用的超时时间(秒,默认使用配置)
max_concurrent: 最大并发工具调用数(默认2)
Returns:
工具调用结果列表
"""
if not tool_calls:
return []
# 使用配置的默认超时
actual_timeout = timeout or mcp_config.TOOL_CALL_TIMEOUT_SECONDS
logger.info(f"开始执行 {len(tool_calls)} 个工具调用 (超时={actual_timeout}s, 最大并发={max_concurrent})")
# ✅ 分批执行,每批最多max_concurrent个
all_results = []
for i in range(0, len(tool_calls), max_concurrent):
batch = tool_calls[i:i+max_concurrent]
batch_num = i // max_concurrent + 1
total_batches = (len(tool_calls) + max_concurrent - 1) // max_concurrent
logger.info(f"执行工具批次 {batch_num}/{total_batches}, 数量: {len(batch)}")
# 创建当前批次的异步任务
tasks = [
self._execute_single_tool(
user_id=user_id,
tool_call=tool_call,
db_session=db_session,
timeout=actual_timeout
)
for tool_call in batch
]
# 并行执行当前批次
batch_results = await asyncio.gather(*tasks, return_exceptions=True)
# 处理批次结果
for j, result in enumerate(batch_results):
tool_call = batch[j]
if isinstance(result, Exception):
# 工具调用异常
all_results.append({
"tool_call_id": tool_call.get("id", f"call_{i+j}"),
"role": "tool",
"name": tool_call["function"]["name"],
"content": f"工具调用失败: {str(result)}",
"success": False,
"error": str(result)
})
else:
all_results.append(result)
# 批次间增加短暂延迟,避免API限流
if i + max_concurrent < len(tool_calls):
await asyncio.sleep(0.5)
logger.debug(f"批次间延迟 0.5 秒...")
return all_results
async def _execute_single_tool(
self,
user_id: str,
tool_call: Dict[str, Any],
db_session: AsyncSession,
timeout: float
) -> Dict[str, Any]:
"""
执行单个工具调用
Args:
user_id: 用户ID
tool_call: 工具调用信息
db_session: 数据库会话
timeout: 超时时间
Returns:
工具调用结果
"""
tool_call_id = tool_call.get("id", "unknown")
function_name = tool_call["function"]["name"]
try:
# 解析插件名和工具名
logger.debug(f"🔍 解析工具名称: {function_name}")
if "_" in function_name:
plugin_name, tool_name = function_name.split("_", 1)
logger.debug(f" 插件: {plugin_name}, 工具: {tool_name}")
else:
raise ValueError(f"无效的工具名称格式: {function_name}")
# 解析参数
arguments_str = tool_call["function"]["arguments"]
logger.debug(f"🔍 解析参数:")
logger.debug(f" 原始类型: {type(arguments_str)}")
logger.debug(f" 原始内容: {arguments_str}")
if isinstance(arguments_str, str):
try:
arguments = json.loads(arguments_str)
logger.debug(f" ✅ JSON解析成功: {arguments}")
except json.JSONDecodeError as je:
logger.error(f" ❌ JSON解析失败: {je}")
logger.error(f" 原始字符串: '{arguments_str}'")
raise ValueError(f"参数JSON解析失败: {je}")
else:
arguments = arguments_str
logger.debug(f" 直接使用dict类型参数")
logger.info(
f"执行工具: {plugin_name}.{tool_name}, "
f"参数: {arguments}"
)
# ✅ 使用带重试的调用
tool_key = f"{plugin_name}.{tool_name}"
start_time = time.time()
try:
result = await self._call_tool_with_retry(
user_id=user_id,
plugin_name=plugin_name,
tool_name=tool_name,
arguments=arguments,
timeout=timeout
)
# 记录成功指标
duration_ms = (time.time() - start_time) * 1000
self._metrics[tool_key].update_success(duration_ms)
logger.info(
f"✅ 工具调用成功: {tool_key} "
f"(耗时: {duration_ms:.2f}ms)"
)
# 成功返回
return {
"tool_call_id": tool_call_id,
"role": "tool",
"name": function_name,
"content": json.dumps(result, ensure_ascii=False),
"success": True,
"error": None
}
except asyncio.TimeoutError:
# 记录失败指标
duration_ms = (time.time() - start_time) * 1000
self._metrics[tool_key].update_failure(duration_ms)
raise MCPToolServiceError(
f"工具调用超时(>{timeout}秒)"
)
except Exception as e:
# 记录失败指标
tool_key = f"{plugin_name}.{tool_name}" if 'plugin_name' in locals() else function_name
duration_ms = (time.time() - start_time) * 1000
self._metrics[tool_key].update_failure(duration_ms)
logger.error(
f"❌ 工具 {function_name} 调用失败: {e}",
exc_info=True
)
return {
"tool_call_id": tool_call_id,
"role": "tool",
"name": function_name,
"content": f"工具调用失败: {str(e)}",
"success": False,
"error": str(e)
}
async def _call_tool_with_retry(
self,
user_id: str,
plugin_name: str,
tool_name: str,
arguments: Dict[str, Any],
timeout: float
) -> Any:
"""
带指数退避重试的工具调用
Args:
user_id: 用户ID
plugin_name: 插件名称
tool_name: 工具名称
arguments: 工具参数
timeout: 超时时间
Returns:
工具执行结果
Raises:
MCPToolServiceError: 工具调用失败
asyncio.TimeoutError: 调用超时
"""
last_exception = None
for attempt in range(self._max_retries):
try:
# 尝试调用工具
result = await asyncio.wait_for(
mcp_registry.call_tool(
user_id=user_id,
plugin_name=plugin_name,
tool_name=tool_name,
arguments=arguments
),
timeout=timeout
)
# 成功则返回
if attempt > 0:
logger.info(
f"✅ 重试成功: {plugin_name}.{tool_name} "
f"(第{attempt + 1}次尝试)"
)
return result
except asyncio.TimeoutError:
# 超时不重试,直接抛出
raise
except Exception as e:
last_exception = e
# 最后一次尝试失败
if attempt == self._max_retries - 1:
logger.error(
f"❌ 重试失败: {plugin_name}.{tool_name} "
f"(已尝试{self._max_retries}次): {e}"
)
raise MCPToolServiceError(
f"工具调用失败(已重试{self._max_retries}次): {str(e)}"
)
# 计算指数退避延迟
delay = min(
self._base_retry_delay * (2 ** attempt),
self._max_retry_delay
)
logger.warning(
f"⚠️ 工具调用失败,{delay:.1f}秒后重试 "
f"(第{attempt + 1}/{self._max_retries}次): "
f"{plugin_name}.{tool_name} - {e}"
)
await asyncio.sleep(delay)
# 理论上不会到这里,但为了类型安全
raise MCPToolServiceError(f"工具调用失败: {last_exception}")
def get_metrics(self, tool_name: Optional[str] = None) -> Dict[str, Dict[str, Any]]:
"""
获取工具调用指标
Args:
tool_name: 工具名称(可选,获取特定工具的指标)
Returns:
指标字典
"""
if tool_name:
if tool_name in self._metrics:
metric = self._metrics[tool_name]
return {
tool_name: {
"total_calls": metric.total_calls,
"success_calls": metric.success_calls,
"failed_calls": metric.failed_calls,
"success_rate": metric.success_rate,
"avg_duration_ms": round(metric.avg_duration_ms, 2),
"last_call_time": metric.last_call_time.isoformat() if metric.last_call_time else None
}
}
return {}
# 返回所有工具的指标
result = {}
for tool_key, metric in self._metrics.items():
result[tool_key] = {
"total_calls": metric.total_calls,
"success_calls": metric.success_calls,
"failed_calls": metric.failed_calls,
"success_rate": round(metric.success_rate, 3),
"avg_duration_ms": round(metric.avg_duration_ms, 2),
"last_call_time": metric.last_call_time.isoformat() if metric.last_call_time else None
}
return result
def get_cache_stats(self) -> Dict[str, Any]:
"""获取缓存统计信息"""
total_entries = len(self._tool_cache)
total_hits = sum(entry.hit_count for entry in self._tool_cache.values())
return {
"total_entries": total_entries,
"total_hits": total_hits,
"cache_ttl_minutes": self._cache_ttl.total_seconds() / 60,
"entries": [
{
"key": key,
"tools_count": len(entry.tools),
"hit_count": entry.hit_count,
"expire_time": entry.expire_time.isoformat()
}
for key, entry in self._tool_cache.items()
]
}
async def build_tool_context(
self,
tool_results: List[Dict[str, Any]],
format: str = "markdown"
) -> str:
"""
将工具调用结果格式化为上下文文本
Args:
tool_results: 工具调用结果列表
format: 输出格式(markdown/json/plain
Returns:
格式化的上下文字符串
"""
if not tool_results:
return ""
if format == "markdown":
return self._build_markdown_context(tool_results)
elif format == "json":
return json.dumps(tool_results, ensure_ascii=False, indent=2)
else: # plain
return self._build_plain_context(tool_results)
def _build_markdown_context(
self,
tool_results: List[Dict[str, Any]]
) -> str:
"""构建Markdown格式的工具上下文"""
lines = ["## 🔧 工具调用结果\n"]
for i, result in enumerate(tool_results, 1):
tool_name = result.get("name", "unknown")
success = result.get("success", False)
content = result.get("content", "")
status_emoji = "" if success else ""
lines.append(f"### {status_emoji} {i}. {tool_name}\n")
if success:
# 尝试美化JSON内容
try:
content_obj = json.loads(content)
content = json.dumps(content_obj, ensure_ascii=False, indent=2)
except:
pass
lines.append(f"```json\n{content}\n```\n")
else:
lines.append(f"**错误**: {content}\n")
return "\n".join(lines)
def _build_plain_context(
self,
tool_results: List[Dict[str, Any]]
) -> str:
"""构建纯文本格式的工具上下文"""
lines = ["=== 工具调用结果 ===\n"]
for i, result in enumerate(tool_results, 1):
tool_name = result.get("name", "unknown")
success = result.get("success", False)
content = result.get("content", "")
status = "成功" if success else "失败"
lines.append(f"{i}. {tool_name} - {status}")
lines.append(f" 结果: {content}\n")
return "\n".join(lines)
# 全局单例
mcp_tool_service = MCPToolService()
+235
View File
@@ -0,0 +1,235 @@
"""MCP工具加载器 - 统一的工具获取入口
在AI请求之前,自动检查用户MCP配置并加载可用工具。
"""
from typing import Optional, List, Dict, Any
from datetime import datetime, timedelta
from dataclasses import dataclass, field
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.logger import get_logger
from app.models.mcp_plugin import MCPPlugin
from app.mcp import mcp_client
logger = get_logger(__name__)
@dataclass
class UserToolsCache:
"""用户工具缓存条目"""
tools: Optional[List[Dict[str, Any]]]
expire_time: datetime
hit_count: int = 0
class MCPToolsLoader:
"""
MCP工具加载器
负责:
1. 检查用户是否配置并启用了MCP插件
2. 从各个启用的插件加载工具列表
3. 将工具转换为OpenAI Function Calling格式
4. 缓存结果以提升性能
"""
_instance: Optional['MCPToolsLoader'] = None
def __new__(cls):
"""单例模式"""
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance._initialized = False
return cls._instance
def __init__(self):
if self._initialized:
return
# 用户工具缓存: user_id -> UserToolsCache
self._cache: Dict[str, UserToolsCache] = {}
# 缓存TTL5分钟)
self._cache_ttl = timedelta(minutes=5)
self._initialized = True
logger.info("✅ MCPToolsLoader 初始化完成")
async def has_enabled_plugins(
self,
user_id: str,
db_session: AsyncSession
) -> bool:
"""
检查用户是否有启用的MCP插件
Args:
user_id: 用户ID
db_session: 数据库会话
Returns:
是否有启用的插件
"""
try:
query = select(MCPPlugin.id).where(
MCPPlugin.user_id == user_id,
MCPPlugin.enabled == True,
MCPPlugin.plugin_type.in_(["http", "streamable_http", "sse"])
).limit(1)
result = await db_session.execute(query)
return result.scalar() is not None
except Exception as e:
logger.warning(f"检查用户MCP插件失败: {e}")
return False
async def get_user_tools(
self,
user_id: str,
db_session: AsyncSession,
use_cache: bool = True,
force_refresh: bool = False
) -> Optional[List[Dict[str, Any]]]:
"""
获取用户的MCP工具列表(OpenAI格式)
Args:
user_id: 用户ID
db_session: 数据库会话
use_cache: 是否使用缓存
force_refresh: 是否强制刷新
Returns:
- None: 用户未配置或未启用任何MCP插件
- []: 有配置但没有可用工具
- List[Dict]: OpenAI Function Calling格式的工具列表
"""
now = datetime.now()
# 检查缓存
if use_cache and not force_refresh and user_id in self._cache:
cache_entry = self._cache[user_id]
if now < cache_entry.expire_time:
cache_entry.hit_count += 1
logger.debug(f"🎯 用户工具缓存命中: {user_id} (命中次数: {cache_entry.hit_count})")
return cache_entry.tools
else:
del self._cache[user_id]
logger.debug(f"⏰ 用户工具缓存过期: {user_id}")
# 从数据库加载
try:
tools = await self._load_user_tools(user_id, db_session)
# 更新缓存
self._cache[user_id] = UserToolsCache(
tools=tools,
expire_time=now + self._cache_ttl
)
if tools:
logger.info(f"🔧 用户 {user_id} 加载了 {len(tools)} 个MCP工具")
else:
logger.debug(f"📭 用户 {user_id} 没有可用的MCP工具")
return tools
except Exception as e:
logger.error(f"❌ 加载用户MCP工具失败: {e}")
return None
async def _load_user_tools(
self,
user_id: str,
db_session: AsyncSession
) -> Optional[List[Dict[str, Any]]]:
"""
从数据库加载用户启用的MCP插件并获取工具
"""
# 查询启用的插件
query = select(MCPPlugin).where(
MCPPlugin.user_id == user_id,
MCPPlugin.enabled == True,
MCPPlugin.plugin_type.in_(["http", "streamable_http", "sse"])
).order_by(MCPPlugin.sort_order)
result = await db_session.execute(query)
plugins = result.scalars().all()
if not plugins:
return None
all_tools = []
for plugin in plugins:
try:
# 确定插件类型
plugin_type = plugin.plugin_type
if plugin_type == "http":
plugin_type = "streamable_http" # 默认使用streamable_http
# 确保插件已注册到MCP客户端
await mcp_client.ensure_registered(
user_id=user_id,
plugin_name=plugin.plugin_name,
url=plugin.server_url,
plugin_type=plugin_type,
headers=plugin.headers
)
# 获取工具列表
plugin_tools = await mcp_client.get_tools(user_id, plugin.plugin_name)
# 转换为OpenAI格式
formatted = mcp_client.format_tools_for_openai(plugin_tools, plugin.plugin_name)
all_tools.extend(formatted)
logger.debug(f"✅ 从插件 {plugin.plugin_name} 加载了 {len(formatted)} 个工具")
except Exception as e:
logger.warning(f"⚠️ 加载插件 {plugin.plugin_name} 工具失败: {e}")
continue
return all_tools if all_tools else None
def invalidate_cache(self, user_id: Optional[str] = None):
"""
使缓存失效
Args:
user_id: 用户ID,为None时清空所有缓存
"""
if user_id:
if user_id in self._cache:
del self._cache[user_id]
logger.debug(f"🧹 清理用户工具缓存: {user_id}")
else:
count = len(self._cache)
self._cache.clear()
logger.info(f"🧹 清理所有用户工具缓存 ({count}个)")
def get_cache_stats(self) -> Dict[str, Any]:
"""获取缓存统计"""
now = datetime.now()
return {
"total_entries": len(self._cache),
"total_hits": sum(e.hit_count for e in self._cache.values()),
"cache_ttl_minutes": self._cache_ttl.total_seconds() / 60,
"entries": [
{
"user_id": uid,
"tools_count": len(e.tools) if e.tools else 0,
"hit_count": e.hit_count,
"expired": now >= e.expire_time,
"expire_time": e.expire_time.isoformat()
}
for uid, e in self._cache.items()
]
}
# 全局单例
mcp_tools_loader = MCPToolsLoader()