feat: 重构MCP功能和AI服务提供者架构
This commit is contained in:
@@ -71,7 +71,18 @@ class AnthropicClient:
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
system_prompt: Optional[str] = None,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
tools: Optional[list] = None,
|
||||
tool_choice: Optional[str] = None,
|
||||
) -> AsyncGenerator[Dict[str, Any], None]:
|
||||
"""
|
||||
流式生成,支持工具调用
|
||||
|
||||
Yields:
|
||||
Dict with keys:
|
||||
- content: str - 文本内容块
|
||||
- tool_calls: list - 工具调用列表(如果有)
|
||||
- done: bool - 是否结束
|
||||
"""
|
||||
kwargs = {
|
||||
"model": model,
|
||||
"max_tokens": max_tokens,
|
||||
@@ -80,12 +91,42 @@ class AnthropicClient:
|
||||
}
|
||||
if system_prompt:
|
||||
kwargs["system"] = system_prompt
|
||||
if tools:
|
||||
kwargs["tools"] = tools
|
||||
if tool_choice == "required":
|
||||
kwargs["tool_choice"] = {"type": "any"}
|
||||
elif tool_choice == "auto":
|
||||
kwargs["tool_choice"] = {"type": "auto"}
|
||||
|
||||
try:
|
||||
async with self.client.messages.stream(**kwargs) as stream:
|
||||
try:
|
||||
async for text in stream.text_stream:
|
||||
yield text
|
||||
tool_calls = []
|
||||
async for chunk in stream:
|
||||
# 处理不同类型的块
|
||||
if chunk.type == "text_delta":
|
||||
yield {"content": chunk.text}
|
||||
elif chunk.type == "tool_use_delta":
|
||||
# 工具调用增量
|
||||
if not tool_calls or tool_calls[-1].get("id") != chunk.id:
|
||||
tool_calls.append({
|
||||
"id": chunk.id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": chunk.name,
|
||||
"arguments": ""
|
||||
}
|
||||
})
|
||||
# 追加参数
|
||||
if tool_calls[-1]["function"]["arguments"] is None:
|
||||
tool_calls[-1]["function"]["arguments"] = ""
|
||||
tool_calls[-1]["function"]["arguments"] += chunk.input_gets_new_text or ""
|
||||
elif chunk.type == "message_delta":
|
||||
if chunk.stop_reason:
|
||||
# 流结束
|
||||
if tool_calls:
|
||||
yield {"tool_calls": tool_calls}
|
||||
yield {"done": True, "finish_reason": chunk.stop_reason}
|
||||
except GeneratorExit:
|
||||
# 生成器被关闭,这是正常的清理过程
|
||||
logger.debug("Anthropic 流式响应生成器被关闭(GeneratorExit)")
|
||||
|
||||
@@ -111,7 +111,18 @@ class GeminiClient:
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
system_prompt: Optional[str] = None,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
tools: Optional[list] = None,
|
||||
tool_choice: Optional[str] = None,
|
||||
) -> AsyncGenerator[Dict[str, Any], None]:
|
||||
"""
|
||||
流式生成,支持工具调用
|
||||
|
||||
Yields:
|
||||
Dict with keys:
|
||||
- content: str - 文本内容块
|
||||
- tool_calls: list - 工具调用列表(如果有)
|
||||
- done: bool - 是否结束
|
||||
"""
|
||||
url = f"{self.base_url}/models/{model}:streamGenerateContent?key={self.api_key}&alt=sse"
|
||||
|
||||
contents = []
|
||||
@@ -125,6 +136,8 @@ class GeminiClient:
|
||||
}
|
||||
if system_prompt:
|
||||
payload["systemInstruction"] = {"parts": [{"text": system_prompt}]}
|
||||
if tools:
|
||||
payload["tools"] = self._convert_tools_to_gemini(tools)
|
||||
|
||||
try:
|
||||
async with self.client.stream("POST", url, json=payload) as response:
|
||||
@@ -139,9 +152,26 @@ class GeminiClient:
|
||||
if candidates and len(candidates) > 0:
|
||||
parts = candidates[0].get("content", {}).get("parts", [])
|
||||
if parts and len(parts) > 0:
|
||||
text = parts[0].get("text", "")
|
||||
text = ""
|
||||
function_calls = []
|
||||
for part in parts:
|
||||
if "text" in part:
|
||||
text += part["text"]
|
||||
elif "functionCall" in part:
|
||||
fc = part["functionCall"]
|
||||
function_calls.append({
|
||||
"id": f"call_{fc['name']}",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": fc["name"],
|
||||
"arguments": fc.get("args", {})
|
||||
}
|
||||
})
|
||||
|
||||
if text:
|
||||
yield text
|
||||
yield {"content": text}
|
||||
if function_calls:
|
||||
yield {"tool_calls": function_calls}
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
except GeneratorExit:
|
||||
|
||||
@@ -86,8 +86,21 @@ class OpenAIClient(BaseAIClient):
|
||||
model: str,
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
payload = self._build_payload(messages, model, temperature, max_tokens, stream=True)
|
||||
tools: Optional[list] = None,
|
||||
tool_choice: Optional[str] = None,
|
||||
) -> AsyncGenerator[Dict[str, Any], None]:
|
||||
"""
|
||||
流式生成,支持工具调用
|
||||
|
||||
Yields:
|
||||
Dict with keys:
|
||||
- content: str - 文本内容块
|
||||
- tool_calls: list - 工具调用列表(如果有)
|
||||
- done: bool - 是否结束
|
||||
"""
|
||||
payload = self._build_payload(messages, model, temperature, max_tokens, tools, tool_choice, stream=True)
|
||||
|
||||
tool_calls_buffer = {} # 收集工具调用块
|
||||
|
||||
try:
|
||||
async with await self._request_with_retry("POST", "/chat/completions", payload, stream=True) as response:
|
||||
@@ -97,14 +110,38 @@ class OpenAIClient(BaseAIClient):
|
||||
if line.startswith("data: "):
|
||||
data_str = line[6:]
|
||||
if data_str.strip() == "[DONE]":
|
||||
# 流结束,检查是否有工具调用需要处理
|
||||
if tool_calls_buffer:
|
||||
yield {"tool_calls": list(tool_calls_buffer.values()), "done": True}
|
||||
yield {"done": True}
|
||||
break
|
||||
try:
|
||||
data = json.loads(data_str)
|
||||
choices = data.get("choices", [])
|
||||
if choices and len(choices) > 0:
|
||||
content = choices[0].get("delta", {}).get("content", "")
|
||||
delta = choices[0].get("delta", {})
|
||||
content = delta.get("content", "")
|
||||
|
||||
# 检查工具调用
|
||||
tc_list = delta.get("tool_calls")
|
||||
if tc_list:
|
||||
for tc in tc_list:
|
||||
index = tc.get("index", 0)
|
||||
if index not in tool_calls_buffer:
|
||||
tool_calls_buffer[index] = tc
|
||||
else:
|
||||
existing = tool_calls_buffer[index]
|
||||
# 合并 function.arguments
|
||||
if "function" in tc and "function" in existing:
|
||||
if tc["function"].get("arguments"):
|
||||
existing["function"]["arguments"] = (
|
||||
existing["function"].get("arguments", "") +
|
||||
tc["function"]["arguments"]
|
||||
)
|
||||
|
||||
if content:
|
||||
yield content
|
||||
yield {"content": content}
|
||||
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
except GeneratorExit:
|
||||
|
||||
@@ -1,9 +1,12 @@
|
||||
"""Anthropic Provider"""
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional
|
||||
|
||||
from app.logger import get_logger
|
||||
from app.services.ai_clients.anthropic_client import AnthropicClient
|
||||
from .base_provider import BaseAIProvider
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class AnthropicProvider(BaseAIProvider):
|
||||
"""Anthropic 提供商"""
|
||||
@@ -39,7 +42,62 @@ class AnthropicProvider(BaseAIProvider):
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
system_prompt: Optional[str] = None,
|
||||
tools: Optional[List[Dict]] = None,
|
||||
tool_choice: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
# 如果有工具,使用真正的流式工具调用
|
||||
if tools:
|
||||
logger.debug(f"🔧 AnthropicProvider: 有 {len(tools)} 个工具,使用流式处理")
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
actual_tool_choice = tool_choice if tool_choice else "auto"
|
||||
|
||||
tool_calls_buffer = []
|
||||
|
||||
async for chunk in self.client.chat_completion_stream(
|
||||
messages=messages,
|
||||
model=model,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
system_prompt=system_prompt,
|
||||
tools=tools,
|
||||
tool_choice=actual_tool_choice,
|
||||
):
|
||||
# 检查是否有工具调用
|
||||
if chunk.get("tool_calls"):
|
||||
tool_calls_buffer.extend(chunk["tool_calls"])
|
||||
logger.debug(f"🔧 收到工具调用: {len(chunk['tool_calls'])} 个")
|
||||
|
||||
# 检查是否结束
|
||||
if chunk.get("done"):
|
||||
if tool_calls_buffer:
|
||||
logger.info(f"🔧 流式结束,处理 {len(tool_calls_buffer)} 个工具调用")
|
||||
from app.mcp import mcp_client
|
||||
actual_user_id = user_id or ""
|
||||
tool_results = await mcp_client.batch_call_tools(
|
||||
user_id=actual_user_id,
|
||||
tool_calls=tool_calls_buffer
|
||||
)
|
||||
# 将工具结果注入到上下文中
|
||||
tool_context = mcp_client.build_tool_context(tool_results, format="markdown")
|
||||
|
||||
# 构建最终提示词,要求AI基于工具结果回答
|
||||
final_prompt = f"{prompt}\n\n{tool_context}\n\n请基于以上工具查询结果,给出完整详细的回答。"
|
||||
final_messages = [{"role": "user", "content": final_prompt}]
|
||||
|
||||
# 递归调用生成最终结果
|
||||
async for final_chunk in self._generate_with_tools(
|
||||
final_messages, model, temperature, max_tokens, system_prompt, tools, user_id
|
||||
):
|
||||
yield final_chunk
|
||||
break
|
||||
|
||||
# 输出文本内容
|
||||
if chunk.get("content"):
|
||||
yield chunk["content"]
|
||||
return
|
||||
|
||||
# 无工具时普通流式生成
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
async for chunk in self.client.chat_completion_stream(
|
||||
messages=messages,
|
||||
@@ -48,4 +106,56 @@ class AnthropicProvider(BaseAIProvider):
|
||||
max_tokens=max_tokens,
|
||||
system_prompt=system_prompt,
|
||||
):
|
||||
yield chunk
|
||||
# 确保只 yield 字符串内容,避免 yield 字典导致类型错误
|
||||
if isinstance(chunk, dict):
|
||||
if chunk.get("content"):
|
||||
yield chunk["content"]
|
||||
else:
|
||||
yield chunk
|
||||
|
||||
async def _generate_with_tools(
|
||||
self,
|
||||
messages: list,
|
||||
model: str,
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
system_prompt: Optional[str] = None,
|
||||
tools: list = None,
|
||||
user_id: Optional[str] = None,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""辅助方法:带工具的流式生成"""
|
||||
tool_calls_buffer = []
|
||||
|
||||
async for chunk in self.client.chat_completion_stream(
|
||||
messages=messages,
|
||||
model=model,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
system_prompt=system_prompt,
|
||||
tools=tools,
|
||||
tool_choice="auto",
|
||||
):
|
||||
if chunk.get("tool_calls"):
|
||||
tool_calls_buffer.extend(chunk["tool_calls"])
|
||||
logger.debug(f"🔧 _generate_with_tools 收到工具调用: {len(chunk['tool_calls'])} 个")
|
||||
|
||||
if chunk.get("done"):
|
||||
if tool_calls_buffer:
|
||||
from app.mcp import mcp_client
|
||||
actual_user_id = user_id or ""
|
||||
tool_results = await mcp_client.batch_call_tools(
|
||||
user_id=actual_user_id,
|
||||
tool_calls=tool_calls_buffer
|
||||
)
|
||||
tool_context = mcp_client.build_tool_context(tool_results, format="markdown")
|
||||
|
||||
messages.append({"role": "user", "content": f"{tool_context}\n\n请基于以上工具查询结果,给出完整详细的回答。"})
|
||||
|
||||
async for final_chunk in self._generate_with_tools(
|
||||
messages, model, temperature, max_tokens, system_prompt, tools, user_id
|
||||
):
|
||||
yield final_chunk
|
||||
break
|
||||
|
||||
if chunk.get("content"):
|
||||
yield chunk["content"]
|
||||
@@ -28,6 +28,9 @@ class BaseAIProvider(ABC):
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
system_prompt: Optional[str] = None,
|
||||
tools: Optional[List[Dict]] = None,
|
||||
tool_choice: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""流式生成"""
|
||||
pass
|
||||
@@ -1,8 +1,12 @@
|
||||
"""Gemini Provider"""
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional
|
||||
|
||||
from app.logger import get_logger
|
||||
from app.services.ai_clients.gemini_client import GeminiClient
|
||||
from .base_provider import BaseAIProvider
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class GeminiProvider(BaseAIProvider):
|
||||
def __init__(self, client: GeminiClient):
|
||||
@@ -36,7 +40,62 @@ class GeminiProvider(BaseAIProvider):
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
system_prompt: Optional[str] = None,
|
||||
tools: Optional[List[Dict]] = None,
|
||||
tool_choice: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
# 如果有工具,使用真正的流式工具调用
|
||||
if tools:
|
||||
logger.debug(f"🔧 GeminiProvider: 有 {len(tools)} 个工具,使用流式处理")
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
actual_tool_choice = tool_choice if tool_choice else "auto"
|
||||
|
||||
tool_calls_buffer = []
|
||||
|
||||
async for chunk in self.client.chat_completion_stream(
|
||||
messages=messages,
|
||||
model=model,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
system_prompt=system_prompt,
|
||||
tools=tools,
|
||||
tool_choice=actual_tool_choice,
|
||||
):
|
||||
# 检查是否有工具调用
|
||||
if chunk.get("tool_calls"):
|
||||
tool_calls_buffer.extend(chunk["tool_calls"])
|
||||
logger.debug(f"🔧 收到工具调用: {len(chunk['tool_calls'])} 个")
|
||||
|
||||
# 检查是否结束
|
||||
if chunk.get("done"):
|
||||
if tool_calls_buffer:
|
||||
logger.info(f"🔧 流式结束,处理 {len(tool_calls_buffer)} 个工具调用")
|
||||
from app.mcp import mcp_client
|
||||
actual_user_id = user_id or ""
|
||||
tool_results = await mcp_client.batch_call_tools(
|
||||
user_id=actual_user_id,
|
||||
tool_calls=tool_calls_buffer
|
||||
)
|
||||
# 将工具结果注入到上下文中
|
||||
tool_context = mcp_client.build_tool_context(tool_results, format="markdown")
|
||||
|
||||
# 构建最终提示词,要求AI基于工具结果回答
|
||||
final_prompt = f"{prompt}\n\n{tool_context}\n\n请基于以上工具查询结果,给出完整详细的回答。"
|
||||
final_messages = [{"role": "user", "content": final_prompt}]
|
||||
|
||||
# 递归调用生成最终结果
|
||||
async for final_chunk in self._generate_with_tools(
|
||||
final_messages, model, temperature, max_tokens, system_prompt, tools, user_id
|
||||
):
|
||||
yield final_chunk
|
||||
break
|
||||
|
||||
# 输出文本内容
|
||||
if chunk.get("content"):
|
||||
yield chunk["content"]
|
||||
return
|
||||
|
||||
# 无工具时普通流式生成
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
async for chunk in self.client.chat_completion_stream(
|
||||
messages=messages,
|
||||
@@ -45,4 +104,56 @@ class GeminiProvider(BaseAIProvider):
|
||||
max_tokens=max_tokens,
|
||||
system_prompt=system_prompt,
|
||||
):
|
||||
yield chunk
|
||||
# 确保只 yield 字符串内容,避免 yield 字典导致类型错误
|
||||
if isinstance(chunk, dict):
|
||||
if chunk.get("content"):
|
||||
yield chunk["content"]
|
||||
else:
|
||||
yield chunk
|
||||
|
||||
async def _generate_with_tools(
|
||||
self,
|
||||
messages: list,
|
||||
model: str,
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
system_prompt: Optional[str] = None,
|
||||
tools: list = None,
|
||||
user_id: Optional[str] = None,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""辅助方法:带工具的流式生成"""
|
||||
tool_calls_buffer = []
|
||||
|
||||
async for chunk in self.client.chat_completion_stream(
|
||||
messages=messages,
|
||||
model=model,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
system_prompt=system_prompt,
|
||||
tools=tools,
|
||||
tool_choice="auto",
|
||||
):
|
||||
if chunk.get("tool_calls"):
|
||||
tool_calls_buffer.extend(chunk["tool_calls"])
|
||||
logger.debug(f"🔧 _generate_with_tools 收到工具调用: {len(chunk['tool_calls'])} 个")
|
||||
|
||||
if chunk.get("done"):
|
||||
if tool_calls_buffer:
|
||||
from app.mcp import mcp_client
|
||||
actual_user_id = user_id or ""
|
||||
tool_results = await mcp_client.batch_call_tools(
|
||||
user_id=actual_user_id,
|
||||
tool_calls=tool_calls_buffer
|
||||
)
|
||||
tool_context = mcp_client.build_tool_context(tool_results, format="markdown")
|
||||
|
||||
messages.append({"role": "user", "content": f"{tool_context}\n\n请基于以上工具查询结果,给出完整详细的回答。"})
|
||||
|
||||
async for final_chunk in self._generate_with_tools(
|
||||
messages, model, temperature, max_tokens, system_prompt, tools, user_id
|
||||
):
|
||||
yield final_chunk
|
||||
break
|
||||
|
||||
if chunk.get("content"):
|
||||
yield chunk["content"]
|
||||
@@ -1,9 +1,12 @@
|
||||
"""OpenAI Provider"""
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional
|
||||
|
||||
from app.logger import get_logger
|
||||
from app.services.ai_clients.openai_client import OpenAIClient
|
||||
from .base_provider import BaseAIProvider
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class OpenAIProvider(BaseAIProvider):
|
||||
"""OpenAI 提供商"""
|
||||
@@ -42,16 +45,117 @@ class OpenAIProvider(BaseAIProvider):
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
system_prompt: Optional[str] = None,
|
||||
tools: Optional[List[Dict]] = None,
|
||||
tool_choice: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
messages = []
|
||||
if system_prompt:
|
||||
messages.append({"role": "system", "content": system_prompt})
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
|
||||
# 如果有工具,使用真正的流式工具调用
|
||||
if tools:
|
||||
logger.debug(f"🔧 OpenAIProvider: 有 {len(tools)} 个工具,使用流式处理")
|
||||
actual_tool_choice = tool_choice if tool_choice else "auto"
|
||||
|
||||
tool_calls_buffer = []
|
||||
|
||||
async for chunk in self.client.chat_completion_stream(
|
||||
messages=messages,
|
||||
model=model,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
tools=tools,
|
||||
tool_choice=actual_tool_choice,
|
||||
):
|
||||
# 检查是否有工具调用
|
||||
if chunk.get("tool_calls"):
|
||||
tool_calls_buffer.extend(chunk["tool_calls"])
|
||||
logger.debug(f"🔧 收到工具调用: {len(chunk['tool_calls'])} 个")
|
||||
|
||||
# 检查是否结束
|
||||
if chunk.get("done"):
|
||||
if tool_calls_buffer:
|
||||
logger.info(f"🔧 流式结束,处理 {len(tool_calls_buffer)} 个工具调用")
|
||||
from app.mcp import mcp_client
|
||||
actual_user_id = user_id or ""
|
||||
tool_results = await mcp_client.batch_call_tools(
|
||||
user_id=actual_user_id,
|
||||
tool_calls=tool_calls_buffer
|
||||
)
|
||||
# 将工具结果注入到上下文中
|
||||
tool_context = mcp_client.build_tool_context(tool_results, format="markdown")
|
||||
|
||||
# 构建最终提示词,要求AI基于工具结果回答
|
||||
final_prompt = f"{prompt}\n\n{tool_context}\n\n请基于以上工具查询结果,给出完整详细的回答。"
|
||||
final_messages = messages.copy()
|
||||
final_messages.append({"role": "user", "content": final_prompt})
|
||||
|
||||
# 递归调用生成最终结果
|
||||
async for final_chunk in self._generate_with_tools(
|
||||
final_messages, model, temperature, max_tokens, tools, user_id
|
||||
):
|
||||
yield final_chunk
|
||||
break
|
||||
|
||||
# 输出文本内容
|
||||
if chunk.get("content"):
|
||||
yield chunk["content"]
|
||||
return
|
||||
|
||||
# 无工具时普通流式生成
|
||||
async for chunk in self.client.chat_completion_stream(
|
||||
messages=messages,
|
||||
model=model,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
):
|
||||
yield chunk
|
||||
# 确保只 yield 字符串内容,避免 yield 字典导致类型错误
|
||||
if isinstance(chunk, dict):
|
||||
if chunk.get("content"):
|
||||
yield chunk["content"]
|
||||
else:
|
||||
yield chunk
|
||||
|
||||
async def _generate_with_tools(
|
||||
self,
|
||||
messages: list,
|
||||
model: str,
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
tools: list,
|
||||
user_id: Optional[str] = None,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""辅助方法:带工具的流式生成(无tool_choice,AI自由决定)"""
|
||||
async for chunk in self.client.chat_completion_stream(
|
||||
messages=messages,
|
||||
model=model,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
tools=tools,
|
||||
tool_choice="auto",
|
||||
):
|
||||
if chunk.get("tool_calls"):
|
||||
from app.mcp import mcp_client
|
||||
actual_user_id = user_id or ""
|
||||
tool_results = await mcp_client.batch_call_tools(
|
||||
user_id=actual_user_id,
|
||||
tool_calls=chunk["tool_calls"]
|
||||
)
|
||||
tool_context = mcp_client.build_tool_context(tool_results, format="markdown")
|
||||
|
||||
# 再次调用获取最终回答
|
||||
messages.append({"role": "user", "content": f"{tool_context}\n\n请基于以上工具查询结果,给出完整详细的回答。"})
|
||||
|
||||
async for final_chunk in self._generate_with_tools(
|
||||
messages, model, temperature, max_tokens, tools, user_id
|
||||
):
|
||||
yield final_chunk
|
||||
break
|
||||
|
||||
if chunk.get("done"):
|
||||
break
|
||||
|
||||
if chunk.get("content"):
|
||||
yield chunk["content"]
|
||||
+391
-112
@@ -1,4 +1,10 @@
|
||||
"""AI服务封装 - 统一的AI接口"""
|
||||
"""AI服务封装 - 统一的AI接口
|
||||
|
||||
重构后支持自动MCP工具加载:
|
||||
- 所有AI方法在请求前自动检查用户MCP配置
|
||||
- 如果有启用的MCP插件且有可用工具,自动发送tools
|
||||
- 通过 auto_mcp 参数控制是否启用自动工具加载
|
||||
"""
|
||||
from typing import Optional, AsyncGenerator, List, Dict, Any, Union
|
||||
|
||||
from app.config import settings as app_settings
|
||||
@@ -13,7 +19,6 @@ from app.services.ai_providers.anthropic_provider import AnthropicProvider
|
||||
from app.services.ai_providers.gemini_provider import GeminiProvider
|
||||
from app.services.ai_providers.base_provider import BaseAIProvider
|
||||
from app.services.json_helper import clean_json_response, parse_json
|
||||
from app.mcp.adapters.universal import universal_mcp_adapter
|
||||
|
||||
# 导出清理函数
|
||||
cleanup_http_clients = cleanup_all_clients
|
||||
@@ -22,7 +27,41 @@ logger = get_logger(__name__)
|
||||
|
||||
|
||||
class AIService:
|
||||
"""AI服务统一接口"""
|
||||
"""
|
||||
AI服务统一接口
|
||||
|
||||
MCP工具支持:
|
||||
- 在创建服务时传入 user_id 和 db_session
|
||||
- 根据用户MCP插件的enabled状态自动决定是否启用MCP
|
||||
- 如果有任意一个MCP插件启用,则加载并使用工具
|
||||
- 如果所有插件都关闭,则不使用任何MCP工具
|
||||
- 通过 auto_mcp=False 可临时禁用自动工具加载
|
||||
- 通过 mcp_max_rounds 控制工具调用轮数
|
||||
- 通过 clear_mcp_cache() 可清理MCP工具缓存
|
||||
|
||||
MCP启用逻辑(backend/app/api/settings.py 中的 get_user_ai_service):
|
||||
- 查询用户的所有MCP插件
|
||||
- 如果有启用的插件 (enabled=True),则 enable_mcp=True
|
||||
- 如果所有插件都关闭或没有插件,则 enable_mcp=False
|
||||
|
||||
使用示例:
|
||||
# 创建支持MCP的AI服务(根据插件状态自动决定是否启用)
|
||||
ai_service = create_user_ai_service_with_mcp(
|
||||
api_provider="openai",
|
||||
api_key="...",
|
||||
user_id="user123",
|
||||
db_session=db
|
||||
)
|
||||
|
||||
# 自动加载MCP工具(如果有启用的插件)
|
||||
result = await ai_service.generate_text(prompt="...")
|
||||
|
||||
# 临时禁用MCP工具
|
||||
result = await ai_service.generate_text(prompt="...", auto_mcp=False)
|
||||
|
||||
# 自定义轮数
|
||||
result = await ai_service.generate_text(prompt="...", mcp_max_rounds=3)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -33,8 +72,11 @@ class AIService:
|
||||
default_temperature: Optional[float] = None,
|
||||
default_max_tokens: Optional[int] = None,
|
||||
default_system_prompt: Optional[str] = None,
|
||||
enable_mcp_adapter: bool = True,
|
||||
config: Optional[AIClientConfig] = None,
|
||||
# MCP支持参数
|
||||
user_id: Optional[str] = None,
|
||||
db_session: Optional[Any] = None,
|
||||
enable_mcp: bool = True,
|
||||
):
|
||||
self.api_provider = api_provider or app_settings.default_ai_provider
|
||||
self.default_model = default_model or app_settings.default_model
|
||||
@@ -43,7 +85,12 @@ class AIService:
|
||||
self.default_system_prompt = default_system_prompt
|
||||
self.config = config or default_config
|
||||
|
||||
self.mcp_adapter = universal_mcp_adapter if enable_mcp_adapter else None
|
||||
# MCP配置
|
||||
self.user_id = user_id
|
||||
self.db_session = db_session
|
||||
self._enable_mcp = enable_mcp
|
||||
self._cached_tools: Optional[List[Dict]] = None
|
||||
self._tools_loaded = False
|
||||
|
||||
self._openai_provider: Optional[OpenAIProvider] = None
|
||||
self._anthropic_provider: Optional[AnthropicProvider] = None
|
||||
@@ -68,6 +115,36 @@ class AIService:
|
||||
client = GeminiClient(api_key, api_base_url, self.config)
|
||||
self._gemini_provider = GeminiProvider(client)
|
||||
|
||||
@property
|
||||
def enable_mcp(self) -> bool:
|
||||
"""是否启用MCP工具"""
|
||||
return self._enable_mcp
|
||||
|
||||
@enable_mcp.setter
|
||||
def enable_mcp(self, value: bool):
|
||||
"""设置MCP启用状态,如果禁用则清理缓存"""
|
||||
if value is False and self._enable_mcp is True:
|
||||
# 从启用变为禁用,清理缓存
|
||||
self.clear_mcp_cache()
|
||||
self._enable_mcp = value
|
||||
|
||||
def clear_mcp_cache(self):
|
||||
"""
|
||||
清理MCP工具缓存
|
||||
|
||||
当禁用MCP时调用此方法,确保后续AI调用不会使用缓存的工具。
|
||||
同时更新 _tools_loaded 状态,使下次调用时重新检查。
|
||||
"""
|
||||
if self._cached_tools is not None:
|
||||
logger.info(f"🔧 清理MCP工具缓存,移除 {len(self._cached_tools)} 个工具")
|
||||
self._cached_tools = None
|
||||
else:
|
||||
logger.debug(f"🔧 MCP工具缓存已经是空,无需清理")
|
||||
|
||||
# 更新加载状态,确保下次调用会重新检查
|
||||
self._tools_loaded = False
|
||||
logger.debug(f"🔧 MCP工具状态已重置: enable_mcp={self._enable_mcp}, _tools_loaded=False")
|
||||
|
||||
def _get_provider(self, provider: Optional[str] = None) -> BaseAIProvider:
|
||||
"""获取对应的 Provider"""
|
||||
p = provider or self.api_provider
|
||||
@@ -79,6 +156,166 @@ class AIService:
|
||||
return self._gemini_provider
|
||||
raise ValueError(f"Provider {p} 未初始化")
|
||||
|
||||
async def _prepare_mcp_tools(self, auto_mcp: bool = True, force_refresh: bool = False) -> Optional[List[Dict]]:
|
||||
"""
|
||||
预处理MCP工具
|
||||
|
||||
检查用户MCP配置并加载可用工具。
|
||||
结果会被缓存,避免重复加载。
|
||||
|
||||
Args:
|
||||
auto_mcp: 是否自动加载MCP工具(来自调用方参数)
|
||||
force_refresh: 是否强制刷新缓存
|
||||
|
||||
Returns:
|
||||
- None: 无可用工具(未配置/未启用/加载失败)
|
||||
- List[Dict]: OpenAI格式的工具列表
|
||||
"""
|
||||
# 前置条件检查
|
||||
if not self._enable_mcp:
|
||||
logger.debug(f"🔧 MCP工具未启用 (_enable_mcp=False)")
|
||||
# 即使有缓存也清理掉,确保不使用
|
||||
self._cached_tools = None
|
||||
self._tools_loaded = False
|
||||
return None
|
||||
|
||||
if not auto_mcp:
|
||||
logger.debug(f"🔧 auto_mcp=False,跳过MCP工具加载")
|
||||
# 即使有缓存也清理掉,确保不使用
|
||||
self._cached_tools = None
|
||||
self._tools_loaded = False
|
||||
return None
|
||||
|
||||
if not self.user_id:
|
||||
logger.debug(f"🔧 MCP工具加载跳过: user_id未设置")
|
||||
return None
|
||||
|
||||
if not self.db_session:
|
||||
logger.debug(f"🔧 MCP工具加载跳过: db_session未设置")
|
||||
return None
|
||||
|
||||
# 使用缓存(只有 enable_mcp=True 时才使用缓存)
|
||||
if self._tools_loaded and not force_refresh:
|
||||
if self._cached_tools:
|
||||
logger.debug(f"🔧 使用缓存的MCP工具 ({len(self._cached_tools)}个)")
|
||||
return self._cached_tools
|
||||
|
||||
try:
|
||||
from app.services.mcp_tools_loader import mcp_tools_loader
|
||||
|
||||
self._cached_tools = await mcp_tools_loader.get_user_tools(
|
||||
user_id=self.user_id,
|
||||
db_session=self.db_session,
|
||||
use_cache=True,
|
||||
force_refresh=force_refresh
|
||||
)
|
||||
self._tools_loaded = True
|
||||
|
||||
if self._cached_tools:
|
||||
logger.info(f"🔧 已加载 {len(self._cached_tools)} 个MCP工具")
|
||||
else:
|
||||
logger.debug(f"📭 用户 {self.user_id} 没有可用的MCP工具")
|
||||
|
||||
return self._cached_tools
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"⚠️ 加载MCP工具失败: {e}")
|
||||
self._tools_loaded = True
|
||||
self._cached_tools = None
|
||||
return None
|
||||
|
||||
async def _handle_tool_calls(
|
||||
self,
|
||||
original_prompt: str,
|
||||
response: Dict[str, Any],
|
||||
max_rounds: int = 2,
|
||||
**kwargs
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
处理AI返回的工具调用
|
||||
|
||||
Args:
|
||||
original_prompt: 原始提示词
|
||||
response: AI响应(包含tool_calls)
|
||||
max_rounds: 最大工具调用轮数
|
||||
**kwargs: 传递给generate_text的其他参数
|
||||
|
||||
Returns:
|
||||
最终的AI响应
|
||||
"""
|
||||
from app.mcp import mcp_client
|
||||
|
||||
tool_calls = response.get("tool_calls", [])
|
||||
if not tool_calls or not self.user_id:
|
||||
return response
|
||||
|
||||
result = {
|
||||
"content": response.get("content", ""),
|
||||
"tool_calls_made": 0,
|
||||
"tools_used": [],
|
||||
"finish_reason": response.get("finish_reason", ""),
|
||||
"mcp_enhanced": True
|
||||
}
|
||||
|
||||
prompt = original_prompt
|
||||
|
||||
for round_num in range(max_rounds):
|
||||
logger.info(f"🔧 工具调用 - 第{round_num+1}/{max_rounds}轮,{len(tool_calls)}个工具")
|
||||
|
||||
try:
|
||||
# 批量执行工具调用
|
||||
tool_results = await mcp_client.batch_call_tools(
|
||||
user_id=self.user_id,
|
||||
tool_calls=tool_calls
|
||||
)
|
||||
|
||||
# 记录使用的工具
|
||||
for tc in tool_calls:
|
||||
name = tc["function"]["name"]
|
||||
if name not in result["tools_used"]:
|
||||
result["tools_used"].append(name)
|
||||
result["tool_calls_made"] += len(tool_calls)
|
||||
|
||||
# 构建工具上下文
|
||||
tool_context = mcp_client.build_tool_context(tool_results, format="markdown")
|
||||
|
||||
# 更新提示词
|
||||
if round_num == max_rounds - 1:
|
||||
# 最后一轮,强制要求回答
|
||||
prompt = f"{original_prompt}\n\n{tool_context}\n\n⚠️ 重要:请基于以上工具查询结果,给出完整详细的最终答案。不要再调用工具。"
|
||||
tool_choice = "none"
|
||||
else:
|
||||
prompt = f"{original_prompt}\n\n{tool_context}\n\n请基于以上工具查询结果,继续完成任务。"
|
||||
tool_choice = kwargs.get("tool_choice", "auto")
|
||||
|
||||
# 继续调用AI
|
||||
prov = self._get_provider(kwargs.get("provider"))
|
||||
next_response = await prov.generate(
|
||||
prompt=prompt,
|
||||
model=kwargs.get("model") or self.default_model,
|
||||
temperature=kwargs.get("temperature") or self.default_temperature,
|
||||
max_tokens=kwargs.get("max_tokens") or self.default_max_tokens,
|
||||
system_prompt=kwargs.get("system_prompt") or self.default_system_prompt,
|
||||
tools=None if tool_choice == "none" else self._cached_tools,
|
||||
tool_choice=tool_choice,
|
||||
)
|
||||
|
||||
tool_calls = next_response.get("tool_calls", [])
|
||||
|
||||
if not tool_calls:
|
||||
# 没有更多工具调用,返回结果
|
||||
result["content"] = next_response.get("content", "")
|
||||
result["finish_reason"] = next_response.get("finish_reason", "stop")
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 工具调用失败: {e}")
|
||||
result["content"] = response.get("content", "")
|
||||
result["finish_reason"] = "tool_error"
|
||||
break
|
||||
|
||||
return result
|
||||
|
||||
async def generate_text(
|
||||
self,
|
||||
prompt: str,
|
||||
@@ -89,10 +326,39 @@ class AIService:
|
||||
system_prompt: Optional[str] = None,
|
||||
tools: Optional[List[Dict]] = None,
|
||||
tool_choice: Optional[str] = None,
|
||||
auto_mcp: bool = True,
|
||||
handle_tool_calls: bool = True,
|
||||
mcp_max_rounds: Optional[int] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""生成文本"""
|
||||
"""
|
||||
生成文本(自动支持MCP工具)
|
||||
|
||||
Args:
|
||||
prompt: 用户提示词
|
||||
provider: AI提供商
|
||||
model: 模型名称
|
||||
temperature: 温度
|
||||
max_tokens: 最大令牌数
|
||||
system_prompt: 系统提示词
|
||||
tools: 手动指定的工具列表(优先级高于自动加载)
|
||||
tool_choice: 工具选择策略
|
||||
auto_mcp: 是否自动加载MCP工具(默认True)
|
||||
handle_tool_calls: 是否自动处理工具调用(默认True)
|
||||
mcp_max_rounds: 最大工具调用轮数(None使用默认值3)
|
||||
|
||||
Returns:
|
||||
包含生成内容的字典
|
||||
"""
|
||||
# 使用全局配置的MCP轮数(如果未指定)
|
||||
if mcp_max_rounds is None:
|
||||
mcp_max_rounds = app_settings.mcp_max_rounds
|
||||
|
||||
# 自动加载MCP工具
|
||||
if auto_mcp and tools is None:
|
||||
tools = await self._prepare_mcp_tools(auto_mcp=auto_mcp)
|
||||
|
||||
prov = self._get_provider(provider)
|
||||
return await prov.generate(
|
||||
response = await prov.generate(
|
||||
prompt=prompt,
|
||||
model=model or self.default_model,
|
||||
temperature=temperature or self.default_temperature,
|
||||
@@ -101,6 +367,22 @@ class AIService:
|
||||
tools=tools,
|
||||
tool_choice=tool_choice,
|
||||
)
|
||||
|
||||
# 处理工具调用
|
||||
if handle_tool_calls and response.get("tool_calls"):
|
||||
return await self._handle_tool_calls(
|
||||
original_prompt=prompt,
|
||||
response=response,
|
||||
provider=provider,
|
||||
model=model,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
system_prompt=system_prompt,
|
||||
tool_choice=tool_choice,
|
||||
max_rounds=mcp_max_rounds,
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
async def generate_text_stream(
|
||||
self,
|
||||
@@ -110,15 +392,51 @@ class AIService:
|
||||
temperature: Optional[float] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
system_prompt: Optional[str] = None,
|
||||
tool_choice: Optional[str] = None,
|
||||
auto_mcp: bool = True,
|
||||
mcp_max_rounds: Optional[int] = None,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""流式生成"""
|
||||
"""
|
||||
流式生成文本(自动支持MCP工具)
|
||||
|
||||
工具调用在 Provider 层通过流式方式处理,支持真正的流式工具调用。
|
||||
|
||||
Args:
|
||||
prompt: 用户提示词
|
||||
provider: AI提供商
|
||||
model: 模型名称
|
||||
temperature: 温度
|
||||
max_tokens: 最大令牌数
|
||||
system_prompt: 系统提示词
|
||||
tool_choice: 工具选择策略("auto"/"none"/"required")
|
||||
auto_mcp: 是否自动加载MCP工具
|
||||
mcp_max_rounds: 最大工具调用轮数(None使用默认值3)
|
||||
|
||||
Yields:
|
||||
生成的文本块
|
||||
"""
|
||||
logger.debug(f"🔧 generate_text_stream: auto_mcp={auto_mcp}, tool_choice={tool_choice}")
|
||||
|
||||
tools_to_use = None
|
||||
|
||||
# 加载MCP工具
|
||||
if auto_mcp:
|
||||
tools_to_use = await self._prepare_mcp_tools(auto_mcp=auto_mcp)
|
||||
if tools_to_use:
|
||||
logger.info(f"🔧 已获取 {len(tools_to_use)} 个MCP工具")
|
||||
|
||||
# 流式生成(Provider 层处理工具调用)
|
||||
prov = self._get_provider(provider)
|
||||
logger.debug(f"🔧 开始流式生成,provider={provider or self.api_provider}, tools_count={len(tools_to_use) if tools_to_use else 0}")
|
||||
async for chunk in prov.generate_stream(
|
||||
prompt=prompt,
|
||||
model=model or self.default_model,
|
||||
temperature=temperature or self.default_temperature,
|
||||
max_tokens=max_tokens or self.default_max_tokens,
|
||||
system_prompt=system_prompt or self.default_system_prompt,
|
||||
tools=tools_to_use,
|
||||
tool_choice=tool_choice,
|
||||
user_id=self.user_id,
|
||||
):
|
||||
yield chunk
|
||||
|
||||
@@ -132,8 +450,25 @@ class AIService:
|
||||
provider: Optional[str] = None,
|
||||
model: Optional[str] = None,
|
||||
expected_type: Optional[str] = None,
|
||||
auto_mcp: bool = True,
|
||||
) -> Union[Dict, List]:
|
||||
"""带重试的 JSON 调用"""
|
||||
"""
|
||||
带重试的 JSON 调用(自动支持MCP工具)
|
||||
|
||||
Args:
|
||||
prompt: 用户提示词
|
||||
system_prompt: 系统提示词
|
||||
max_retries: 最大重试次数
|
||||
temperature: 温度
|
||||
max_tokens: 最大令牌数
|
||||
provider: AI提供商
|
||||
model: 模型名称
|
||||
expected_type: 期望的返回类型("object"或"array")
|
||||
auto_mcp: 是否自动加载MCP工具
|
||||
|
||||
Returns:
|
||||
解析后的JSON数据
|
||||
"""
|
||||
last_response = ""
|
||||
|
||||
for attempt in range(1, max_retries + 1):
|
||||
@@ -146,6 +481,8 @@ class AIService:
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
system_prompt=system_prompt,
|
||||
auto_mcp=auto_mcp,
|
||||
handle_tool_calls=True,
|
||||
)
|
||||
|
||||
last_response = result.get("content", "")
|
||||
@@ -172,108 +509,6 @@ class AIService:
|
||||
"""清洗 JSON 响应"""
|
||||
return clean_json_response(text)
|
||||
|
||||
async def generate_text_with_mcp(
|
||||
self,
|
||||
prompt: str,
|
||||
user_id: str,
|
||||
db_session,
|
||||
enable_mcp: bool = True,
|
||||
max_tool_rounds: int = 3,
|
||||
tool_choice: str = "auto",
|
||||
**kwargs
|
||||
) -> Dict[str, Any]:
|
||||
"""支持MCP工具的AI文本生成"""
|
||||
from app.services.mcp_tool_service import mcp_tool_service, MCPToolServiceError
|
||||
|
||||
result = {"content": "", "tool_calls_made": 0, "tools_used": [], "finish_reason": "", "mcp_enhanced": False}
|
||||
tools = None
|
||||
|
||||
if enable_mcp:
|
||||
try:
|
||||
tools = await mcp_tool_service.get_user_enabled_tools(user_id=user_id, db_session=db_session)
|
||||
if tools:
|
||||
result["mcp_enhanced"] = True
|
||||
except MCPToolServiceError:
|
||||
tools = None
|
||||
|
||||
original_prompt = prompt # 保存原始提示词
|
||||
|
||||
for round_num in range(max_tool_rounds):
|
||||
logger.debug(f"🔄 MCP工具调用 - 第{round_num+1}/{max_tool_rounds}轮")
|
||||
logger.debug(f" prompt长度: {len(prompt)}, tools数量: {len(tools) if tools else 0}, tool_choice: {tool_choice}")
|
||||
|
||||
ai_response = await self.generate_text(prompt=prompt, tools=tools, tool_choice=tool_choice, **kwargs)
|
||||
logger.debug(f" AI响应: finish_reason={ai_response.get('finish_reason')}, content长度={len(ai_response.get('content', ''))}")
|
||||
|
||||
tool_calls = ai_response.get("tool_calls", [])
|
||||
|
||||
if not tool_calls:
|
||||
content = ai_response.get("content", "")
|
||||
result["content"] = content
|
||||
result["finish_reason"] = ai_response.get("finish_reason", "stop")
|
||||
logger.debug(f" ✅ 无工具调用,返回内容长度: {len(content)}")
|
||||
|
||||
# 🔧 修复:如果内容为空且已经调用过工具,强制要求AI给出答案
|
||||
if not content.strip() and result["tool_calls_made"] > 0:
|
||||
logger.warning(f"⚠️ AI在工具调用后返回空内容,尝试强制要求回答(第{round_num+1}轮)")
|
||||
prompt = f"{prompt}\n\n⚠️ 请注意:你必须基于以上工具查询结果,给出完整的回答。不要返回空内容。"
|
||||
tools = None
|
||||
tool_choice = "none" # 强制不使用工具
|
||||
continue
|
||||
|
||||
break
|
||||
|
||||
logger.info(f"🔧 检测到 {len(tool_calls)} 个工具调用")
|
||||
for idx, tc in enumerate(tool_calls):
|
||||
logger.debug(f" 工具{idx+1}: {tc.get('function', {}).get('name')} - 参数: {tc.get('function', {}).get('arguments')}")
|
||||
|
||||
try:
|
||||
logger.debug(f" 开始执行工具调用...")
|
||||
tool_results = await mcp_tool_service.execute_tool_calls(user_id=user_id, tool_calls=tool_calls, db_session=db_session)
|
||||
logger.debug(f" 工具执行完成,结果数量: {len(tool_results)}")
|
||||
|
||||
# 🔍 检查工具结果
|
||||
for idx, tr in enumerate(tool_results):
|
||||
success = tr.get("success", False)
|
||||
content_preview = tr.get("content", "")[:200] if tr.get("content") else "None"
|
||||
logger.debug(f" 工具结果[{idx}]: success={success}, content预览={content_preview}")
|
||||
|
||||
for tc in tool_calls:
|
||||
name = tc["function"]["name"]
|
||||
if name not in result["tools_used"]:
|
||||
result["tools_used"].append(name)
|
||||
result["tool_calls_made"] += len(tool_calls)
|
||||
|
||||
tool_context = await mcp_tool_service.build_tool_context(tool_results, format="markdown")
|
||||
logger.debug(f" 工具上下文长度: {len(tool_context)}")
|
||||
logger.debug(f" 工具上下文预览: {tool_context[:300] if len(tool_context) > 300 else tool_context}")
|
||||
|
||||
# 🔧 改进:在最后一轮时,明确要求AI给出完整答案
|
||||
if round_num == max_tool_rounds - 1:
|
||||
logger.info(f"⚠️ 最后一轮,强制要求AI给出最终答案")
|
||||
prompt = f"{original_prompt}\n\n{tool_context}\n\n⚠️ 重要:这是最后一轮,请基于以上工具查询的参考资料,给出完整详细的最终答案。不要再调用工具。"
|
||||
tool_choice = "none"
|
||||
else:
|
||||
prompt = f"{original_prompt}\n\n{tool_context}\n\n请基于以上工具查询结果,继续完成任务。"
|
||||
logger.debug(f" 新prompt长度: {len(prompt)}")
|
||||
|
||||
tools = None # 工具调用后禁用工具列表,避免重复调用
|
||||
logger.debug(f" ✅ 工具调用成功,准备下一轮")
|
||||
|
||||
except Exception as tool_error:
|
||||
logger.error(f"❌ 工具调用执行失败: {tool_error}", exc_info=True)
|
||||
logger.error(f" 错误类型: {type(tool_error).__name__}")
|
||||
logger.error(f" AI响应内容: {ai_response.get('content', '')[:200]}")
|
||||
result["content"] = ai_response.get("content", "")
|
||||
result["finish_reason"] = "tool_error"
|
||||
break
|
||||
|
||||
return result
|
||||
|
||||
|
||||
# 全局实例
|
||||
ai_service = AIService()
|
||||
|
||||
|
||||
def create_user_ai_service(
|
||||
api_provider: str,
|
||||
@@ -284,7 +519,7 @@ def create_user_ai_service(
|
||||
max_tokens: int,
|
||||
system_prompt: Optional[str] = None,
|
||||
) -> AIService:
|
||||
"""创建用户 AI 服务"""
|
||||
"""创建用户 AI 服务(不带MCP支持)"""
|
||||
return AIService(
|
||||
api_provider=api_provider,
|
||||
api_key=api_key,
|
||||
@@ -293,4 +528,48 @@ def create_user_ai_service(
|
||||
default_temperature=temperature,
|
||||
default_max_tokens=max_tokens,
|
||||
default_system_prompt=system_prompt,
|
||||
)
|
||||
|
||||
|
||||
def create_user_ai_service_with_mcp(
|
||||
api_provider: str,
|
||||
api_key: str,
|
||||
api_base_url: str,
|
||||
model_name: str,
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
user_id: str,
|
||||
db_session,
|
||||
system_prompt: Optional[str] = None,
|
||||
enable_mcp: bool = True,
|
||||
) -> AIService:
|
||||
"""
|
||||
创建支持MCP的用户AI服务
|
||||
|
||||
Args:
|
||||
api_provider: AI提供商
|
||||
api_key: API密钥
|
||||
api_base_url: API基础URL
|
||||
model_name: 模型名称
|
||||
temperature: 温度
|
||||
max_tokens: 最大令牌数
|
||||
user_id: 用户ID(用于加载MCP工具)
|
||||
db_session: 数据库会话
|
||||
system_prompt: 系统提示词
|
||||
enable_mcp: 是否启用MCP工具
|
||||
|
||||
Returns:
|
||||
配置好的AIService实例
|
||||
"""
|
||||
return AIService(
|
||||
api_provider=api_provider,
|
||||
api_key=api_key,
|
||||
api_base_url=api_base_url,
|
||||
default_model=model_name,
|
||||
default_temperature=temperature,
|
||||
default_max_tokens=max_tokens,
|
||||
default_system_prompt=system_prompt,
|
||||
user_id=user_id,
|
||||
db_session=db_session,
|
||||
enable_mcp=enable_mcp,
|
||||
)
|
||||
@@ -269,25 +269,11 @@ class AutoCharacterService:
|
||||
)
|
||||
|
||||
try:
|
||||
# 调用AI分析(使用统一的JSON调用方法)
|
||||
if enable_mcp and user_id:
|
||||
result = await self.ai_service.generate_text_with_mcp(
|
||||
prompt=prompt,
|
||||
user_id=user_id,
|
||||
db_session=db,
|
||||
enable_mcp=True,
|
||||
max_tool_rounds=2
|
||||
)
|
||||
content = result.get("content", "")
|
||||
# 使用统一的JSON清洗方法
|
||||
cleaned = self.ai_service._clean_json_response(content)
|
||||
analysis = json.loads(cleaned)
|
||||
else:
|
||||
# 非MCP调用:使用带自动重试的JSON调用
|
||||
analysis = await self.ai_service.call_with_json_retry(
|
||||
prompt=prompt,
|
||||
max_retries=3
|
||||
)
|
||||
# 使用统一的JSON调用方法(支持自动MCP工具加载)
|
||||
analysis = await self.ai_service.call_with_json_retry(
|
||||
prompt=prompt,
|
||||
max_retries=3,
|
||||
)
|
||||
|
||||
logger.info(f" ✅ AI分析完成: needs_new_characters={analysis.get('needs_new_characters')}")
|
||||
return analysis
|
||||
@@ -364,16 +350,16 @@ class AutoCharacterService:
|
||||
existing_characters=existing_chars_summary + careers_info,
|
||||
plot_context="根据剧情需要引入的新角色",
|
||||
character_specification=json.dumps(spec, ensure_ascii=False, indent=2),
|
||||
mcp_references="" # 暂时不使用MCP增强
|
||||
mcp_references="" # MCP工具通过AI服务自动加载
|
||||
)
|
||||
|
||||
# 调用AI生成(禁用MCP,避免累积超时导致卡死)
|
||||
logger.info(f"🔧 角色详情生成: enable_mcp={enable_mcp}")
|
||||
|
||||
# 调用AI生成
|
||||
try:
|
||||
# 🔧 优化:角色详情生成不使用MCP,只在分析阶段使用MCP
|
||||
# 这样可以减少大量的外部工具调用,避免超时和卡死
|
||||
character_data = await self.ai_service.call_with_json_retry(
|
||||
prompt=prompt,
|
||||
max_retries=2 # 减少重试次数以加快速度
|
||||
max_retries=2, # 减少重试次数以加快速度
|
||||
)
|
||||
|
||||
char_name = character_data.get('name', '未知')
|
||||
|
||||
@@ -292,25 +292,11 @@ class AutoOrganizationService:
|
||||
)
|
||||
|
||||
try:
|
||||
# 调用AI分析(使用统一的JSON调用方法)
|
||||
if enable_mcp and user_id:
|
||||
result = await self.ai_service.generate_text_with_mcp(
|
||||
prompt=prompt,
|
||||
user_id=user_id,
|
||||
db_session=db,
|
||||
enable_mcp=True,
|
||||
max_tool_rounds=2
|
||||
)
|
||||
content = result.get("content", "")
|
||||
# 使用统一的JSON清洗方法
|
||||
cleaned = self.ai_service._clean_json_response(content)
|
||||
analysis = json.loads(cleaned)
|
||||
else:
|
||||
# 非MCP调用:使用带自动重试的JSON调用
|
||||
analysis = await self.ai_service.call_with_json_retry(
|
||||
prompt=prompt,
|
||||
max_retries=3
|
||||
)
|
||||
# 使用统一的JSON调用方法(支持自动MCP工具加载)
|
||||
analysis = await self.ai_service.call_with_json_retry(
|
||||
prompt=prompt,
|
||||
max_retries=3,
|
||||
)
|
||||
|
||||
logger.info(f" ✅ AI分析完成: needs_new_organizations={analysis.get('needs_new_organizations')}")
|
||||
return analysis
|
||||
@@ -362,24 +348,11 @@ class AutoOrganizationService:
|
||||
|
||||
# 调用AI生成(使用统一的JSON调用方法)
|
||||
try:
|
||||
if enable_mcp and user_id:
|
||||
result = await self.ai_service.generate_text_with_mcp(
|
||||
prompt=prompt,
|
||||
user_id=user_id,
|
||||
db_session=db,
|
||||
enable_mcp=True,
|
||||
max_tool_rounds=2
|
||||
)
|
||||
content = result.get("content", "")
|
||||
# 使用统一的JSON清洗方法
|
||||
cleaned = self.ai_service._clean_json_response(content)
|
||||
organization_data = json.loads(cleaned)
|
||||
else:
|
||||
# 非MCP调用:使用带自动重试的JSON调用
|
||||
organization_data = await self.ai_service.call_with_json_retry(
|
||||
prompt=prompt,
|
||||
max_retries=3
|
||||
)
|
||||
# 使用统一的JSON调用方法(支持自动MCP工具加载)
|
||||
organization_data = await self.ai_service.call_with_json_retry(
|
||||
prompt=prompt,
|
||||
max_retries=3,
|
||||
)
|
||||
|
||||
org_name = organization_data.get('name', '未知')
|
||||
logger.info(f" ✅ 组织详情生成成功: {org_name}")
|
||||
|
||||
@@ -1,4 +1,7 @@
|
||||
"""MCP插件测试服务 - 专门处理插件测试逻辑"""
|
||||
"""MCP插件测试服务 - 专门处理插件测试逻辑
|
||||
|
||||
重构后使用统一的MCPClientFacade门面来管理所有MCP操作。
|
||||
"""
|
||||
|
||||
import time
|
||||
import json
|
||||
@@ -10,7 +13,7 @@ from sqlalchemy import select
|
||||
|
||||
from app.models.mcp_plugin import MCPPlugin
|
||||
from app.models.settings import Settings as UserSettings
|
||||
from app.mcp.registry import mcp_registry
|
||||
from app.mcp import mcp_client, MCPPluginConfig # 使用新的统一门面
|
||||
from app.services.ai_service import create_user_ai_service
|
||||
from app.schemas.mcp_plugin import MCPTestResult
|
||||
from app.services.prompt_service import prompt_service
|
||||
@@ -21,7 +24,32 @@ logger = get_logger(__name__)
|
||||
|
||||
|
||||
class MCPTestService:
|
||||
"""MCP插件测试服务(分离的测试逻辑)"""
|
||||
"""MCP插件测试服务(使用统一门面重构)"""
|
||||
|
||||
async def _ensure_plugin_registered(
|
||||
self,
|
||||
plugin: MCPPlugin,
|
||||
user_id: str
|
||||
) -> bool:
|
||||
"""
|
||||
确保插件已注册到统一门面
|
||||
|
||||
Args:
|
||||
plugin: 插件配置
|
||||
user_id: 用户ID
|
||||
|
||||
Returns:
|
||||
是否成功
|
||||
"""
|
||||
if plugin.plugin_type in ("http", "streamable_http", "sse") and plugin.server_url:
|
||||
return await mcp_client.ensure_registered(
|
||||
user_id=user_id,
|
||||
plugin_name=plugin.plugin_name,
|
||||
url=plugin.server_url,
|
||||
plugin_type=plugin.plugin_type,
|
||||
headers=plugin.headers
|
||||
)
|
||||
return False
|
||||
|
||||
async def test_plugin_connection(
|
||||
self,
|
||||
@@ -41,19 +69,18 @@ class MCPTestService:
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# 确保插件已加载
|
||||
if not mcp_registry.get_client(user_id, plugin.plugin_name):
|
||||
success = await mcp_registry.load_plugin(plugin)
|
||||
if not success:
|
||||
return MCPTestResult(
|
||||
success=False,
|
||||
message="插件加载失败",
|
||||
error="无法创建MCP客户端",
|
||||
suggestions=["请检查插件配置", "请确认服务器URL正确"]
|
||||
)
|
||||
# 确保插件已注册
|
||||
registered = await self._ensure_plugin_registered(plugin, user_id)
|
||||
if not registered:
|
||||
return MCPTestResult(
|
||||
success=False,
|
||||
message="插件注册失败",
|
||||
error="无法创建MCP客户端",
|
||||
suggestions=["请检查插件配置", "请确认服务器URL正确"]
|
||||
)
|
||||
|
||||
# 测试连接并获取工具列表
|
||||
test_result = await mcp_registry.test_plugin(user_id, plugin.plugin_name)
|
||||
# 使用统一门面测试连接
|
||||
test_result = await mcp_client.test_connection(user_id, plugin.plugin_name)
|
||||
|
||||
end_time = time.time()
|
||||
response_time = round((end_time - start_time) * 1000, 2)
|
||||
@@ -70,7 +97,18 @@ class MCPTestService:
|
||||
]
|
||||
)
|
||||
else:
|
||||
return MCPTestResult(**test_result)
|
||||
return MCPTestResult(
|
||||
success=False,
|
||||
message="❌ 连接测试失败",
|
||||
response_time_ms=response_time,
|
||||
error=test_result.get("message", "未知错误"),
|
||||
error_type=test_result.get("error_type"),
|
||||
suggestions=[
|
||||
"请检查服务器是否在线",
|
||||
"请确认配置正确",
|
||||
"请检查API Key是否有效"
|
||||
]
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
end_time = time.time()
|
||||
@@ -117,8 +155,8 @@ class MCPTestService:
|
||||
if not connection_result.success:
|
||||
return connection_result
|
||||
|
||||
# 2. 获取工具列表
|
||||
tools = await mcp_registry.get_plugin_tools(user.user_id, plugin.plugin_name)
|
||||
# 2. 使用统一门面获取工具列表
|
||||
tools = await mcp_client.get_tools(user.user_id, plugin.plugin_name)
|
||||
|
||||
if not tools:
|
||||
return MCPTestResult(
|
||||
@@ -162,8 +200,8 @@ class MCPTestService:
|
||||
max_tokens=1000
|
||||
)
|
||||
|
||||
# 转换为OpenAI Function Calling格式
|
||||
openai_tools = self._convert_tools_to_openai_format(tools)
|
||||
# 使用统一门面转换为OpenAI Function Calling格式
|
||||
openai_tools = mcp_client.format_tools_for_openai(tools, plugin.plugin_name)
|
||||
|
||||
logger.info(f"📋 转换后的OpenAI工具数量: {len(openai_tools)}")
|
||||
logger.debug(f"📋 OpenAI工具列表: {[t['function']['name'] for t in openai_tools]}")
|
||||
@@ -175,26 +213,16 @@ class MCPTestService:
|
||||
db=db_session
|
||||
)
|
||||
|
||||
# 注意: generate_text_stream 返回的是异步生成器,但在 tool_choice="required" 模式下
|
||||
# AI服务会直接返回包含 tool_calls 的完整响应,而不是流式chunks
|
||||
# 因此这里需要特殊处理
|
||||
accumulated_text = ""
|
||||
tool_calls = None
|
||||
|
||||
async for chunk in ai_service.generate_text_stream(
|
||||
# 使用 generate_text 进行 Function Calling(非流式)
|
||||
ai_response = await ai_service.generate_text(
|
||||
prompt=prompts["user"],
|
||||
system_prompt=prompts["system"],
|
||||
tools=openai_tools,
|
||||
tool_choice="required"
|
||||
):
|
||||
# 在 function calling 模式下,chunk 可能是字典格式包含 tool_calls
|
||||
if isinstance(chunk, dict):
|
||||
if "tool_calls" in chunk:
|
||||
tool_calls = chunk["tool_calls"]
|
||||
if "content" in chunk:
|
||||
accumulated_text += chunk.get("content", "")
|
||||
else:
|
||||
accumulated_text += chunk
|
||||
tool_choice="auto"
|
||||
)
|
||||
|
||||
accumulated_text = ai_response.get("content", "")
|
||||
tool_calls = ai_response.get("tool_calls")
|
||||
|
||||
# 5. 检查AI是否返回工具调用
|
||||
if not tool_calls:
|
||||
@@ -214,7 +242,7 @@ class MCPTestService:
|
||||
# 6. 解析工具调用
|
||||
tool_call = tool_calls[0]
|
||||
function = tool_call["function"]
|
||||
tool_name = function["name"]
|
||||
tool_name_with_prefix = function["name"]
|
||||
test_arguments = function["arguments"]
|
||||
|
||||
if isinstance(test_arguments, str):
|
||||
@@ -231,17 +259,23 @@ class MCPTestService:
|
||||
tools_count=len(tools)
|
||||
)
|
||||
|
||||
# 解析插件名和工具名
|
||||
try:
|
||||
_, tool_name = mcp_client.parse_function_name(tool_name_with_prefix)
|
||||
except ValueError:
|
||||
tool_name = tool_name_with_prefix
|
||||
|
||||
logger.info(f"🤖 AI选择的工具: {tool_name}")
|
||||
logger.info(f"📝 AI生成的参数: {test_arguments}")
|
||||
|
||||
# 7. 调用MCP工具
|
||||
# 7. 使用统一门面调用MCP工具
|
||||
call_start = time.time()
|
||||
try:
|
||||
tool_result = await mcp_registry.call_tool(
|
||||
user.user_id,
|
||||
plugin.plugin_name,
|
||||
tool_name,
|
||||
test_arguments
|
||||
tool_result = await mcp_client.call_tool(
|
||||
user_id=user.user_id,
|
||||
plugin_name=plugin.plugin_name,
|
||||
tool_name=tool_name,
|
||||
arguments=test_arguments
|
||||
)
|
||||
|
||||
call_end = time.time()
|
||||
@@ -307,22 +341,6 @@ class MCPTestService:
|
||||
"请检查API Key是否有效"
|
||||
]
|
||||
)
|
||||
|
||||
def _convert_tools_to_openai_format(self, tools: list) -> list:
|
||||
"""将MCP工具格式转换为OpenAI Function Calling格式"""
|
||||
openai_tools = []
|
||||
for tool in tools:
|
||||
openai_tool = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool["name"],
|
||||
"description": tool.get("description", ""),
|
||||
}
|
||||
}
|
||||
if "inputSchema" in tool:
|
||||
openai_tool["function"]["parameters"] = tool["inputSchema"]
|
||||
openai_tools.append(openai_tool)
|
||||
return openai_tools
|
||||
|
||||
|
||||
# 全局单例
|
||||
|
||||
@@ -1,691 +0,0 @@
|
||||
"""MCP工具服务 - 统一管理MCP工具的注入和执行"""
|
||||
|
||||
from typing import List, Dict, Any, Optional
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
from datetime import datetime, timedelta
|
||||
from dataclasses import dataclass, field
|
||||
from collections import defaultdict
|
||||
|
||||
from app.models.mcp_plugin import MCPPlugin
|
||||
from app.mcp.registry import mcp_registry
|
||||
from app.mcp.config import mcp_config
|
||||
from app.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolMetrics:
|
||||
"""工具调用指标"""
|
||||
total_calls: int = 0
|
||||
success_calls: int = 0
|
||||
failed_calls: int = 0
|
||||
total_duration_ms: float = 0.0
|
||||
avg_duration_ms: float = 0.0
|
||||
last_call_time: Optional[datetime] = None
|
||||
|
||||
def update_success(self, duration_ms: float):
|
||||
"""更新成功调用指标"""
|
||||
self.total_calls += 1
|
||||
self.success_calls += 1
|
||||
self.total_duration_ms += duration_ms
|
||||
self.avg_duration_ms = self.total_duration_ms / self.total_calls
|
||||
self.last_call_time = datetime.now()
|
||||
|
||||
def update_failure(self, duration_ms: float):
|
||||
"""更新失败调用指标"""
|
||||
self.total_calls += 1
|
||||
self.failed_calls += 1
|
||||
self.total_duration_ms += duration_ms
|
||||
self.avg_duration_ms = self.total_duration_ms / self.total_calls
|
||||
self.last_call_time = datetime.now()
|
||||
|
||||
@property
|
||||
def success_rate(self) -> float:
|
||||
"""成功率"""
|
||||
if self.total_calls == 0:
|
||||
return 0.0
|
||||
return self.success_calls / self.total_calls
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolCacheEntry:
|
||||
"""工具缓存条目"""
|
||||
tools: List[Dict[str, Any]]
|
||||
expire_time: datetime
|
||||
hit_count: int = 0
|
||||
|
||||
|
||||
class MCPToolServiceError(Exception):
|
||||
"""MCP工具服务异常"""
|
||||
pass
|
||||
|
||||
|
||||
class MCPToolService:
|
||||
"""MCP工具服务 - 统一管理MCP工具的注入和执行(优化版)"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cache_ttl_minutes: Optional[int] = None,
|
||||
max_retries: Optional[int] = None
|
||||
):
|
||||
"""
|
||||
初始化MCP工具服务
|
||||
|
||||
Args:
|
||||
cache_ttl_minutes: 工具缓存TTL(分钟,默认使用配置)
|
||||
max_retries: 最大重试次数(默认使用配置)
|
||||
"""
|
||||
# 工具定义缓存: {cache_key: ToolCacheEntry}
|
||||
self._tool_cache: Dict[str, ToolCacheEntry] = {}
|
||||
self._cache_ttl = timedelta(
|
||||
minutes=cache_ttl_minutes or mcp_config.TOOL_CACHE_TTL_MINUTES
|
||||
)
|
||||
|
||||
# 调用指标: {tool_key: ToolMetrics}
|
||||
self._metrics: Dict[str, ToolMetrics] = defaultdict(ToolMetrics)
|
||||
|
||||
# 重试配置(使用配置常量)
|
||||
self._max_retries = max_retries or mcp_config.MAX_RETRIES
|
||||
self._base_retry_delay = mcp_config.BASE_RETRY_DELAY_SECONDS
|
||||
self._max_retry_delay = mcp_config.MAX_RETRY_DELAY_SECONDS
|
||||
|
||||
logger.info(
|
||||
f"✅ MCPToolService初始化完成 "
|
||||
f"(缓存TTL={self._cache_ttl.total_seconds()/60:.1f}分钟, "
|
||||
f"最大重试={self._max_retries}次)"
|
||||
)
|
||||
|
||||
async def get_user_enabled_tools(
|
||||
self,
|
||||
user_id: str,
|
||||
db_session: AsyncSession,
|
||||
category: Optional[str] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取用户启用的MCP工具列表
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
db_session: 数据库会话
|
||||
category: 工具类别筛选(search/analysis/filesystem等)
|
||||
|
||||
Returns:
|
||||
工具定义列表,格式符合OpenAI Function Calling规范
|
||||
"""
|
||||
try:
|
||||
# 1. 查询用户启用的插件(enabled=True即可,不强制要求status=active)
|
||||
# 因为新启用的插件status可能还是inactive,需要给它机会被调用
|
||||
query = select(MCPPlugin).where(
|
||||
MCPPlugin.user_id == user_id,
|
||||
MCPPlugin.enabled == True
|
||||
)
|
||||
|
||||
if category:
|
||||
query = query.where(MCPPlugin.category == category)
|
||||
|
||||
result = await db_session.execute(query)
|
||||
plugins = result.scalars().all()
|
||||
|
||||
if not plugins:
|
||||
logger.info(f"用户 {user_id} 没有启用的MCP插件")
|
||||
return []
|
||||
|
||||
# 2. 获取所有工具定义(使用缓存)
|
||||
all_tools = []
|
||||
for plugin in plugins:
|
||||
try:
|
||||
# 确保插件已加载到注册表
|
||||
if not mcp_registry.get_client(user_id, plugin.plugin_name):
|
||||
logger.info(f"插件 {plugin.plugin_name} 未加载,尝试加载...")
|
||||
success = await mcp_registry.load_plugin(plugin)
|
||||
if not success:
|
||||
logger.warning(f"插件 {plugin.plugin_name} 加载失败,跳过")
|
||||
continue
|
||||
|
||||
# ✅ 使用缓存获取工具列表
|
||||
plugin_tools = await self._get_plugin_tools_cached(
|
||||
user_id=user_id,
|
||||
plugin_name=plugin.plugin_name
|
||||
)
|
||||
|
||||
# 格式化为Function Calling格式
|
||||
formatted_tools = self._format_tools_for_ai(
|
||||
plugin_tools,
|
||||
plugin.plugin_name
|
||||
)
|
||||
all_tools.extend(formatted_tools)
|
||||
|
||||
logger.info(
|
||||
f"从插件 {plugin.plugin_name} 加载了 "
|
||||
f"{len(formatted_tools)} 个工具"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"获取插件 {plugin.plugin_name} 的工具失败: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
continue
|
||||
|
||||
logger.info(f"用户 {user_id} 共加载 {len(all_tools)} 个MCP工具")
|
||||
return all_tools
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取用户MCP工具失败: {e}", exc_info=True)
|
||||
raise MCPToolServiceError(f"获取MCP工具失败: {str(e)}")
|
||||
|
||||
def _format_tools_for_ai(
|
||||
self,
|
||||
plugin_tools: List[Dict[str, Any]],
|
||||
plugin_name: str
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
将MCP工具定义格式化为AI Function Calling格式
|
||||
|
||||
Args:
|
||||
plugin_tools: MCP插件的工具列表
|
||||
plugin_name: 插件名称
|
||||
|
||||
Returns:
|
||||
格式化后的工具列表
|
||||
"""
|
||||
formatted_tools = []
|
||||
|
||||
for tool in plugin_tools:
|
||||
formatted_tool = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": f"{plugin_name}_{tool['name']}", # 加插件前缀避免冲突
|
||||
"description": tool.get("description", ""),
|
||||
"parameters": tool.get("inputSchema", {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": []
|
||||
})
|
||||
}
|
||||
}
|
||||
formatted_tools.append(formatted_tool)
|
||||
|
||||
return formatted_tools
|
||||
|
||||
async def _get_plugin_tools_cached(
|
||||
self,
|
||||
user_id: str,
|
||||
plugin_name: str
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
带缓存的工具列表获取
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
plugin_name: 插件名称
|
||||
|
||||
Returns:
|
||||
工具列表
|
||||
"""
|
||||
cache_key = f"{user_id}:{plugin_name}"
|
||||
now = datetime.now()
|
||||
|
||||
# 检查缓存
|
||||
if cache_key in self._tool_cache:
|
||||
entry = self._tool_cache[cache_key]
|
||||
if now < entry.expire_time:
|
||||
entry.hit_count += 1
|
||||
logger.debug(
|
||||
f"🎯 工具缓存命中: {cache_key} "
|
||||
f"(命中次数: {entry.hit_count})"
|
||||
)
|
||||
return entry.tools
|
||||
else:
|
||||
logger.debug(f"⏰ 工具缓存过期: {cache_key}")
|
||||
del self._tool_cache[cache_key]
|
||||
|
||||
# 缓存未命中,从MCP获取
|
||||
logger.debug(f"🔍 工具缓存未命中,从MCP获取: {cache_key}")
|
||||
tools = await mcp_registry.get_plugin_tools(user_id, plugin_name)
|
||||
|
||||
# 更新缓存
|
||||
self._tool_cache[cache_key] = ToolCacheEntry(
|
||||
tools=tools,
|
||||
expire_time=now + self._cache_ttl,
|
||||
hit_count=0
|
||||
)
|
||||
|
||||
return tools
|
||||
|
||||
def clear_cache(self, user_id: Optional[str] = None, plugin_name: Optional[str] = None):
|
||||
"""
|
||||
清理缓存
|
||||
|
||||
Args:
|
||||
user_id: 用户ID(可选,清理特定用户的缓存)
|
||||
plugin_name: 插件名称(可选,清理特定插件的缓存)
|
||||
"""
|
||||
if user_id is None and plugin_name is None:
|
||||
# 清理所有缓存
|
||||
self._tool_cache.clear()
|
||||
logger.info("🧹 已清理所有工具缓存")
|
||||
elif user_id and plugin_name:
|
||||
# 清理特定插件缓存
|
||||
cache_key = f"{user_id}:{plugin_name}"
|
||||
if cache_key in self._tool_cache:
|
||||
del self._tool_cache[cache_key]
|
||||
logger.info(f"🧹 已清理缓存: {cache_key}")
|
||||
elif user_id:
|
||||
# 清理用户所有缓存
|
||||
keys_to_delete = [
|
||||
key for key in self._tool_cache.keys()
|
||||
if key.startswith(f"{user_id}:")
|
||||
]
|
||||
for key in keys_to_delete:
|
||||
del self._tool_cache[key]
|
||||
logger.info(f"🧹 已清理用户缓存: {user_id} ({len(keys_to_delete)}个)")
|
||||
|
||||
async def execute_tool_calls(
|
||||
self,
|
||||
user_id: str,
|
||||
tool_calls: List[Dict[str, Any]],
|
||||
db_session: AsyncSession,
|
||||
timeout: Optional[float] = None,
|
||||
max_concurrent: int = 2
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
批量执行AI请求的工具调用(限制并发数,避免超时)
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
tool_calls: AI返回的工具调用列表
|
||||
db_session: 数据库会话
|
||||
timeout: 单个工具调用的超时时间(秒,默认使用配置)
|
||||
max_concurrent: 最大并发工具调用数(默认2)
|
||||
|
||||
Returns:
|
||||
工具调用结果列表
|
||||
"""
|
||||
if not tool_calls:
|
||||
return []
|
||||
|
||||
# 使用配置的默认超时
|
||||
actual_timeout = timeout or mcp_config.TOOL_CALL_TIMEOUT_SECONDS
|
||||
|
||||
logger.info(f"开始执行 {len(tool_calls)} 个工具调用 (超时={actual_timeout}s, 最大并发={max_concurrent})")
|
||||
|
||||
# ✅ 分批执行,每批最多max_concurrent个
|
||||
all_results = []
|
||||
for i in range(0, len(tool_calls), max_concurrent):
|
||||
batch = tool_calls[i:i+max_concurrent]
|
||||
batch_num = i // max_concurrent + 1
|
||||
total_batches = (len(tool_calls) + max_concurrent - 1) // max_concurrent
|
||||
|
||||
logger.info(f"执行工具批次 {batch_num}/{total_batches}, 数量: {len(batch)}")
|
||||
|
||||
# 创建当前批次的异步任务
|
||||
tasks = [
|
||||
self._execute_single_tool(
|
||||
user_id=user_id,
|
||||
tool_call=tool_call,
|
||||
db_session=db_session,
|
||||
timeout=actual_timeout
|
||||
)
|
||||
for tool_call in batch
|
||||
]
|
||||
|
||||
# 并行执行当前批次
|
||||
batch_results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# 处理批次结果
|
||||
for j, result in enumerate(batch_results):
|
||||
tool_call = batch[j]
|
||||
|
||||
if isinstance(result, Exception):
|
||||
# 工具调用异常
|
||||
all_results.append({
|
||||
"tool_call_id": tool_call.get("id", f"call_{i+j}"),
|
||||
"role": "tool",
|
||||
"name": tool_call["function"]["name"],
|
||||
"content": f"工具调用失败: {str(result)}",
|
||||
"success": False,
|
||||
"error": str(result)
|
||||
})
|
||||
else:
|
||||
all_results.append(result)
|
||||
|
||||
# 批次间增加短暂延迟,避免API限流
|
||||
if i + max_concurrent < len(tool_calls):
|
||||
await asyncio.sleep(0.5)
|
||||
logger.debug(f"批次间延迟 0.5 秒...")
|
||||
|
||||
return all_results
|
||||
|
||||
async def _execute_single_tool(
|
||||
self,
|
||||
user_id: str,
|
||||
tool_call: Dict[str, Any],
|
||||
db_session: AsyncSession,
|
||||
timeout: float
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
执行单个工具调用
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
tool_call: 工具调用信息
|
||||
db_session: 数据库会话
|
||||
timeout: 超时时间
|
||||
|
||||
Returns:
|
||||
工具调用结果
|
||||
"""
|
||||
tool_call_id = tool_call.get("id", "unknown")
|
||||
function_name = tool_call["function"]["name"]
|
||||
|
||||
try:
|
||||
# 解析插件名和工具名
|
||||
logger.debug(f"🔍 解析工具名称: {function_name}")
|
||||
if "_" in function_name:
|
||||
plugin_name, tool_name = function_name.split("_", 1)
|
||||
logger.debug(f" 插件: {plugin_name}, 工具: {tool_name}")
|
||||
else:
|
||||
raise ValueError(f"无效的工具名称格式: {function_name}")
|
||||
|
||||
# 解析参数
|
||||
arguments_str = tool_call["function"]["arguments"]
|
||||
logger.debug(f"🔍 解析参数:")
|
||||
logger.debug(f" 原始类型: {type(arguments_str)}")
|
||||
logger.debug(f" 原始内容: {arguments_str}")
|
||||
|
||||
if isinstance(arguments_str, str):
|
||||
try:
|
||||
arguments = json.loads(arguments_str)
|
||||
logger.debug(f" ✅ JSON解析成功: {arguments}")
|
||||
except json.JSONDecodeError as je:
|
||||
logger.error(f" ❌ JSON解析失败: {je}")
|
||||
logger.error(f" 原始字符串: '{arguments_str}'")
|
||||
raise ValueError(f"参数JSON解析失败: {je}")
|
||||
else:
|
||||
arguments = arguments_str
|
||||
logger.debug(f" 直接使用dict类型参数")
|
||||
|
||||
logger.info(
|
||||
f"执行工具: {plugin_name}.{tool_name}, "
|
||||
f"参数: {arguments}"
|
||||
)
|
||||
|
||||
# ✅ 使用带重试的调用
|
||||
tool_key = f"{plugin_name}.{tool_name}"
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
result = await self._call_tool_with_retry(
|
||||
user_id=user_id,
|
||||
plugin_name=plugin_name,
|
||||
tool_name=tool_name,
|
||||
arguments=arguments,
|
||||
timeout=timeout
|
||||
)
|
||||
|
||||
# 记录成功指标
|
||||
duration_ms = (time.time() - start_time) * 1000
|
||||
self._metrics[tool_key].update_success(duration_ms)
|
||||
|
||||
logger.info(
|
||||
f"✅ 工具调用成功: {tool_key} "
|
||||
f"(耗时: {duration_ms:.2f}ms)"
|
||||
)
|
||||
|
||||
# 成功返回
|
||||
return {
|
||||
"tool_call_id": tool_call_id,
|
||||
"role": "tool",
|
||||
"name": function_name,
|
||||
"content": json.dumps(result, ensure_ascii=False),
|
||||
"success": True,
|
||||
"error": None
|
||||
}
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
# 记录失败指标
|
||||
duration_ms = (time.time() - start_time) * 1000
|
||||
self._metrics[tool_key].update_failure(duration_ms)
|
||||
raise MCPToolServiceError(
|
||||
f"工具调用超时(>{timeout}秒)"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
# 记录失败指标
|
||||
tool_key = f"{plugin_name}.{tool_name}" if 'plugin_name' in locals() else function_name
|
||||
duration_ms = (time.time() - start_time) * 1000
|
||||
self._metrics[tool_key].update_failure(duration_ms)
|
||||
|
||||
logger.error(
|
||||
f"❌ 工具 {function_name} 调用失败: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
return {
|
||||
"tool_call_id": tool_call_id,
|
||||
"role": "tool",
|
||||
"name": function_name,
|
||||
"content": f"工具调用失败: {str(e)}",
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
async def _call_tool_with_retry(
|
||||
self,
|
||||
user_id: str,
|
||||
plugin_name: str,
|
||||
tool_name: str,
|
||||
arguments: Dict[str, Any],
|
||||
timeout: float
|
||||
) -> Any:
|
||||
"""
|
||||
带指数退避重试的工具调用
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
plugin_name: 插件名称
|
||||
tool_name: 工具名称
|
||||
arguments: 工具参数
|
||||
timeout: 超时时间
|
||||
|
||||
Returns:
|
||||
工具执行结果
|
||||
|
||||
Raises:
|
||||
MCPToolServiceError: 工具调用失败
|
||||
asyncio.TimeoutError: 调用超时
|
||||
"""
|
||||
last_exception = None
|
||||
|
||||
for attempt in range(self._max_retries):
|
||||
try:
|
||||
# 尝试调用工具
|
||||
result = await asyncio.wait_for(
|
||||
mcp_registry.call_tool(
|
||||
user_id=user_id,
|
||||
plugin_name=plugin_name,
|
||||
tool_name=tool_name,
|
||||
arguments=arguments
|
||||
),
|
||||
timeout=timeout
|
||||
)
|
||||
|
||||
# 成功则返回
|
||||
if attempt > 0:
|
||||
logger.info(
|
||||
f"✅ 重试成功: {plugin_name}.{tool_name} "
|
||||
f"(第{attempt + 1}次尝试)"
|
||||
)
|
||||
return result
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
# 超时不重试,直接抛出
|
||||
raise
|
||||
|
||||
except Exception as e:
|
||||
last_exception = e
|
||||
|
||||
# 最后一次尝试失败
|
||||
if attempt == self._max_retries - 1:
|
||||
logger.error(
|
||||
f"❌ 重试失败: {plugin_name}.{tool_name} "
|
||||
f"(已尝试{self._max_retries}次): {e}"
|
||||
)
|
||||
raise MCPToolServiceError(
|
||||
f"工具调用失败(已重试{self._max_retries}次): {str(e)}"
|
||||
)
|
||||
|
||||
# 计算指数退避延迟
|
||||
delay = min(
|
||||
self._base_retry_delay * (2 ** attempt),
|
||||
self._max_retry_delay
|
||||
)
|
||||
|
||||
logger.warning(
|
||||
f"⚠️ 工具调用失败,{delay:.1f}秒后重试 "
|
||||
f"(第{attempt + 1}/{self._max_retries}次): "
|
||||
f"{plugin_name}.{tool_name} - {e}"
|
||||
)
|
||||
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
# 理论上不会到这里,但为了类型安全
|
||||
raise MCPToolServiceError(f"工具调用失败: {last_exception}")
|
||||
|
||||
def get_metrics(self, tool_name: Optional[str] = None) -> Dict[str, Dict[str, Any]]:
|
||||
"""
|
||||
获取工具调用指标
|
||||
|
||||
Args:
|
||||
tool_name: 工具名称(可选,获取特定工具的指标)
|
||||
|
||||
Returns:
|
||||
指标字典
|
||||
"""
|
||||
if tool_name:
|
||||
if tool_name in self._metrics:
|
||||
metric = self._metrics[tool_name]
|
||||
return {
|
||||
tool_name: {
|
||||
"total_calls": metric.total_calls,
|
||||
"success_calls": metric.success_calls,
|
||||
"failed_calls": metric.failed_calls,
|
||||
"success_rate": metric.success_rate,
|
||||
"avg_duration_ms": round(metric.avg_duration_ms, 2),
|
||||
"last_call_time": metric.last_call_time.isoformat() if metric.last_call_time else None
|
||||
}
|
||||
}
|
||||
return {}
|
||||
|
||||
# 返回所有工具的指标
|
||||
result = {}
|
||||
for tool_key, metric in self._metrics.items():
|
||||
result[tool_key] = {
|
||||
"total_calls": metric.total_calls,
|
||||
"success_calls": metric.success_calls,
|
||||
"failed_calls": metric.failed_calls,
|
||||
"success_rate": round(metric.success_rate, 3),
|
||||
"avg_duration_ms": round(metric.avg_duration_ms, 2),
|
||||
"last_call_time": metric.last_call_time.isoformat() if metric.last_call_time else None
|
||||
}
|
||||
return result
|
||||
|
||||
def get_cache_stats(self) -> Dict[str, Any]:
|
||||
"""获取缓存统计信息"""
|
||||
total_entries = len(self._tool_cache)
|
||||
total_hits = sum(entry.hit_count for entry in self._tool_cache.values())
|
||||
|
||||
return {
|
||||
"total_entries": total_entries,
|
||||
"total_hits": total_hits,
|
||||
"cache_ttl_minutes": self._cache_ttl.total_seconds() / 60,
|
||||
"entries": [
|
||||
{
|
||||
"key": key,
|
||||
"tools_count": len(entry.tools),
|
||||
"hit_count": entry.hit_count,
|
||||
"expire_time": entry.expire_time.isoformat()
|
||||
}
|
||||
for key, entry in self._tool_cache.items()
|
||||
]
|
||||
}
|
||||
|
||||
async def build_tool_context(
|
||||
self,
|
||||
tool_results: List[Dict[str, Any]],
|
||||
format: str = "markdown"
|
||||
) -> str:
|
||||
"""
|
||||
将工具调用结果格式化为上下文文本
|
||||
|
||||
Args:
|
||||
tool_results: 工具调用结果列表
|
||||
format: 输出格式(markdown/json/plain)
|
||||
|
||||
Returns:
|
||||
格式化的上下文字符串
|
||||
"""
|
||||
if not tool_results:
|
||||
return ""
|
||||
|
||||
if format == "markdown":
|
||||
return self._build_markdown_context(tool_results)
|
||||
elif format == "json":
|
||||
return json.dumps(tool_results, ensure_ascii=False, indent=2)
|
||||
else: # plain
|
||||
return self._build_plain_context(tool_results)
|
||||
|
||||
def _build_markdown_context(
|
||||
self,
|
||||
tool_results: List[Dict[str, Any]]
|
||||
) -> str:
|
||||
"""构建Markdown格式的工具上下文"""
|
||||
lines = ["## 🔧 工具调用结果\n"]
|
||||
|
||||
for i, result in enumerate(tool_results, 1):
|
||||
tool_name = result.get("name", "unknown")
|
||||
success = result.get("success", False)
|
||||
content = result.get("content", "")
|
||||
|
||||
status_emoji = "✅" if success else "❌"
|
||||
lines.append(f"### {status_emoji} {i}. {tool_name}\n")
|
||||
|
||||
if success:
|
||||
# 尝试美化JSON内容
|
||||
try:
|
||||
content_obj = json.loads(content)
|
||||
content = json.dumps(content_obj, ensure_ascii=False, indent=2)
|
||||
except:
|
||||
pass
|
||||
lines.append(f"```json\n{content}\n```\n")
|
||||
else:
|
||||
lines.append(f"**错误**: {content}\n")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
def _build_plain_context(
|
||||
self,
|
||||
tool_results: List[Dict[str, Any]]
|
||||
) -> str:
|
||||
"""构建纯文本格式的工具上下文"""
|
||||
lines = ["=== 工具调用结果 ===\n"]
|
||||
|
||||
for i, result in enumerate(tool_results, 1):
|
||||
tool_name = result.get("name", "unknown")
|
||||
success = result.get("success", False)
|
||||
content = result.get("content", "")
|
||||
|
||||
status = "成功" if success else "失败"
|
||||
lines.append(f"{i}. {tool_name} - {status}")
|
||||
lines.append(f" 结果: {content}\n")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
# 全局单例
|
||||
mcp_tool_service = MCPToolService()
|
||||
@@ -0,0 +1,235 @@
|
||||
"""MCP工具加载器 - 统一的工具获取入口
|
||||
|
||||
在AI请求之前,自动检查用户MCP配置并加载可用工具。
|
||||
"""
|
||||
from typing import Optional, List, Dict, Any
|
||||
from datetime import datetime, timedelta
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.logger import get_logger
|
||||
from app.models.mcp_plugin import MCPPlugin
|
||||
from app.mcp import mcp_client
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class UserToolsCache:
|
||||
"""用户工具缓存条目"""
|
||||
tools: Optional[List[Dict[str, Any]]]
|
||||
expire_time: datetime
|
||||
hit_count: int = 0
|
||||
|
||||
|
||||
class MCPToolsLoader:
|
||||
"""
|
||||
MCP工具加载器
|
||||
|
||||
负责:
|
||||
1. 检查用户是否配置并启用了MCP插件
|
||||
2. 从各个启用的插件加载工具列表
|
||||
3. 将工具转换为OpenAI Function Calling格式
|
||||
4. 缓存结果以提升性能
|
||||
"""
|
||||
|
||||
_instance: Optional['MCPToolsLoader'] = None
|
||||
|
||||
def __new__(cls):
|
||||
"""单例模式"""
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
cls._instance._initialized = False
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
# 用户工具缓存: user_id -> UserToolsCache
|
||||
self._cache: Dict[str, UserToolsCache] = {}
|
||||
|
||||
# 缓存TTL(5分钟)
|
||||
self._cache_ttl = timedelta(minutes=5)
|
||||
|
||||
self._initialized = True
|
||||
logger.info("✅ MCPToolsLoader 初始化完成")
|
||||
|
||||
async def has_enabled_plugins(
|
||||
self,
|
||||
user_id: str,
|
||||
db_session: AsyncSession
|
||||
) -> bool:
|
||||
"""
|
||||
检查用户是否有启用的MCP插件
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
db_session: 数据库会话
|
||||
|
||||
Returns:
|
||||
是否有启用的插件
|
||||
"""
|
||||
try:
|
||||
query = select(MCPPlugin.id).where(
|
||||
MCPPlugin.user_id == user_id,
|
||||
MCPPlugin.enabled == True,
|
||||
MCPPlugin.plugin_type.in_(["http", "streamable_http", "sse"])
|
||||
).limit(1)
|
||||
|
||||
result = await db_session.execute(query)
|
||||
return result.scalar() is not None
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"检查用户MCP插件失败: {e}")
|
||||
return False
|
||||
|
||||
async def get_user_tools(
|
||||
self,
|
||||
user_id: str,
|
||||
db_session: AsyncSession,
|
||||
use_cache: bool = True,
|
||||
force_refresh: bool = False
|
||||
) -> Optional[List[Dict[str, Any]]]:
|
||||
"""
|
||||
获取用户的MCP工具列表(OpenAI格式)
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
db_session: 数据库会话
|
||||
use_cache: 是否使用缓存
|
||||
force_refresh: 是否强制刷新
|
||||
|
||||
Returns:
|
||||
- None: 用户未配置或未启用任何MCP插件
|
||||
- []: 有配置但没有可用工具
|
||||
- List[Dict]: OpenAI Function Calling格式的工具列表
|
||||
"""
|
||||
now = datetime.now()
|
||||
|
||||
# 检查缓存
|
||||
if use_cache and not force_refresh and user_id in self._cache:
|
||||
cache_entry = self._cache[user_id]
|
||||
if now < cache_entry.expire_time:
|
||||
cache_entry.hit_count += 1
|
||||
logger.debug(f"🎯 用户工具缓存命中: {user_id} (命中次数: {cache_entry.hit_count})")
|
||||
return cache_entry.tools
|
||||
else:
|
||||
del self._cache[user_id]
|
||||
logger.debug(f"⏰ 用户工具缓存过期: {user_id}")
|
||||
|
||||
# 从数据库加载
|
||||
try:
|
||||
tools = await self._load_user_tools(user_id, db_session)
|
||||
|
||||
# 更新缓存
|
||||
self._cache[user_id] = UserToolsCache(
|
||||
tools=tools,
|
||||
expire_time=now + self._cache_ttl
|
||||
)
|
||||
|
||||
if tools:
|
||||
logger.info(f"🔧 用户 {user_id} 加载了 {len(tools)} 个MCP工具")
|
||||
else:
|
||||
logger.debug(f"📭 用户 {user_id} 没有可用的MCP工具")
|
||||
|
||||
return tools
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 加载用户MCP工具失败: {e}")
|
||||
return None
|
||||
|
||||
async def _load_user_tools(
|
||||
self,
|
||||
user_id: str,
|
||||
db_session: AsyncSession
|
||||
) -> Optional[List[Dict[str, Any]]]:
|
||||
"""
|
||||
从数据库加载用户启用的MCP插件并获取工具
|
||||
"""
|
||||
# 查询启用的插件
|
||||
query = select(MCPPlugin).where(
|
||||
MCPPlugin.user_id == user_id,
|
||||
MCPPlugin.enabled == True,
|
||||
MCPPlugin.plugin_type.in_(["http", "streamable_http", "sse"])
|
||||
).order_by(MCPPlugin.sort_order)
|
||||
|
||||
result = await db_session.execute(query)
|
||||
plugins = result.scalars().all()
|
||||
|
||||
if not plugins:
|
||||
return None
|
||||
|
||||
all_tools = []
|
||||
|
||||
for plugin in plugins:
|
||||
try:
|
||||
# 确定插件类型
|
||||
plugin_type = plugin.plugin_type
|
||||
if plugin_type == "http":
|
||||
plugin_type = "streamable_http" # 默认使用streamable_http
|
||||
|
||||
# 确保插件已注册到MCP客户端
|
||||
await mcp_client.ensure_registered(
|
||||
user_id=user_id,
|
||||
plugin_name=plugin.plugin_name,
|
||||
url=plugin.server_url,
|
||||
plugin_type=plugin_type,
|
||||
headers=plugin.headers
|
||||
)
|
||||
|
||||
# 获取工具列表
|
||||
plugin_tools = await mcp_client.get_tools(user_id, plugin.plugin_name)
|
||||
|
||||
# 转换为OpenAI格式
|
||||
formatted = mcp_client.format_tools_for_openai(plugin_tools, plugin.plugin_name)
|
||||
all_tools.extend(formatted)
|
||||
|
||||
logger.debug(f"✅ 从插件 {plugin.plugin_name} 加载了 {len(formatted)} 个工具")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"⚠️ 加载插件 {plugin.plugin_name} 工具失败: {e}")
|
||||
continue
|
||||
|
||||
return all_tools if all_tools else None
|
||||
|
||||
def invalidate_cache(self, user_id: Optional[str] = None):
|
||||
"""
|
||||
使缓存失效
|
||||
|
||||
Args:
|
||||
user_id: 用户ID,为None时清空所有缓存
|
||||
"""
|
||||
if user_id:
|
||||
if user_id in self._cache:
|
||||
del self._cache[user_id]
|
||||
logger.debug(f"🧹 清理用户工具缓存: {user_id}")
|
||||
else:
|
||||
count = len(self._cache)
|
||||
self._cache.clear()
|
||||
logger.info(f"🧹 清理所有用户工具缓存 ({count}个)")
|
||||
|
||||
def get_cache_stats(self) -> Dict[str, Any]:
|
||||
"""获取缓存统计"""
|
||||
now = datetime.now()
|
||||
return {
|
||||
"total_entries": len(self._cache),
|
||||
"total_hits": sum(e.hit_count for e in self._cache.values()),
|
||||
"cache_ttl_minutes": self._cache_ttl.total_seconds() / 60,
|
||||
"entries": [
|
||||
{
|
||||
"user_id": uid,
|
||||
"tools_count": len(e.tools) if e.tools else 0,
|
||||
"hit_count": e.hit_count,
|
||||
"expired": now >= e.expire_time,
|
||||
"expire_time": e.expire_time.isoformat()
|
||||
}
|
||||
for uid, e in self._cache.items()
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
# 全局单例
|
||||
mcp_tools_loader = MCPToolsLoader()
|
||||
Reference in New Issue
Block a user