feat: 重构MCP功能和AI服务提供者架构
This commit is contained in:
+253
-7
@@ -22,7 +22,7 @@ from app.schemas.settings import (
|
||||
from app.user_manager import User
|
||||
from app.logger import get_logger
|
||||
from app.config import settings as app_settings, PROJECT_ROOT
|
||||
from app.services.ai_service import AIService, create_user_ai_service
|
||||
from app.services.ai_service import AIService, create_user_ai_service, create_user_ai_service_with_mcp
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -53,9 +53,14 @@ async def get_user_ai_service(
|
||||
db: AsyncSession = Depends(get_db)
|
||||
) -> AIService:
|
||||
"""
|
||||
依赖:获取当前用户的AI服务实例
|
||||
从数据库读取用户设置并创建对应的AI服务
|
||||
依赖:获取当前用户的AI服务实例(支持MCP工具自动加载)
|
||||
|
||||
从数据库读取用户设置并创建对应的AI服务。
|
||||
自动传递 user_id 和 db_session,使得 AIService 能够加载用户配置的MCP工具。
|
||||
根据用户的所有MCP插件状态决定是否启用MCP:如果有启用的插件则启用,否则禁用。
|
||||
"""
|
||||
from app.models.mcp_plugin import MCPPlugin
|
||||
|
||||
result = await db.execute(
|
||||
select(Settings).where(Settings.user_id == user.user_id)
|
||||
)
|
||||
@@ -73,15 +78,34 @@ async def get_user_ai_service(
|
||||
await db.refresh(settings)
|
||||
logger.info(f"用户 {user.user_id} 首次使用AI服务,已从.env同步设置到数据库")
|
||||
|
||||
# 使用用户设置创建AI服务实例(包括系统提示词)
|
||||
return create_user_ai_service(
|
||||
# 查询用户的所有MCP插件状态
|
||||
mcp_result = await db.execute(
|
||||
select(MCPPlugin).where(MCPPlugin.user_id == user.user_id)
|
||||
)
|
||||
mcp_plugins = mcp_result.scalars().all()
|
||||
|
||||
# 检查是否有启用的MCP插件
|
||||
enable_mcp = any(plugin.enabled for plugin in mcp_plugins) if mcp_plugins else False
|
||||
|
||||
if mcp_plugins:
|
||||
enabled_count = sum(1 for p in mcp_plugins if p.enabled)
|
||||
logger.info(f"用户 {user.user_id} 有 {len(mcp_plugins)} 个MCP插件,{enabled_count} 个启用,{enable_mcp} 决定使用MCP")
|
||||
else:
|
||||
logger.debug(f"用户 {user.user_id} 没有配置MCP插件,禁用MCP")
|
||||
|
||||
# ✅ 使用支持MCP的工厂函数创建AI服务实例
|
||||
# 传递 user_id 和 db_session,使得 AIService 能够自动加载用户配置的MCP工具
|
||||
return create_user_ai_service_with_mcp(
|
||||
api_provider=settings.api_provider,
|
||||
api_key=settings.api_key,
|
||||
api_base_url=settings.api_base_url or "",
|
||||
model_name=settings.llm_model,
|
||||
temperature=settings.temperature,
|
||||
max_tokens=settings.max_tokens,
|
||||
system_prompt=settings.system_prompt # 传递系统提示词
|
||||
user_id=user.user_id, # ✅ 传递 user_id
|
||||
db_session=db, # ✅ 传递 db_session
|
||||
system_prompt=settings.system_prompt,
|
||||
enable_mcp=enable_mcp, # 根据MCP插件状态动态决定
|
||||
)
|
||||
|
||||
|
||||
@@ -327,6 +351,227 @@ class ApiTestRequest(BaseModel):
|
||||
llm_model: str
|
||||
|
||||
|
||||
@router.post("/check-function-calling")
|
||||
async def check_function_calling_support(data: ApiTestRequest):
|
||||
"""
|
||||
检查模型是否支持 Function Calling(工具调用)
|
||||
|
||||
基于业界最佳实践的测试方法:
|
||||
1. 发送包含工具定义的请求
|
||||
2. 检查响应的 finish_reason 是否为 "tool_calls"
|
||||
3. 验证响应中是否包含有效的 tool_calls 数据
|
||||
|
||||
Args:
|
||||
data: 包含 API 配置的请求数据
|
||||
|
||||
Returns:
|
||||
检测结果包含支持状态、详细信息和建议
|
||||
"""
|
||||
api_key = data.api_key
|
||||
api_base_url = data.api_base_url
|
||||
provider = data.provider
|
||||
llm_model = data.llm_model
|
||||
|
||||
try:
|
||||
start_time = time.time()
|
||||
|
||||
# 定义一个简单的测试工具(天气查询)
|
||||
test_tools = [{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"description": "获取指定城市的当前天气信息",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {
|
||||
"type": "string",
|
||||
"description": "城市名称,例如:北京、上海、深圳"
|
||||
},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
"description": "温度单位"
|
||||
}
|
||||
},
|
||||
"required": ["city"]
|
||||
}
|
||||
}
|
||||
}]
|
||||
|
||||
# 测试提示:故意设计一个需要调用工具的问题
|
||||
test_prompt = "请告诉我北京现在的天气情况如何?"
|
||||
|
||||
logger.info(f"🧪 开始检测 Function Calling 支持")
|
||||
logger.info(f" - 提供商: {provider}")
|
||||
logger.info(f" - 模型: {llm_model}")
|
||||
logger.info(f" - 测试工具: get_weather")
|
||||
|
||||
# 创建临时 AI 服务实例进行测试
|
||||
test_service = AIService(
|
||||
api_provider=provider,
|
||||
api_key=api_key,
|
||||
api_base_url=api_base_url,
|
||||
default_model=llm_model,
|
||||
default_temperature=0.3, # 使用较低温度以获得更确定的行为
|
||||
default_max_tokens=200
|
||||
)
|
||||
|
||||
# 发送带工具的测试请求
|
||||
response = await test_service.generate_text(
|
||||
prompt=test_prompt,
|
||||
provider=provider,
|
||||
model=llm_model,
|
||||
temperature=0.3,
|
||||
max_tokens=200,
|
||||
tools=test_tools,
|
||||
tool_choice="auto", # 让模型自动决定是否使用工具
|
||||
auto_mcp=False # 禁用 MCP 自动加载
|
||||
)
|
||||
|
||||
end_time = time.time()
|
||||
response_time = round((end_time - start_time) * 1000, 2)
|
||||
|
||||
# 分析响应以确定是否支持 Function Calling
|
||||
supported = False
|
||||
finish_reason = None
|
||||
tool_calls = None
|
||||
response_content = None
|
||||
|
||||
if isinstance(response, dict):
|
||||
# 检查 finish_reason(OpenAI 标准)
|
||||
finish_reason = response.get("finish_reason")
|
||||
|
||||
# 检查是否有 tool_calls
|
||||
if "tool_calls" in response and response["tool_calls"]:
|
||||
supported = True
|
||||
tool_calls = response["tool_calls"]
|
||||
logger.info(f"✅ 检测到工具调用: {len(tool_calls)} 个")
|
||||
|
||||
# 记录返回的内容(如果有)
|
||||
if "content" in response:
|
||||
response_content = response["content"]
|
||||
elif isinstance(response, str):
|
||||
# 如果只返回字符串,说明不支持工具调用
|
||||
response_content = response
|
||||
|
||||
logger.info(f" - 响应时间: {response_time}ms")
|
||||
logger.info(f" - finish_reason: {finish_reason}")
|
||||
logger.info(f" - 支持状态: {'✅ 支持' if supported else '❌ 不支持'}")
|
||||
|
||||
# 构建详细的返回信息
|
||||
result = {
|
||||
"success": True,
|
||||
"supported": supported,
|
||||
"message": "✅ 模型支持 Function Calling" if supported else "❌ 模型不支持 Function Calling",
|
||||
"response_time_ms": response_time,
|
||||
"provider": provider,
|
||||
"model": llm_model,
|
||||
"details": {
|
||||
"finish_reason": finish_reason,
|
||||
"has_tool_calls": bool(tool_calls),
|
||||
"tool_call_count": len(tool_calls) if tool_calls else 0,
|
||||
"test_tool": "get_weather",
|
||||
"test_prompt": test_prompt,
|
||||
"response_type": "tool_calls" if supported else "text"
|
||||
}
|
||||
}
|
||||
|
||||
# 添加工具调用详情
|
||||
if tool_calls:
|
||||
result["tool_calls"] = tool_calls
|
||||
result["suggestions"] = [
|
||||
"✅ 该模型支持 Function Calling,可以正常使用 MCP 插件",
|
||||
"建议:启用需要的 MCP 插件以扩展 AI 能力",
|
||||
"提示:测试成功检测到工具调用,模型能够正确解析和使用外部工具"
|
||||
]
|
||||
else:
|
||||
result["response_preview"] = response_content[:200] if response_content else None
|
||||
result["suggestions"] = [
|
||||
"❌ 该模型不支持 Function Calling,无法使用 MCP 插件功能",
|
||||
"建议:更换支持工具调用的模型",
|
||||
"推荐模型:GPT-4 系列、GPT-4-turbo、Claude 3 Opus/Sonnet、Gemini 1.5 Pro 等",
|
||||
"说明:模型返回了文本回复而非工具调用,表明不支持该功能"
|
||||
]
|
||||
|
||||
return result
|
||||
|
||||
except ValueError as e:
|
||||
error_msg = str(e)
|
||||
logger.error(f"❌ Function Calling 检测配置错误: {error_msg}")
|
||||
return {
|
||||
"success": False,
|
||||
"supported": False,
|
||||
"message": "配置错误",
|
||||
"error": error_msg,
|
||||
"error_type": "ConfigurationError",
|
||||
"suggestions": [
|
||||
"请检查 API Key 是否正确",
|
||||
"请确认 API Base URL 格式是否正确",
|
||||
"请验证所选提供商与配置是否匹配"
|
||||
]
|
||||
}
|
||||
|
||||
except TimeoutError as e:
|
||||
error_msg = str(e)
|
||||
logger.error(f"❌ Function Calling 检测超时: {error_msg}")
|
||||
return {
|
||||
"success": False,
|
||||
"supported": None,
|
||||
"message": "检测超时",
|
||||
"error": error_msg,
|
||||
"error_type": "TimeoutError",
|
||||
"suggestions": [
|
||||
"请检查网络连接是否正常",
|
||||
"请确认 API 服务是否可访问",
|
||||
"建议:稍后重试或使用其他网络环境"
|
||||
]
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
error_type = type(e).__name__
|
||||
|
||||
logger.error(f"❌ Function Calling 检测失败: {error_msg}")
|
||||
logger.error(f" - 错误类型: {error_type}")
|
||||
|
||||
# 智能分析错误原因
|
||||
suggestions = []
|
||||
if "tool" in error_msg.lower() or "function" in error_msg.lower():
|
||||
suggestions = [
|
||||
"该模型可能不支持 Function Calling 功能",
|
||||
"API 返回了与工具调用相关的错误",
|
||||
"建议:更换支持工具调用的模型或联系 API 提供商"
|
||||
]
|
||||
elif "unauthorized" in error_msg.lower() or "401" in error_msg:
|
||||
suggestions = [
|
||||
"API Key 认证失败",
|
||||
"请检查 API Key 是否正确且有效",
|
||||
"请确认 API Key 是否有足够的权限"
|
||||
]
|
||||
elif "not found" in error_msg.lower() or "404" in error_msg:
|
||||
suggestions = [
|
||||
"模型不存在或不可用",
|
||||
"请检查模型名称是否正确",
|
||||
"请确认该模型在当前 API 中是否可用"
|
||||
]
|
||||
else:
|
||||
suggestions = [
|
||||
"检测过程中遇到未知错误",
|
||||
"建议:检查所有配置参数是否正确",
|
||||
"提示:查看详细错误信息以获取更多线索"
|
||||
]
|
||||
|
||||
return {
|
||||
"success": False,
|
||||
"supported": False,
|
||||
"message": "Function Calling 检测失败",
|
||||
"error": error_msg,
|
||||
"error_type": error_type,
|
||||
"suggestions": suggestions
|
||||
}
|
||||
|
||||
|
||||
@router.post("/test")
|
||||
async def test_api_connection(data: ApiTestRequest):
|
||||
"""
|
||||
@@ -370,7 +615,8 @@ async def test_api_connection(data: ApiTestRequest):
|
||||
provider=provider,
|
||||
model=llm_model,
|
||||
temperature=0.7,
|
||||
max_tokens=8000
|
||||
max_tokens=8000,
|
||||
auto_mcp=False # 测试时不加载MCP工具
|
||||
)
|
||||
|
||||
end_time = time.time()
|
||||
|
||||
Reference in New Issue
Block a user