支持自定义API接口
This commit is contained in:
@@ -2,7 +2,7 @@
|
||||
from typing import Optional, AsyncGenerator, List, Dict, Any
|
||||
from openai import AsyncOpenAI
|
||||
from anthropic import AsyncAnthropic
|
||||
from app.config import settings
|
||||
from app.config import settings as app_settings
|
||||
from app.logger import get_logger
|
||||
import httpx
|
||||
|
||||
@@ -10,12 +10,37 @@ logger = get_logger(__name__)
|
||||
|
||||
|
||||
class AIService:
|
||||
"""AI服务统一接口"""
|
||||
"""AI服务统一接口 - 支持从用户设置或全局配置初始化"""
|
||||
|
||||
def __init__(self):
|
||||
"""初始化AI客户端(优化并发性能)"""
|
||||
def __init__(
|
||||
self,
|
||||
api_provider: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
api_base_url: Optional[str] = None,
|
||||
default_model: Optional[str] = None,
|
||||
default_temperature: Optional[float] = None,
|
||||
default_max_tokens: Optional[int] = None
|
||||
):
|
||||
"""
|
||||
初始化AI客户端(优化并发性能)
|
||||
|
||||
Args:
|
||||
api_provider: API提供商 (openai/anthropic),为None时使用全局配置
|
||||
api_key: API密钥,为None时使用全局配置
|
||||
api_base_url: API基础URL,为None时使用全局配置
|
||||
default_model: 默认模型,为None时使用全局配置
|
||||
default_temperature: 默认温度,为None时使用全局配置
|
||||
default_max_tokens: 默认最大tokens,为None时使用全局配置
|
||||
"""
|
||||
# 保存用户设置或使用全局配置
|
||||
self.api_provider = api_provider or app_settings.default_ai_provider
|
||||
self.default_model = default_model or app_settings.default_model
|
||||
self.default_temperature = default_temperature or app_settings.default_temperature
|
||||
self.default_max_tokens = default_max_tokens or app_settings.default_max_tokens
|
||||
|
||||
# 初始化OpenAI客户端
|
||||
if settings.openai_api_key:
|
||||
openai_key = api_key if api_provider == "openai" else app_settings.openai_api_key
|
||||
if openai_key:
|
||||
# 创建自定义的httpx客户端来避免proxies参数问题
|
||||
try:
|
||||
# 配置连接池限制,支持高并发
|
||||
@@ -43,12 +68,14 @@ class AIService:
|
||||
)
|
||||
|
||||
client_kwargs = {
|
||||
"api_key": settings.openai_api_key,
|
||||
"api_key": openai_key,
|
||||
"http_client": http_client
|
||||
}
|
||||
|
||||
if settings.openai_base_url:
|
||||
client_kwargs["base_url"] = settings.openai_base_url
|
||||
# 优先使用用户提供的base_url,否则使用全局配置
|
||||
base_url = api_base_url if api_provider == "openai" else app_settings.openai_base_url
|
||||
if base_url:
|
||||
client_kwargs["base_url"] = base_url
|
||||
|
||||
self.openai_client = AsyncOpenAI(**client_kwargs)
|
||||
logger.info("✅ OpenAI客户端初始化成功")
|
||||
@@ -62,7 +89,8 @@ class AIService:
|
||||
logger.warning("OpenAI API key未配置")
|
||||
|
||||
# 初始化Anthropic客户端
|
||||
if settings.anthropic_api_key:
|
||||
anthropic_key = api_key if api_provider == "anthropic" else app_settings.anthropic_api_key
|
||||
if anthropic_key:
|
||||
try:
|
||||
# 为Anthropic设置相同的超时和连接池配置
|
||||
limits = httpx.Limits(
|
||||
@@ -82,12 +110,14 @@ class AIService:
|
||||
)
|
||||
|
||||
client_kwargs = {
|
||||
"api_key": settings.anthropic_api_key,
|
||||
"api_key": anthropic_key,
|
||||
"http_client": http_client
|
||||
}
|
||||
|
||||
if settings.anthropic_base_url:
|
||||
client_kwargs["base_url"] = settings.anthropic_base_url
|
||||
# 优先使用用户提供的base_url,否则使用全局配置
|
||||
base_url = api_base_url if api_provider == "anthropic" else app_settings.anthropic_base_url
|
||||
if base_url:
|
||||
client_kwargs["base_url"] = base_url
|
||||
|
||||
self.anthropic_client = AsyncAnthropic(**client_kwargs)
|
||||
logger.info("✅ Anthropic客户端初始化成功")
|
||||
@@ -123,10 +153,10 @@ class AIService:
|
||||
Returns:
|
||||
生成的文本
|
||||
"""
|
||||
provider = provider or settings.default_ai_provider
|
||||
model = model or settings.default_model
|
||||
temperature = temperature or settings.default_temperature
|
||||
max_tokens = max_tokens or settings.default_max_tokens
|
||||
provider = provider or self.api_provider
|
||||
model = model or self.default_model
|
||||
temperature = temperature or self.default_temperature
|
||||
max_tokens = max_tokens or self.default_max_tokens
|
||||
|
||||
if provider == "openai":
|
||||
return await self._generate_openai(
|
||||
@@ -162,10 +192,10 @@ class AIService:
|
||||
Yields:
|
||||
生成的文本片段
|
||||
"""
|
||||
provider = provider or settings.default_ai_provider
|
||||
model = model or settings.default_model
|
||||
temperature = temperature or settings.default_temperature
|
||||
max_tokens = max_tokens or settings.default_max_tokens
|
||||
provider = provider or self.api_provider
|
||||
model = model or self.default_model
|
||||
temperature = temperature or self.default_temperature
|
||||
max_tokens = max_tokens or self.default_max_tokens
|
||||
|
||||
if provider == "openai":
|
||||
async for chunk in self._generate_openai_stream(
|
||||
@@ -359,5 +389,37 @@ class AIService:
|
||||
raise
|
||||
|
||||
|
||||
# 创建全局AI服务实例
|
||||
ai_service = AIService()
|
||||
# 创建全局AI服务实例(使用环境变量配置,用于向后兼容)
|
||||
ai_service = AIService()
|
||||
|
||||
|
||||
def create_user_ai_service(
|
||||
api_provider: str,
|
||||
api_key: str,
|
||||
api_base_url: str,
|
||||
model_name: str,
|
||||
temperature: float,
|
||||
max_tokens: int
|
||||
) -> AIService:
|
||||
"""
|
||||
根据用户设置创建AI服务实例
|
||||
|
||||
Args:
|
||||
api_provider: API提供商
|
||||
api_key: API密钥
|
||||
api_base_url: API基础URL
|
||||
model_name: 模型名称
|
||||
temperature: 温度参数
|
||||
max_tokens: 最大tokens
|
||||
|
||||
Returns:
|
||||
AIService实例
|
||||
"""
|
||||
return AIService(
|
||||
api_provider=api_provider,
|
||||
api_key=api_key,
|
||||
api_base_url=api_base_url,
|
||||
default_model=model_name,
|
||||
default_temperature=temperature,
|
||||
default_max_tokens=max_tokens
|
||||
)
|
||||
Reference in New Issue
Block a user