fba6922a5c
- JSON解析器字符串状态追踪修复 - AI客户端流式响应异常处理 - 写作风格MultipleResultsFound错误 - 职业stages字段类型处理 - 章节分析任务状态同步 - 后台任务返回值修复
101 lines
3.3 KiB
Python
101 lines
3.3 KiB
Python
"""Anthropic 客户端"""
|
|
from typing import Any, AsyncGenerator, Dict, Optional
|
|
|
|
from anthropic import AsyncAnthropic
|
|
|
|
from app.logger import get_logger
|
|
from app.services.ai_config import AIClientConfig, default_config
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
class AnthropicClient:
|
|
"""Anthropic API 客户端"""
|
|
|
|
def __init__(self, api_key: str, base_url: Optional[str] = None, config: Optional[AIClientConfig] = None):
|
|
self.config = config or default_config
|
|
kwargs = {"api_key": api_key}
|
|
if base_url:
|
|
kwargs["base_url"] = base_url
|
|
self.client = AsyncAnthropic(**kwargs)
|
|
|
|
async def chat_completion(
|
|
self,
|
|
messages: list,
|
|
model: str,
|
|
temperature: float,
|
|
max_tokens: int,
|
|
system_prompt: Optional[str] = None,
|
|
tools: Optional[list] = None,
|
|
tool_choice: Optional[str] = None,
|
|
) -> Dict[str, Any]:
|
|
kwargs = {
|
|
"model": model,
|
|
"max_tokens": max_tokens,
|
|
"temperature": temperature,
|
|
"messages": messages,
|
|
}
|
|
if system_prompt:
|
|
kwargs["system"] = system_prompt
|
|
if tools:
|
|
kwargs["tools"] = tools
|
|
if tool_choice == "required":
|
|
kwargs["tool_choice"] = {"type": "any"}
|
|
elif tool_choice == "auto":
|
|
kwargs["tool_choice"] = {"type": "auto"}
|
|
|
|
response = await self.client.messages.create(**kwargs)
|
|
|
|
tool_calls = []
|
|
content = ""
|
|
for block in response.content:
|
|
if block.type == "tool_use":
|
|
tool_calls.append({
|
|
"id": block.id,
|
|
"type": "function",
|
|
"function": {"name": block.name, "arguments": block.input},
|
|
})
|
|
elif block.type == "text":
|
|
content += block.text
|
|
|
|
return {
|
|
"content": content,
|
|
"tool_calls": tool_calls if tool_calls else None,
|
|
"finish_reason": response.stop_reason,
|
|
}
|
|
|
|
async def chat_completion_stream(
|
|
self,
|
|
messages: list,
|
|
model: str,
|
|
temperature: float,
|
|
max_tokens: int,
|
|
system_prompt: Optional[str] = None,
|
|
) -> AsyncGenerator[str, None]:
|
|
kwargs = {
|
|
"model": model,
|
|
"max_tokens": max_tokens,
|
|
"temperature": temperature,
|
|
"messages": messages,
|
|
}
|
|
if system_prompt:
|
|
kwargs["system"] = system_prompt
|
|
|
|
try:
|
|
async with self.client.messages.stream(**kwargs) as stream:
|
|
try:
|
|
async for text in stream.text_stream:
|
|
yield text
|
|
except GeneratorExit:
|
|
# 生成器被关闭,这是正常的清理过程
|
|
logger.debug("Anthropic 流式响应生成器被关闭(GeneratorExit)")
|
|
raise
|
|
except Exception as iter_error:
|
|
logger.error(f"Anthropic 流式响应迭代出错: {str(iter_error)}")
|
|
raise
|
|
except GeneratorExit:
|
|
# 重新抛出GeneratorExit,让调用方处理
|
|
raise
|
|
except Exception as e:
|
|
logger.error(f"Anthropic 流式请求出错: {str(e)}")
|
|
raise |