update:1.优化 AI 流式生成和进度显示系统 2.新增写作风格系统提示词支持 3.灵感模式功能增强,支持灵感重写 4.设置页面功能扩展,新增Gemini适配器 5.提示词模板系统优化,调整灵感模式提示词
This commit is contained in:
@@ -0,0 +1,86 @@
|
||||
"""Anthropic 客户端"""
|
||||
from typing import Any, AsyncGenerator, Dict, Optional
|
||||
|
||||
from anthropic import AsyncAnthropic
|
||||
|
||||
from app.logger import get_logger
|
||||
from app.services.ai_config import AIClientConfig, default_config
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class AnthropicClient:
|
||||
"""Anthropic API 客户端"""
|
||||
|
||||
def __init__(self, api_key: str, base_url: Optional[str] = None, config: Optional[AIClientConfig] = None):
|
||||
self.config = config or default_config
|
||||
kwargs = {"api_key": api_key}
|
||||
if base_url:
|
||||
kwargs["base_url"] = base_url
|
||||
self.client = AsyncAnthropic(**kwargs)
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
messages: list,
|
||||
model: str,
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
system_prompt: Optional[str] = None,
|
||||
tools: Optional[list] = None,
|
||||
tool_choice: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
kwargs = {
|
||||
"model": model,
|
||||
"max_tokens": max_tokens,
|
||||
"temperature": temperature,
|
||||
"messages": messages,
|
||||
}
|
||||
if system_prompt:
|
||||
kwargs["system"] = system_prompt
|
||||
if tools:
|
||||
kwargs["tools"] = tools
|
||||
if tool_choice == "required":
|
||||
kwargs["tool_choice"] = {"type": "any"}
|
||||
elif tool_choice == "auto":
|
||||
kwargs["tool_choice"] = {"type": "auto"}
|
||||
|
||||
response = await self.client.messages.create(**kwargs)
|
||||
|
||||
tool_calls = []
|
||||
content = ""
|
||||
for block in response.content:
|
||||
if block.type == "tool_use":
|
||||
tool_calls.append({
|
||||
"id": block.id,
|
||||
"type": "function",
|
||||
"function": {"name": block.name, "arguments": block.input},
|
||||
})
|
||||
elif block.type == "text":
|
||||
content += block.text
|
||||
|
||||
return {
|
||||
"content": content,
|
||||
"tool_calls": tool_calls if tool_calls else None,
|
||||
"finish_reason": response.stop_reason,
|
||||
}
|
||||
|
||||
async def chat_completion_stream(
|
||||
self,
|
||||
messages: list,
|
||||
model: str,
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
system_prompt: Optional[str] = None,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
kwargs = {
|
||||
"model": model,
|
||||
"max_tokens": max_tokens,
|
||||
"temperature": temperature,
|
||||
"messages": messages,
|
||||
}
|
||||
if system_prompt:
|
||||
kwargs["system"] = system_prompt
|
||||
|
||||
async with self.client.messages.stream(**kwargs) as stream:
|
||||
async for text in stream.text_stream:
|
||||
yield text
|
||||
Reference in New Issue
Block a user