736 lines
28 KiB
Python
736 lines
28 KiB
Python
"""AI服务封装 - 统一的AI接口
|
||
|
||
重构后支持自动MCP工具加载:
|
||
- 所有AI方法在请求前自动检查用户MCP配置
|
||
- 如果有启用的MCP插件且有可用工具,自动发送tools
|
||
- 通过 auto_mcp 参数控制是否启用自动工具加载
|
||
"""
|
||
from typing import Optional, AsyncGenerator, List, Dict, Any, Union
|
||
|
||
from app.config import settings as app_settings
|
||
from app.logger import get_logger
|
||
from app.services.ai_config import AIClientConfig, default_config
|
||
from app.services.ai_metrics import AICallMetrics, TokenUsage, ToolCallMetrics
|
||
from app.services.ai_clients.openai_client import OpenAIClient
|
||
from app.services.ai_clients.anthropic_client import AnthropicClient
|
||
from app.services.ai_clients.gemini_client import GeminiClient
|
||
from app.services.ai_clients.base_client import cleanup_all_clients
|
||
from app.services.ai_providers.openai_provider import OpenAIProvider
|
||
from app.services.ai_providers.anthropic_provider import AnthropicProvider
|
||
from app.services.ai_providers.gemini_provider import GeminiProvider
|
||
from app.services.ai_providers.base_provider import BaseAIProvider
|
||
from app.services.json_helper import clean_json_response, parse_json
|
||
|
||
# 导出清理函数
|
||
cleanup_http_clients = cleanup_all_clients
|
||
|
||
logger = get_logger(__name__)
|
||
|
||
|
||
def normalize_provider(provider: Optional[str]) -> Optional[str]:
|
||
"""标准化 provider 名称,兼容渠道别名。"""
|
||
if provider == "mumu":
|
||
return "openai"
|
||
return provider
|
||
|
||
|
||
class AIService:
|
||
"""
|
||
AI服务统一接口
|
||
|
||
MCP工具支持:
|
||
- 在创建服务时传入 user_id 和 db_session
|
||
- 根据用户MCP插件的enabled状态自动决定是否启用MCP
|
||
- 如果有任意一个MCP插件启用,则加载并使用工具
|
||
- 如果所有插件都关闭,则不使用任何MCP工具
|
||
- 通过 auto_mcp=False 可临时禁用自动工具加载
|
||
- 通过 mcp_max_rounds 控制工具调用轮数
|
||
- 通过 clear_mcp_cache() 可清理MCP工具缓存
|
||
|
||
MCP启用逻辑(backend/app/api/settings.py 中的 get_user_ai_service):
|
||
- 查询用户的所有MCP插件
|
||
- 如果有启用的插件 (enabled=True),则 enable_mcp=True
|
||
- 如果所有插件都关闭或没有插件,则 enable_mcp=False
|
||
|
||
使用示例:
|
||
# 创建支持MCP的AI服务(根据插件状态自动决定是否启用)
|
||
ai_service = create_user_ai_service_with_mcp(
|
||
api_provider="openai",
|
||
api_key="...",
|
||
user_id="user123",
|
||
db_session=db
|
||
)
|
||
|
||
# 自动加载MCP工具(如果有启用的插件)
|
||
result = await ai_service.generate_text(prompt="...")
|
||
|
||
# 临时禁用MCP工具
|
||
result = await ai_service.generate_text(prompt="...", auto_mcp=False)
|
||
|
||
# 自定义轮数
|
||
result = await ai_service.generate_text(prompt="...", mcp_max_rounds=3)
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
api_provider: Optional[str] = None,
|
||
api_key: Optional[str] = None,
|
||
api_base_url: Optional[str] = None,
|
||
default_model: Optional[str] = None,
|
||
default_temperature: Optional[float] = None,
|
||
default_max_tokens: Optional[int] = None,
|
||
default_system_prompt: Optional[str] = None,
|
||
config: Optional[AIClientConfig] = None,
|
||
# MCP支持参数
|
||
user_id: Optional[str] = None,
|
||
db_session: Optional[Any] = None,
|
||
enable_mcp: bool = True,
|
||
):
|
||
self.api_provider = normalize_provider(api_provider or app_settings.default_ai_provider)
|
||
self.default_model = default_model or app_settings.default_model
|
||
self.default_temperature = default_temperature or app_settings.default_temperature
|
||
self.default_max_tokens = default_max_tokens or app_settings.default_max_tokens
|
||
self.default_system_prompt = default_system_prompt
|
||
self.config = config or default_config
|
||
|
||
# MCP配置
|
||
self.user_id = user_id
|
||
self.db_session = db_session
|
||
self._enable_mcp = enable_mcp
|
||
self._cached_tools: Optional[List[Dict]] = None
|
||
self._tools_loaded = False
|
||
|
||
self._openai_provider: Optional[OpenAIProvider] = None
|
||
self._anthropic_provider: Optional[AnthropicProvider] = None
|
||
self._gemini_provider: Optional[GeminiProvider] = None
|
||
|
||
# 初始化 OpenAI
|
||
openai_key = api_key if self.api_provider == "openai" else app_settings.openai_api_key
|
||
if openai_key:
|
||
base_url = api_base_url if self.api_provider == "openai" else app_settings.openai_base_url
|
||
client = OpenAIClient(openai_key, base_url or "https://api.openai.com/v1", self.config)
|
||
self._openai_provider = OpenAIProvider(client)
|
||
|
||
# 初始化 Anthropic
|
||
anthropic_key = api_key if self.api_provider == "anthropic" else app_settings.anthropic_api_key
|
||
if anthropic_key:
|
||
base_url = api_base_url if self.api_provider == "anthropic" else app_settings.anthropic_base_url
|
||
client = AnthropicClient(anthropic_key, base_url, self.config)
|
||
self._anthropic_provider = AnthropicProvider(client)
|
||
|
||
# 初始化 Gemini
|
||
if self.api_provider == "gemini" and api_key:
|
||
client = GeminiClient(api_key, api_base_url, self.config)
|
||
self._gemini_provider = GeminiProvider(client)
|
||
|
||
@property
|
||
def enable_mcp(self) -> bool:
|
||
"""是否启用MCP工具"""
|
||
return self._enable_mcp
|
||
|
||
@enable_mcp.setter
|
||
def enable_mcp(self, value: bool):
|
||
"""设置MCP启用状态,如果禁用则清理缓存"""
|
||
if value is False and self._enable_mcp is True:
|
||
# 从启用变为禁用,清理缓存
|
||
self.clear_mcp_cache()
|
||
self._enable_mcp = value
|
||
|
||
def clear_mcp_cache(self):
|
||
"""
|
||
清理MCP工具缓存
|
||
|
||
当禁用MCP时调用此方法,确保后续AI调用不会使用缓存的工具。
|
||
同时更新 _tools_loaded 状态,使下次调用时重新检查。
|
||
"""
|
||
if self._cached_tools is not None:
|
||
logger.info(f"🔧 清理MCP工具缓存,移除 {len(self._cached_tools)} 个工具")
|
||
self._cached_tools = None
|
||
else:
|
||
logger.debug(f"🔧 MCP工具缓存已经是空,无需清理")
|
||
|
||
# 更新加载状态,确保下次调用会重新检查
|
||
self._tools_loaded = False
|
||
logger.debug(f"🔧 MCP工具状态已重置: enable_mcp={self._enable_mcp}, _tools_loaded=False")
|
||
|
||
def _get_provider(self, provider: Optional[str] = None) -> BaseAIProvider:
|
||
"""获取对应的 Provider"""
|
||
p = normalize_provider(provider or self.api_provider)
|
||
if p == "openai" and self._openai_provider:
|
||
return self._openai_provider
|
||
if p == "anthropic" and self._anthropic_provider:
|
||
return self._anthropic_provider
|
||
if p == "gemini" and self._gemini_provider:
|
||
return self._gemini_provider
|
||
raise ValueError(f"Provider {p} 未初始化")
|
||
|
||
def _build_call_metrics(
|
||
self,
|
||
*,
|
||
request_mode: str,
|
||
provider: Optional[str],
|
||
model: Optional[str],
|
||
prompt: str,
|
||
auto_mcp: bool,
|
||
tools_count: int,
|
||
stream: bool,
|
||
) -> AICallMetrics:
|
||
return AICallMetrics(
|
||
request_mode=request_mode,
|
||
provider=normalize_provider(provider or self.api_provider) or "unknown",
|
||
model=model or self.default_model,
|
||
user_id=self.user_id,
|
||
stream=stream,
|
||
auto_mcp=auto_mcp,
|
||
tools_count=tools_count,
|
||
prompt_length=len(prompt or ""),
|
||
)
|
||
|
||
def _log_call_metrics(self, metrics: AICallMetrics, title: Optional[str] = None):
|
||
log_title = title or ("AI调用完成" if metrics.success else "AI调用失败")
|
||
message = metrics.to_log_message(log_title)
|
||
if metrics.success:
|
||
logger.info(message)
|
||
else:
|
||
logger.error(message)
|
||
|
||
async def _prepare_mcp_tools(self, auto_mcp: bool = True, force_refresh: bool = False) -> Optional[List[Dict]]:
|
||
"""
|
||
预处理MCP工具
|
||
|
||
检查用户MCP配置并加载可用工具。
|
||
结果会被缓存,避免重复加载。
|
||
|
||
Args:
|
||
auto_mcp: 是否自动加载MCP工具(来自调用方参数)
|
||
force_refresh: 是否强制刷新缓存
|
||
|
||
Returns:
|
||
- None: 无可用工具(未配置/未启用/加载失败)
|
||
- List[Dict]: OpenAI格式的工具列表
|
||
"""
|
||
# 前置条件检查
|
||
if not self._enable_mcp:
|
||
logger.debug(f"🔧 MCP工具未启用 (_enable_mcp=False)")
|
||
# 即使有缓存也清理掉,确保不使用
|
||
self._cached_tools = None
|
||
self._tools_loaded = False
|
||
return None
|
||
|
||
if not auto_mcp:
|
||
logger.debug(f"🔧 auto_mcp=False,跳过MCP工具加载")
|
||
# 即使有缓存也清理掉,确保不使用
|
||
self._cached_tools = None
|
||
self._tools_loaded = False
|
||
return None
|
||
|
||
if not self.user_id:
|
||
logger.debug(f"🔧 MCP工具加载跳过: user_id未设置")
|
||
return None
|
||
|
||
if not self.db_session:
|
||
logger.debug(f"🔧 MCP工具加载跳过: db_session未设置")
|
||
return None
|
||
|
||
# 使用缓存(只有 enable_mcp=True 时才使用缓存)
|
||
if self._tools_loaded and not force_refresh:
|
||
if self._cached_tools:
|
||
logger.debug(f"🔧 使用缓存的MCP工具 ({len(self._cached_tools)}个)")
|
||
return self._cached_tools
|
||
|
||
try:
|
||
from app.services.mcp_tools_loader import mcp_tools_loader
|
||
|
||
self._cached_tools = await mcp_tools_loader.get_user_tools(
|
||
user_id=self.user_id,
|
||
db_session=self.db_session,
|
||
use_cache=True,
|
||
force_refresh=force_refresh
|
||
)
|
||
self._tools_loaded = True
|
||
|
||
if self._cached_tools:
|
||
logger.info(f"🔧 已加载 {len(self._cached_tools)} 个MCP工具")
|
||
else:
|
||
logger.debug(f"📭 用户 {self.user_id} 没有可用的MCP工具")
|
||
|
||
return self._cached_tools
|
||
|
||
except Exception as e:
|
||
logger.warning(f"⚠️ 加载MCP工具失败: {e}")
|
||
self._tools_loaded = True
|
||
self._cached_tools = None
|
||
return None
|
||
|
||
async def _handle_tool_calls(
|
||
self,
|
||
original_prompt: str,
|
||
response: Dict[str, Any],
|
||
max_rounds: int = 2,
|
||
**kwargs
|
||
) -> Dict[str, Any]:
|
||
"""
|
||
处理AI返回的工具调用
|
||
|
||
Args:
|
||
original_prompt: 原始提示词
|
||
response: AI响应(包含tool_calls)
|
||
max_rounds: 最大工具调用轮数
|
||
**kwargs: 传递给generate_text的其他参数
|
||
|
||
Returns:
|
||
最终的AI响应
|
||
"""
|
||
from app.mcp import mcp_client
|
||
|
||
tool_calls = response.get("tool_calls", [])
|
||
if not tool_calls or not self.user_id:
|
||
return response
|
||
|
||
tool_metrics = ToolCallMetrics()
|
||
tool_metrics.usage.add(TokenUsage.from_response(response))
|
||
|
||
result = {
|
||
"content": response.get("content", ""),
|
||
"tool_calls_made": 0,
|
||
"tools_used": [],
|
||
"finish_reason": response.get("finish_reason", ""),
|
||
"mcp_enhanced": True,
|
||
"usage": response.get("usage"),
|
||
}
|
||
|
||
prompt = original_prompt
|
||
|
||
for round_num in range(max_rounds):
|
||
logger.info(f"🔧 工具调用 - 第{round_num+1}/{max_rounds}轮,{len(tool_calls)}个工具")
|
||
tool_metrics.mcp_rounds += 1
|
||
|
||
try:
|
||
# 批量执行工具调用
|
||
tool_results = await mcp_client.batch_call_tools(
|
||
user_id=self.user_id,
|
||
tool_calls=tool_calls
|
||
)
|
||
|
||
# 记录使用的工具
|
||
for tc in tool_calls:
|
||
name = tc["function"]["name"]
|
||
tool_metrics.add_tool_name(name)
|
||
if name not in result["tools_used"]:
|
||
result["tools_used"].append(name)
|
||
result["tool_calls_made"] += len(tool_calls)
|
||
tool_metrics.tool_calls_count += len(tool_calls)
|
||
|
||
# 构建工具上下文
|
||
tool_context = mcp_client.build_tool_context(tool_results, format="markdown")
|
||
|
||
# 更新提示词
|
||
if round_num == max_rounds - 1:
|
||
# 最后一轮,强制要求回答
|
||
prompt = f"{original_prompt}\n\n{tool_context}\n\n⚠️ 重要:请基于以上工具查询结果,给出完整详细的最终答案。不要再调用工具。"
|
||
tool_choice = "none"
|
||
else:
|
||
prompt = f"{original_prompt}\n\n{tool_context}\n\n请基于以上工具查询结果,继续完成任务。"
|
||
tool_choice = kwargs.get("tool_choice", "auto")
|
||
|
||
# 继续调用AI
|
||
prov = self._get_provider(kwargs.get("provider"))
|
||
next_response = await prov.generate(
|
||
prompt=prompt,
|
||
model=kwargs.get("model") or self.default_model,
|
||
temperature=kwargs.get("temperature") or self.default_temperature,
|
||
max_tokens=kwargs.get("max_tokens") or self.default_max_tokens,
|
||
system_prompt=kwargs.get("system_prompt") or self.default_system_prompt,
|
||
tools=None if tool_choice == "none" else self._cached_tools,
|
||
tool_choice=tool_choice,
|
||
)
|
||
tool_metrics.usage.add(TokenUsage.from_response(next_response))
|
||
|
||
tool_calls = next_response.get("tool_calls", [])
|
||
|
||
if not tool_calls:
|
||
# 没有更多工具调用,返回结果
|
||
result["content"] = next_response.get("content", "")
|
||
result["finish_reason"] = next_response.get("finish_reason", "stop")
|
||
result["usage"] = {
|
||
"prompt_tokens": tool_metrics.usage.prompt_tokens,
|
||
"completion_tokens": tool_metrics.usage.completion_tokens,
|
||
"total_tokens": tool_metrics.usage.total_tokens,
|
||
}
|
||
break
|
||
|
||
except Exception as e:
|
||
logger.error(f"❌ 工具调用失败: {e}")
|
||
tool_metrics.tool_error_count += 1
|
||
result["content"] = response.get("content", "")
|
||
result["finish_reason"] = "tool_error"
|
||
result["usage"] = {
|
||
"prompt_tokens": tool_metrics.usage.prompt_tokens,
|
||
"completion_tokens": tool_metrics.usage.completion_tokens,
|
||
"total_tokens": tool_metrics.usage.total_tokens,
|
||
}
|
||
break
|
||
|
||
result["__tool_metrics"] = tool_metrics
|
||
|
||
return result
|
||
|
||
async def generate_text(
|
||
self,
|
||
prompt: str,
|
||
provider: Optional[str] = None,
|
||
model: Optional[str] = None,
|
||
temperature: Optional[float] = None,
|
||
max_tokens: Optional[int] = None,
|
||
system_prompt: Optional[str] = None,
|
||
tools: Optional[List[Dict]] = None,
|
||
tool_choice: Optional[str] = None,
|
||
auto_mcp: bool = True,
|
||
handle_tool_calls: bool = True,
|
||
mcp_max_rounds: Optional[int] = None,
|
||
) -> Dict[str, Any]:
|
||
"""
|
||
生成文本(自动支持MCP工具)
|
||
|
||
Args:
|
||
prompt: 用户提示词
|
||
provider: AI提供商
|
||
model: 模型名称
|
||
temperature: 温度
|
||
max_tokens: 最大令牌数
|
||
system_prompt: 系统提示词
|
||
tools: 手动指定的工具列表(优先级高于自动加载)
|
||
tool_choice: 工具选择策略
|
||
auto_mcp: 是否自动加载MCP工具(默认True)
|
||
handle_tool_calls: 是否自动处理工具调用(默认True)
|
||
mcp_max_rounds: 最大工具调用轮数(None使用默认值3)
|
||
|
||
Returns:
|
||
包含生成内容的字典
|
||
"""
|
||
# 使用全局配置的MCP轮数(如果未指定)
|
||
if mcp_max_rounds is None:
|
||
mcp_max_rounds = app_settings.mcp_max_rounds
|
||
|
||
# 自动加载MCP工具
|
||
if auto_mcp and tools is None:
|
||
tools = await self._prepare_mcp_tools(auto_mcp=auto_mcp)
|
||
|
||
metrics = self._build_call_metrics(
|
||
request_mode="文本",
|
||
provider=provider,
|
||
model=model,
|
||
prompt=prompt,
|
||
auto_mcp=auto_mcp,
|
||
tools_count=len(tools) if tools else 0,
|
||
stream=False,
|
||
)
|
||
|
||
try:
|
||
prov = self._get_provider(provider)
|
||
response = await prov.generate(
|
||
prompt=prompt,
|
||
model=model or self.default_model,
|
||
temperature=temperature or self.default_temperature,
|
||
max_tokens=max_tokens or self.default_max_tokens,
|
||
system_prompt=system_prompt or self.default_system_prompt,
|
||
tools=tools,
|
||
tool_choice=tool_choice,
|
||
)
|
||
usage = TokenUsage.from_response(response)
|
||
|
||
# 处理工具调用
|
||
if handle_tool_calls and response.get("tool_calls"):
|
||
response = await self._handle_tool_calls(
|
||
original_prompt=prompt,
|
||
response=response,
|
||
provider=provider,
|
||
model=model,
|
||
temperature=temperature,
|
||
max_tokens=max_tokens,
|
||
system_prompt=system_prompt,
|
||
tool_choice=tool_choice,
|
||
max_rounds=mcp_max_rounds,
|
||
)
|
||
usage = TokenUsage.from_response(response)
|
||
tool_metrics = response.get("__tool_metrics")
|
||
if tool_metrics:
|
||
metrics.merge_tool_metrics(tool_metrics)
|
||
|
||
metrics.finish(
|
||
success=True,
|
||
response_length=len(response.get("content", "") or ""),
|
||
finish_reason=response.get("finish_reason"),
|
||
usage=usage,
|
||
)
|
||
self._log_call_metrics(metrics)
|
||
return response
|
||
except Exception as e:
|
||
metrics.finish(success=False, error=e)
|
||
self._log_call_metrics(metrics)
|
||
raise
|
||
|
||
async def generate_text_stream(
|
||
self,
|
||
prompt: str,
|
||
provider: Optional[str] = None,
|
||
model: Optional[str] = None,
|
||
temperature: Optional[float] = None,
|
||
max_tokens: Optional[int] = None,
|
||
system_prompt: Optional[str] = None,
|
||
tool_choice: Optional[str] = None,
|
||
auto_mcp: bool = True,
|
||
mcp_max_rounds: Optional[int] = None,
|
||
) -> AsyncGenerator[str, None]:
|
||
"""
|
||
流式生成文本(自动支持MCP工具)
|
||
|
||
工具调用在 Provider 层通过流式方式处理,支持真正的流式工具调用。
|
||
|
||
Args:
|
||
prompt: 用户提示词
|
||
provider: AI提供商
|
||
model: 模型名称
|
||
temperature: 温度
|
||
max_tokens: 最大令牌数
|
||
system_prompt: 系统提示词
|
||
tool_choice: 工具选择策略("auto"/"none"/"required")
|
||
auto_mcp: 是否自动加载MCP工具
|
||
mcp_max_rounds: 最大工具调用轮数(None使用默认值3)
|
||
|
||
Yields:
|
||
生成的文本块
|
||
"""
|
||
logger.debug(f"🔧 generate_text_stream: auto_mcp={auto_mcp}, tool_choice={tool_choice}")
|
||
|
||
tools_to_use = None
|
||
|
||
# 加载MCP工具
|
||
if auto_mcp:
|
||
tools_to_use = await self._prepare_mcp_tools(auto_mcp=auto_mcp)
|
||
if tools_to_use:
|
||
logger.info(f"🔧 已获取 {len(tools_to_use)} 个MCP工具")
|
||
|
||
metrics = self._build_call_metrics(
|
||
request_mode="流式文本",
|
||
provider=provider,
|
||
model=model,
|
||
prompt=prompt,
|
||
auto_mcp=auto_mcp,
|
||
tools_count=len(tools_to_use) if tools_to_use else 0,
|
||
stream=True,
|
||
)
|
||
response_parts: List[str] = []
|
||
latest_usage = TokenUsage()
|
||
finish_reason = "stop"
|
||
|
||
try:
|
||
# 流式生成(Provider 层处理工具调用)
|
||
prov = self._get_provider(provider)
|
||
logger.debug(f"🔧 开始流式生成,provider={provider or self.api_provider}, tools_count={len(tools_to_use) if tools_to_use else 0}")
|
||
async for chunk in prov.generate_stream(
|
||
prompt=prompt,
|
||
model=model or self.default_model,
|
||
temperature=temperature or self.default_temperature,
|
||
max_tokens=max_tokens or self.default_max_tokens,
|
||
system_prompt=system_prompt or self.default_system_prompt,
|
||
tools=tools_to_use,
|
||
tool_choice=tool_choice,
|
||
user_id=self.user_id,
|
||
):
|
||
if isinstance(chunk, dict):
|
||
if chunk.get("usage"):
|
||
latest_usage = TokenUsage.from_response({"usage": chunk.get("usage")})
|
||
if chunk.get("finish_reason"):
|
||
finish_reason = chunk.get("finish_reason") or finish_reason
|
||
continue
|
||
|
||
if chunk:
|
||
metrics.mark_first_chunk()
|
||
metrics.chunk_count += 1
|
||
response_parts.append(chunk)
|
||
yield chunk
|
||
|
||
metrics.finish(
|
||
success=True,
|
||
response_length=len("".join(response_parts)),
|
||
finish_reason=finish_reason,
|
||
usage=latest_usage,
|
||
)
|
||
self._log_call_metrics(metrics)
|
||
except Exception as e:
|
||
metrics.finish(
|
||
success=False,
|
||
response_length=len("".join(response_parts)),
|
||
finish_reason=finish_reason,
|
||
usage=latest_usage,
|
||
error=e,
|
||
)
|
||
self._log_call_metrics(metrics)
|
||
raise
|
||
|
||
async def call_with_json_retry(
|
||
self,
|
||
prompt: str,
|
||
system_prompt: Optional[str] = None,
|
||
max_retries: int = 3,
|
||
temperature: Optional[float] = None,
|
||
max_tokens: Optional[int] = None,
|
||
provider: Optional[str] = None,
|
||
model: Optional[str] = None,
|
||
expected_type: Optional[str] = None,
|
||
auto_mcp: bool = True,
|
||
) -> Union[Dict, List]:
|
||
"""
|
||
带重试的 JSON 调用(自动支持MCP工具)
|
||
|
||
Args:
|
||
prompt: 用户提示词
|
||
system_prompt: 系统提示词
|
||
max_retries: 最大重试次数
|
||
temperature: 温度
|
||
max_tokens: 最大令牌数
|
||
provider: AI提供商
|
||
model: 模型名称
|
||
expected_type: 期望的返回类型("object"或"array")
|
||
auto_mcp: 是否自动加载MCP工具
|
||
|
||
Returns:
|
||
解析后的JSON数据
|
||
"""
|
||
last_response = ""
|
||
aggregate_usage = TokenUsage()
|
||
metrics = self._build_call_metrics(
|
||
request_mode="JSON重试",
|
||
provider=provider,
|
||
model=model,
|
||
prompt=prompt,
|
||
auto_mcp=auto_mcp,
|
||
tools_count=0,
|
||
stream=False,
|
||
)
|
||
|
||
try:
|
||
for attempt in range(1, max_retries + 1):
|
||
current_prompt = prompt if attempt == 1 else self._add_json_hint(prompt, last_response, attempt)
|
||
|
||
result = await self.generate_text(
|
||
prompt=current_prompt,
|
||
provider=provider,
|
||
model=model,
|
||
temperature=temperature,
|
||
max_tokens=max_tokens,
|
||
system_prompt=system_prompt,
|
||
auto_mcp=auto_mcp,
|
||
handle_tool_calls=True,
|
||
)
|
||
aggregate_usage.add(TokenUsage.from_response(result))
|
||
metrics.retry_count = attempt
|
||
metrics.tools_count = max(metrics.tools_count, len(self._cached_tools) if self._cached_tools else 0)
|
||
|
||
last_response = result.get("content", "")
|
||
|
||
try:
|
||
data = parse_json(last_response)
|
||
if expected_type == "object" and not isinstance(data, dict):
|
||
raise ValueError("期望对象")
|
||
if expected_type == "array" and not isinstance(data, list):
|
||
raise ValueError("期望数组")
|
||
metrics.json_parse_success = True
|
||
metrics.finish(
|
||
success=True,
|
||
response_length=len(last_response),
|
||
finish_reason=result.get("finish_reason"),
|
||
usage=aggregate_usage,
|
||
)
|
||
self._log_call_metrics(metrics, title="AI调用汇总")
|
||
return data
|
||
except Exception as e:
|
||
metrics.json_parse_success = False
|
||
if attempt == max_retries:
|
||
raise ValueError(f"JSON 解析失败: {e}")
|
||
|
||
raise ValueError("JSON 调用失败")
|
||
except Exception as e:
|
||
metrics.finish(
|
||
success=False,
|
||
response_length=len(last_response),
|
||
usage=aggregate_usage,
|
||
error=e,
|
||
)
|
||
self._log_call_metrics(metrics, title="AI调用汇总")
|
||
raise
|
||
|
||
@staticmethod
|
||
def _add_json_hint(prompt: str, failed: str, attempt: int) -> str:
|
||
return f"{prompt}\n\n⚠️ 第{attempt}次重试,请返回纯JSON,不要markdown包裹。上次错误: {failed[:200]}..."
|
||
|
||
@staticmethod
|
||
def _clean_json_response(text: str) -> str:
|
||
"""清洗 JSON 响应"""
|
||
return clean_json_response(text)
|
||
|
||
|
||
def create_user_ai_service(
|
||
api_provider: str,
|
||
api_key: str,
|
||
api_base_url: str,
|
||
model_name: str,
|
||
temperature: float,
|
||
max_tokens: int,
|
||
system_prompt: Optional[str] = None,
|
||
) -> AIService:
|
||
"""创建用户 AI 服务(不带MCP支持)"""
|
||
return AIService(
|
||
api_provider=api_provider,
|
||
api_key=api_key,
|
||
api_base_url=api_base_url,
|
||
default_model=model_name,
|
||
default_temperature=temperature,
|
||
default_max_tokens=max_tokens,
|
||
default_system_prompt=system_prompt,
|
||
)
|
||
|
||
|
||
def create_user_ai_service_with_mcp(
|
||
api_provider: str,
|
||
api_key: str,
|
||
api_base_url: str,
|
||
model_name: str,
|
||
temperature: float,
|
||
max_tokens: int,
|
||
user_id: str,
|
||
db_session,
|
||
system_prompt: Optional[str] = None,
|
||
enable_mcp: bool = True,
|
||
) -> AIService:
|
||
"""
|
||
创建支持MCP的用户AI服务
|
||
|
||
Args:
|
||
api_provider: AI提供商
|
||
api_key: API密钥
|
||
api_base_url: API基础URL
|
||
model_name: 模型名称
|
||
temperature: 温度
|
||
max_tokens: 最大令牌数
|
||
user_id: 用户ID(用于加载MCP工具)
|
||
db_session: 数据库会话
|
||
system_prompt: 系统提示词
|
||
enable_mcp: 是否启用MCP工具
|
||
|
||
Returns:
|
||
配置好的AIService实例
|
||
"""
|
||
return AIService(
|
||
api_provider=api_provider,
|
||
api_key=api_key,
|
||
api_base_url=api_base_url,
|
||
default_model=model_name,
|
||
default_temperature=temperature,
|
||
default_max_tokens=max_tokens,
|
||
default_system_prompt=system_prompt,
|
||
user_id=user_id,
|
||
db_session=db_session,
|
||
enable_mcp=enable_mcp,
|
||
) |