fix:修复MuMuのAPI适配器错误
This commit is contained in:
@@ -23,7 +23,7 @@ from app.schemas.settings import (
|
|||||||
from app.user_manager import User
|
from app.user_manager import User
|
||||||
from app.logger import get_logger
|
from app.logger import get_logger
|
||||||
from app.config import settings as app_settings, PROJECT_ROOT
|
from app.config import settings as app_settings, PROJECT_ROOT
|
||||||
from app.services.ai_service import AIService, create_user_ai_service, create_user_ai_service_with_mcp
|
from app.services.ai_service import AIService, create_user_ai_service, create_user_ai_service_with_mcp, normalize_provider
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
@@ -315,6 +315,7 @@ async def get_available_models(
|
|||||||
模型列表
|
模型列表
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
|
provider = normalize_provider(provider)
|
||||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||||
if provider == "openai" or provider == "azure" or provider == "custom":
|
if provider == "openai" or provider == "azure" or provider == "custom":
|
||||||
# OpenAI 兼容接口获取模型列表
|
# OpenAI 兼容接口获取模型列表
|
||||||
@@ -436,7 +437,7 @@ async def check_function_calling_support(data: ApiTestRequest):
|
|||||||
"""
|
"""
|
||||||
api_key = data.api_key
|
api_key = data.api_key
|
||||||
api_base_url = data.api_base_url
|
api_base_url = data.api_base_url
|
||||||
provider = data.provider
|
provider = normalize_provider(data.provider)
|
||||||
llm_model = data.llm_model
|
llm_model = data.llm_model
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -652,7 +653,7 @@ async def test_api_connection(data: ApiTestRequest):
|
|||||||
"""
|
"""
|
||||||
api_key = data.api_key
|
api_key = data.api_key
|
||||||
api_base_url = data.api_base_url
|
api_base_url = data.api_base_url
|
||||||
provider = data.provider
|
provider = normalize_provider(data.provider)
|
||||||
llm_model = data.llm_model
|
llm_model = data.llm_model
|
||||||
# 使用前端传递的参数,如果未传递则使用默认值
|
# 使用前端传递的参数,如果未传递则使用默认值
|
||||||
temperature = data.temperature if data.temperature is not None else 0.7
|
temperature = data.temperature if data.temperature is not None else 0.7
|
||||||
@@ -897,7 +898,10 @@ async def create_preset(
|
|||||||
"description": data.description,
|
"description": data.description,
|
||||||
"is_active": False,
|
"is_active": False,
|
||||||
"created_at": datetime.now().isoformat(),
|
"created_at": datetime.now().isoformat(),
|
||||||
"config": data.config.model_dump()
|
"config": {
|
||||||
|
**data.config.model_dump(),
|
||||||
|
"api_provider": normalize_provider(data.config.api_provider)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
presets.append(new_preset)
|
presets.append(new_preset)
|
||||||
@@ -947,7 +951,10 @@ async def update_preset(
|
|||||||
if data.description is not None:
|
if data.description is not None:
|
||||||
target_preset['description'] = data.description
|
target_preset['description'] = data.description
|
||||||
if data.config is not None:
|
if data.config is not None:
|
||||||
target_preset['config'] = data.config.model_dump()
|
target_preset['config'] = {
|
||||||
|
**data.config.model_dump(),
|
||||||
|
'api_provider': normalize_provider(data.config.api_provider)
|
||||||
|
}
|
||||||
|
|
||||||
# 保存回preferences
|
# 保存回preferences
|
||||||
prefs['api_presets'] = api_presets
|
prefs['api_presets'] = api_presets
|
||||||
@@ -1033,7 +1040,7 @@ async def activate_preset(
|
|||||||
|
|
||||||
# 应用配置到Settings主字段
|
# 应用配置到Settings主字段
|
||||||
config = target_preset['config']
|
config = target_preset['config']
|
||||||
settings.api_provider = config['api_provider']
|
settings.api_provider = normalize_provider(config['api_provider'])
|
||||||
settings.api_key = config['api_key']
|
settings.api_key = config['api_key']
|
||||||
settings.api_base_url = config.get('api_base_url')
|
settings.api_base_url = config.get('api_base_url')
|
||||||
settings.llm_model = config['llm_model']
|
settings.llm_model = config['llm_model']
|
||||||
@@ -1116,7 +1123,7 @@ async def create_preset_from_current(
|
|||||||
|
|
||||||
# 从当前Settings主字段读取配置
|
# 从当前Settings主字段读取配置
|
||||||
current_config = APIKeyPresetConfig(
|
current_config = APIKeyPresetConfig(
|
||||||
api_provider=settings.api_provider,
|
api_provider=normalize_provider(settings.api_provider),
|
||||||
api_key=settings.api_key,
|
api_key=settings.api_key,
|
||||||
api_base_url=settings.api_base_url,
|
api_base_url=settings.api_base_url,
|
||||||
llm_model=settings.llm_model,
|
llm_model=settings.llm_model,
|
||||||
|
|||||||
@@ -26,6 +26,13 @@ cleanup_http_clients = cleanup_all_clients
|
|||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_provider(provider: Optional[str]) -> Optional[str]:
|
||||||
|
"""标准化 provider 名称,兼容渠道别名。"""
|
||||||
|
if provider == "mumu":
|
||||||
|
return "openai"
|
||||||
|
return provider
|
||||||
|
|
||||||
|
|
||||||
class AIService:
|
class AIService:
|
||||||
"""
|
"""
|
||||||
AI服务统一接口
|
AI服务统一接口
|
||||||
@@ -78,7 +85,7 @@ class AIService:
|
|||||||
db_session: Optional[Any] = None,
|
db_session: Optional[Any] = None,
|
||||||
enable_mcp: bool = True,
|
enable_mcp: bool = True,
|
||||||
):
|
):
|
||||||
self.api_provider = api_provider or app_settings.default_ai_provider
|
self.api_provider = normalize_provider(api_provider or app_settings.default_ai_provider)
|
||||||
self.default_model = default_model or app_settings.default_model
|
self.default_model = default_model or app_settings.default_model
|
||||||
self.default_temperature = default_temperature or app_settings.default_temperature
|
self.default_temperature = default_temperature or app_settings.default_temperature
|
||||||
self.default_max_tokens = default_max_tokens or app_settings.default_max_tokens
|
self.default_max_tokens = default_max_tokens or app_settings.default_max_tokens
|
||||||
@@ -97,21 +104,21 @@ class AIService:
|
|||||||
self._gemini_provider: Optional[GeminiProvider] = None
|
self._gemini_provider: Optional[GeminiProvider] = None
|
||||||
|
|
||||||
# 初始化 OpenAI
|
# 初始化 OpenAI
|
||||||
openai_key = api_key if api_provider == "openai" else app_settings.openai_api_key
|
openai_key = api_key if self.api_provider == "openai" else app_settings.openai_api_key
|
||||||
if openai_key:
|
if openai_key:
|
||||||
base_url = api_base_url if api_provider == "openai" else app_settings.openai_base_url
|
base_url = api_base_url if self.api_provider == "openai" else app_settings.openai_base_url
|
||||||
client = OpenAIClient(openai_key, base_url or "https://api.openai.com/v1", self.config)
|
client = OpenAIClient(openai_key, base_url or "https://api.openai.com/v1", self.config)
|
||||||
self._openai_provider = OpenAIProvider(client)
|
self._openai_provider = OpenAIProvider(client)
|
||||||
|
|
||||||
# 初始化 Anthropic
|
# 初始化 Anthropic
|
||||||
anthropic_key = api_key if api_provider == "anthropic" else app_settings.anthropic_api_key
|
anthropic_key = api_key if self.api_provider == "anthropic" else app_settings.anthropic_api_key
|
||||||
if anthropic_key:
|
if anthropic_key:
|
||||||
base_url = api_base_url if api_provider == "anthropic" else app_settings.anthropic_base_url
|
base_url = api_base_url if self.api_provider == "anthropic" else app_settings.anthropic_base_url
|
||||||
client = AnthropicClient(anthropic_key, base_url, self.config)
|
client = AnthropicClient(anthropic_key, base_url, self.config)
|
||||||
self._anthropic_provider = AnthropicProvider(client)
|
self._anthropic_provider = AnthropicProvider(client)
|
||||||
|
|
||||||
# 初始化 Gemini
|
# 初始化 Gemini
|
||||||
if api_provider == "gemini" and api_key:
|
if self.api_provider == "gemini" and api_key:
|
||||||
client = GeminiClient(api_key, api_base_url, self.config)
|
client = GeminiClient(api_key, api_base_url, self.config)
|
||||||
self._gemini_provider = GeminiProvider(client)
|
self._gemini_provider = GeminiProvider(client)
|
||||||
|
|
||||||
@@ -147,7 +154,7 @@ class AIService:
|
|||||||
|
|
||||||
def _get_provider(self, provider: Optional[str] = None) -> BaseAIProvider:
|
def _get_provider(self, provider: Optional[str] = None) -> BaseAIProvider:
|
||||||
"""获取对应的 Provider"""
|
"""获取对应的 Provider"""
|
||||||
p = provider or self.api_provider
|
p = normalize_provider(provider or self.api_provider)
|
||||||
if p == "openai" and self._openai_provider:
|
if p == "openai" and self._openai_provider:
|
||||||
return self._openai_provider
|
return self._openai_provider
|
||||||
if p == "anthropic" and self._anthropic_provider:
|
if p == "anthropic" and self._anthropic_provider:
|
||||||
|
|||||||
Reference in New Issue
Block a user