296 lines
13 KiB
Python
296 lines
13 KiB
Python
"""AI服务封装 - 统一的AI接口"""
|
|
from typing import Optional, AsyncGenerator, List, Dict, Any, Union
|
|
|
|
from app.config import settings as app_settings
|
|
from app.logger import get_logger
|
|
from app.services.ai_config import AIClientConfig, default_config
|
|
from app.services.ai_clients.openai_client import OpenAIClient
|
|
from app.services.ai_clients.anthropic_client import AnthropicClient
|
|
from app.services.ai_clients.gemini_client import GeminiClient
|
|
from app.services.ai_clients.base_client import cleanup_all_clients
|
|
from app.services.ai_providers.openai_provider import OpenAIProvider
|
|
from app.services.ai_providers.anthropic_provider import AnthropicProvider
|
|
from app.services.ai_providers.gemini_provider import GeminiProvider
|
|
from app.services.ai_providers.base_provider import BaseAIProvider
|
|
from app.services.json_helper import clean_json_response, parse_json
|
|
from app.mcp.adapters.universal import universal_mcp_adapter
|
|
|
|
# 导出清理函数
|
|
cleanup_http_clients = cleanup_all_clients
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
class AIService:
|
|
"""AI服务统一接口"""
|
|
|
|
def __init__(
|
|
self,
|
|
api_provider: Optional[str] = None,
|
|
api_key: Optional[str] = None,
|
|
api_base_url: Optional[str] = None,
|
|
default_model: Optional[str] = None,
|
|
default_temperature: Optional[float] = None,
|
|
default_max_tokens: Optional[int] = None,
|
|
default_system_prompt: Optional[str] = None,
|
|
enable_mcp_adapter: bool = True,
|
|
config: Optional[AIClientConfig] = None,
|
|
):
|
|
self.api_provider = api_provider or app_settings.default_ai_provider
|
|
self.default_model = default_model or app_settings.default_model
|
|
self.default_temperature = default_temperature or app_settings.default_temperature
|
|
self.default_max_tokens = default_max_tokens or app_settings.default_max_tokens
|
|
self.default_system_prompt = default_system_prompt
|
|
self.config = config or default_config
|
|
|
|
self.mcp_adapter = universal_mcp_adapter if enable_mcp_adapter else None
|
|
|
|
self._openai_provider: Optional[OpenAIProvider] = None
|
|
self._anthropic_provider: Optional[AnthropicProvider] = None
|
|
self._gemini_provider: Optional[GeminiProvider] = None
|
|
|
|
# 初始化 OpenAI
|
|
openai_key = api_key if api_provider == "openai" else app_settings.openai_api_key
|
|
if openai_key:
|
|
base_url = api_base_url if api_provider == "openai" else app_settings.openai_base_url
|
|
client = OpenAIClient(openai_key, base_url or "https://api.openai.com/v1", self.config)
|
|
self._openai_provider = OpenAIProvider(client)
|
|
|
|
# 初始化 Anthropic
|
|
anthropic_key = api_key if api_provider == "anthropic" else app_settings.anthropic_api_key
|
|
if anthropic_key:
|
|
base_url = api_base_url if api_provider == "anthropic" else app_settings.anthropic_base_url
|
|
client = AnthropicClient(anthropic_key, base_url, self.config)
|
|
self._anthropic_provider = AnthropicProvider(client)
|
|
|
|
# 初始化 Gemini
|
|
if api_provider == "gemini" and api_key:
|
|
client = GeminiClient(api_key, api_base_url, self.config)
|
|
self._gemini_provider = GeminiProvider(client)
|
|
|
|
def _get_provider(self, provider: Optional[str] = None) -> BaseAIProvider:
|
|
"""获取对应的 Provider"""
|
|
p = provider or self.api_provider
|
|
if p == "openai" and self._openai_provider:
|
|
return self._openai_provider
|
|
if p == "anthropic" and self._anthropic_provider:
|
|
return self._anthropic_provider
|
|
if p == "gemini" and self._gemini_provider:
|
|
return self._gemini_provider
|
|
raise ValueError(f"Provider {p} 未初始化")
|
|
|
|
async def generate_text(
|
|
self,
|
|
prompt: str,
|
|
provider: Optional[str] = None,
|
|
model: Optional[str] = None,
|
|
temperature: Optional[float] = None,
|
|
max_tokens: Optional[int] = None,
|
|
system_prompt: Optional[str] = None,
|
|
tools: Optional[List[Dict]] = None,
|
|
tool_choice: Optional[str] = None,
|
|
) -> Dict[str, Any]:
|
|
"""生成文本"""
|
|
prov = self._get_provider(provider)
|
|
return await prov.generate(
|
|
prompt=prompt,
|
|
model=model or self.default_model,
|
|
temperature=temperature or self.default_temperature,
|
|
max_tokens=max_tokens or self.default_max_tokens,
|
|
system_prompt=system_prompt or self.default_system_prompt,
|
|
tools=tools,
|
|
tool_choice=tool_choice,
|
|
)
|
|
|
|
async def generate_text_stream(
|
|
self,
|
|
prompt: str,
|
|
provider: Optional[str] = None,
|
|
model: Optional[str] = None,
|
|
temperature: Optional[float] = None,
|
|
max_tokens: Optional[int] = None,
|
|
system_prompt: Optional[str] = None,
|
|
) -> AsyncGenerator[str, None]:
|
|
"""流式生成"""
|
|
prov = self._get_provider(provider)
|
|
async for chunk in prov.generate_stream(
|
|
prompt=prompt,
|
|
model=model or self.default_model,
|
|
temperature=temperature or self.default_temperature,
|
|
max_tokens=max_tokens or self.default_max_tokens,
|
|
system_prompt=system_prompt or self.default_system_prompt,
|
|
):
|
|
yield chunk
|
|
|
|
async def call_with_json_retry(
|
|
self,
|
|
prompt: str,
|
|
system_prompt: Optional[str] = None,
|
|
max_retries: int = 3,
|
|
temperature: Optional[float] = None,
|
|
max_tokens: Optional[int] = None,
|
|
provider: Optional[str] = None,
|
|
model: Optional[str] = None,
|
|
expected_type: Optional[str] = None,
|
|
) -> Union[Dict, List]:
|
|
"""带重试的 JSON 调用"""
|
|
last_response = ""
|
|
|
|
for attempt in range(1, max_retries + 1):
|
|
current_prompt = prompt if attempt == 1 else self._add_json_hint(prompt, last_response, attempt)
|
|
|
|
result = await self.generate_text(
|
|
prompt=current_prompt,
|
|
provider=provider,
|
|
model=model,
|
|
temperature=temperature,
|
|
max_tokens=max_tokens,
|
|
system_prompt=system_prompt,
|
|
)
|
|
|
|
last_response = result.get("content", "")
|
|
|
|
try:
|
|
data = parse_json(last_response)
|
|
if expected_type == "object" and not isinstance(data, dict):
|
|
raise ValueError("期望对象")
|
|
if expected_type == "array" and not isinstance(data, list):
|
|
raise ValueError("期望数组")
|
|
return data
|
|
except Exception as e:
|
|
if attempt == max_retries:
|
|
raise ValueError(f"JSON 解析失败: {e}")
|
|
|
|
raise ValueError("JSON 调用失败")
|
|
|
|
@staticmethod
|
|
def _add_json_hint(prompt: str, failed: str, attempt: int) -> str:
|
|
return f"{prompt}\n\n⚠️ 第{attempt}次重试,请返回纯JSON,不要markdown包裹。上次错误: {failed[:200]}..."
|
|
|
|
@staticmethod
|
|
def _clean_json_response(text: str) -> str:
|
|
"""清洗 JSON 响应"""
|
|
return clean_json_response(text)
|
|
|
|
async def generate_text_with_mcp(
|
|
self,
|
|
prompt: str,
|
|
user_id: str,
|
|
db_session,
|
|
enable_mcp: bool = True,
|
|
max_tool_rounds: int = 3,
|
|
tool_choice: str = "auto",
|
|
**kwargs
|
|
) -> Dict[str, Any]:
|
|
"""支持MCP工具的AI文本生成"""
|
|
from app.services.mcp_tool_service import mcp_tool_service, MCPToolServiceError
|
|
|
|
result = {"content": "", "tool_calls_made": 0, "tools_used": [], "finish_reason": "", "mcp_enhanced": False}
|
|
tools = None
|
|
|
|
if enable_mcp:
|
|
try:
|
|
tools = await mcp_tool_service.get_user_enabled_tools(user_id=user_id, db_session=db_session)
|
|
if tools:
|
|
result["mcp_enhanced"] = True
|
|
except MCPToolServiceError:
|
|
tools = None
|
|
|
|
original_prompt = prompt # 保存原始提示词
|
|
|
|
for round_num in range(max_tool_rounds):
|
|
logger.debug(f"🔄 MCP工具调用 - 第{round_num+1}/{max_tool_rounds}轮")
|
|
logger.debug(f" prompt长度: {len(prompt)}, tools数量: {len(tools) if tools else 0}, tool_choice: {tool_choice}")
|
|
|
|
ai_response = await self.generate_text(prompt=prompt, tools=tools, tool_choice=tool_choice, **kwargs)
|
|
logger.debug(f" AI响应: finish_reason={ai_response.get('finish_reason')}, content长度={len(ai_response.get('content', ''))}")
|
|
|
|
tool_calls = ai_response.get("tool_calls", [])
|
|
|
|
if not tool_calls:
|
|
content = ai_response.get("content", "")
|
|
result["content"] = content
|
|
result["finish_reason"] = ai_response.get("finish_reason", "stop")
|
|
logger.debug(f" ✅ 无工具调用,返回内容长度: {len(content)}")
|
|
|
|
# 🔧 修复:如果内容为空且已经调用过工具,强制要求AI给出答案
|
|
if not content.strip() and result["tool_calls_made"] > 0:
|
|
logger.warning(f"⚠️ AI在工具调用后返回空内容,尝试强制要求回答(第{round_num+1}轮)")
|
|
prompt = f"{prompt}\n\n⚠️ 请注意:你必须基于以上工具查询结果,给出完整的回答。不要返回空内容。"
|
|
tools = None
|
|
tool_choice = "none" # 强制不使用工具
|
|
continue
|
|
|
|
break
|
|
|
|
logger.info(f"🔧 检测到 {len(tool_calls)} 个工具调用")
|
|
for idx, tc in enumerate(tool_calls):
|
|
logger.debug(f" 工具{idx+1}: {tc.get('function', {}).get('name')} - 参数: {tc.get('function', {}).get('arguments')}")
|
|
|
|
try:
|
|
logger.debug(f" 开始执行工具调用...")
|
|
tool_results = await mcp_tool_service.execute_tool_calls(user_id=user_id, tool_calls=tool_calls, db_session=db_session)
|
|
logger.debug(f" 工具执行完成,结果数量: {len(tool_results)}")
|
|
|
|
# 🔍 检查工具结果
|
|
for idx, tr in enumerate(tool_results):
|
|
success = tr.get("success", False)
|
|
content_preview = tr.get("content", "")[:200] if tr.get("content") else "None"
|
|
logger.debug(f" 工具结果[{idx}]: success={success}, content预览={content_preview}")
|
|
|
|
for tc in tool_calls:
|
|
name = tc["function"]["name"]
|
|
if name not in result["tools_used"]:
|
|
result["tools_used"].append(name)
|
|
result["tool_calls_made"] += len(tool_calls)
|
|
|
|
tool_context = await mcp_tool_service.build_tool_context(tool_results, format="markdown")
|
|
logger.debug(f" 工具上下文长度: {len(tool_context)}")
|
|
logger.debug(f" 工具上下文预览: {tool_context[:300] if len(tool_context) > 300 else tool_context}")
|
|
|
|
# 🔧 改进:在最后一轮时,明确要求AI给出完整答案
|
|
if round_num == max_tool_rounds - 1:
|
|
logger.info(f"⚠️ 最后一轮,强制要求AI给出最终答案")
|
|
prompt = f"{original_prompt}\n\n{tool_context}\n\n⚠️ 重要:这是最后一轮,请基于以上工具查询的参考资料,给出完整详细的最终答案。不要再调用工具。"
|
|
tool_choice = "none"
|
|
else:
|
|
prompt = f"{original_prompt}\n\n{tool_context}\n\n请基于以上工具查询结果,继续完成任务。"
|
|
logger.debug(f" 新prompt长度: {len(prompt)}")
|
|
|
|
tools = None # 工具调用后禁用工具列表,避免重复调用
|
|
logger.debug(f" ✅ 工具调用成功,准备下一轮")
|
|
|
|
except Exception as tool_error:
|
|
logger.error(f"❌ 工具调用执行失败: {tool_error}", exc_info=True)
|
|
logger.error(f" 错误类型: {type(tool_error).__name__}")
|
|
logger.error(f" AI响应内容: {ai_response.get('content', '')[:200]}")
|
|
result["content"] = ai_response.get("content", "")
|
|
result["finish_reason"] = "tool_error"
|
|
break
|
|
|
|
return result
|
|
|
|
|
|
# 全局实例
|
|
ai_service = AIService()
|
|
|
|
|
|
def create_user_ai_service(
|
|
api_provider: str,
|
|
api_key: str,
|
|
api_base_url: str,
|
|
model_name: str,
|
|
temperature: float,
|
|
max_tokens: int,
|
|
system_prompt: Optional[str] = None,
|
|
) -> AIService:
|
|
"""创建用户 AI 服务"""
|
|
return AIService(
|
|
api_provider=api_provider,
|
|
api_key=api_key,
|
|
api_base_url=api_base_url,
|
|
default_model=model_name,
|
|
default_temperature=temperature,
|
|
default_max_tokens=max_tokens,
|
|
default_system_prompt=system_prompt,
|
|
) |