From 7e21049216795c2a397b75e4eced5ff9dd27f9bc Mon Sep 17 00:00:00 2001 From: xiamuceer Date: Tue, 17 Mar 2026 17:31:08 +0800 Subject: [PATCH] =?UTF-8?q?fix=EF=BC=9A=E4=BF=AE=E5=A4=8DMuMu=E3=81=AEAPI?= =?UTF-8?q?=E9=80=82=E9=85=8D=E5=99=A8=E9=94=99=E8=AF=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/app/api/settings.py | 21 ++++++++++++++------- backend/app/services/ai_service.py | 21 ++++++++++++++------- 2 files changed, 28 insertions(+), 14 deletions(-) diff --git a/backend/app/api/settings.py b/backend/app/api/settings.py index 1fa629e..f9d66b5 100644 --- a/backend/app/api/settings.py +++ b/backend/app/api/settings.py @@ -23,7 +23,7 @@ from app.schemas.settings import ( from app.user_manager import User from app.logger import get_logger from app.config import settings as app_settings, PROJECT_ROOT -from app.services.ai_service import AIService, create_user_ai_service, create_user_ai_service_with_mcp +from app.services.ai_service import AIService, create_user_ai_service, create_user_ai_service_with_mcp, normalize_provider logger = get_logger(__name__) @@ -315,6 +315,7 @@ async def get_available_models( 模型列表 """ try: + provider = normalize_provider(provider) async with httpx.AsyncClient(timeout=10.0) as client: if provider == "openai" or provider == "azure" or provider == "custom": # OpenAI 兼容接口获取模型列表 @@ -436,7 +437,7 @@ async def check_function_calling_support(data: ApiTestRequest): """ api_key = data.api_key api_base_url = data.api_base_url - provider = data.provider + provider = normalize_provider(data.provider) llm_model = data.llm_model try: @@ -652,7 +653,7 @@ async def test_api_connection(data: ApiTestRequest): """ api_key = data.api_key api_base_url = data.api_base_url - provider = data.provider + provider = normalize_provider(data.provider) llm_model = data.llm_model # 使用前端传递的参数,如果未传递则使用默认值 temperature = data.temperature if data.temperature is not None else 0.7 @@ -897,7 +898,10 @@ async def create_preset( "description": data.description, "is_active": False, "created_at": datetime.now().isoformat(), - "config": data.config.model_dump() + "config": { + **data.config.model_dump(), + "api_provider": normalize_provider(data.config.api_provider) + } } presets.append(new_preset) @@ -947,7 +951,10 @@ async def update_preset( if data.description is not None: target_preset['description'] = data.description if data.config is not None: - target_preset['config'] = data.config.model_dump() + target_preset['config'] = { + **data.config.model_dump(), + 'api_provider': normalize_provider(data.config.api_provider) + } # 保存回preferences prefs['api_presets'] = api_presets @@ -1033,7 +1040,7 @@ async def activate_preset( # 应用配置到Settings主字段 config = target_preset['config'] - settings.api_provider = config['api_provider'] + settings.api_provider = normalize_provider(config['api_provider']) settings.api_key = config['api_key'] settings.api_base_url = config.get('api_base_url') settings.llm_model = config['llm_model'] @@ -1116,7 +1123,7 @@ async def create_preset_from_current( # 从当前Settings主字段读取配置 current_config = APIKeyPresetConfig( - api_provider=settings.api_provider, + api_provider=normalize_provider(settings.api_provider), api_key=settings.api_key, api_base_url=settings.api_base_url, llm_model=settings.llm_model, diff --git a/backend/app/services/ai_service.py b/backend/app/services/ai_service.py index bd31b86..f3eb706 100644 --- a/backend/app/services/ai_service.py +++ b/backend/app/services/ai_service.py @@ -26,6 +26,13 @@ cleanup_http_clients = cleanup_all_clients logger = get_logger(__name__) +def normalize_provider(provider: Optional[str]) -> Optional[str]: + """标准化 provider 名称,兼容渠道别名。""" + if provider == "mumu": + return "openai" + return provider + + class AIService: """ AI服务统一接口 @@ -78,7 +85,7 @@ class AIService: db_session: Optional[Any] = None, enable_mcp: bool = True, ): - self.api_provider = api_provider or app_settings.default_ai_provider + self.api_provider = normalize_provider(api_provider or app_settings.default_ai_provider) self.default_model = default_model or app_settings.default_model self.default_temperature = default_temperature or app_settings.default_temperature self.default_max_tokens = default_max_tokens or app_settings.default_max_tokens @@ -97,21 +104,21 @@ class AIService: self._gemini_provider: Optional[GeminiProvider] = None # 初始化 OpenAI - openai_key = api_key if api_provider == "openai" else app_settings.openai_api_key + openai_key = api_key if self.api_provider == "openai" else app_settings.openai_api_key if openai_key: - base_url = api_base_url if api_provider == "openai" else app_settings.openai_base_url + base_url = api_base_url if self.api_provider == "openai" else app_settings.openai_base_url client = OpenAIClient(openai_key, base_url or "https://api.openai.com/v1", self.config) self._openai_provider = OpenAIProvider(client) # 初始化 Anthropic - anthropic_key = api_key if api_provider == "anthropic" else app_settings.anthropic_api_key + anthropic_key = api_key if self.api_provider == "anthropic" else app_settings.anthropic_api_key if anthropic_key: - base_url = api_base_url if api_provider == "anthropic" else app_settings.anthropic_base_url + base_url = api_base_url if self.api_provider == "anthropic" else app_settings.anthropic_base_url client = AnthropicClient(anthropic_key, base_url, self.config) self._anthropic_provider = AnthropicProvider(client) # 初始化 Gemini - if api_provider == "gemini" and api_key: + if self.api_provider == "gemini" and api_key: client = GeminiClient(api_key, api_base_url, self.config) self._gemini_provider = GeminiProvider(client) @@ -147,7 +154,7 @@ class AIService: def _get_provider(self, provider: Optional[str] = None) -> BaseAIProvider: """获取对应的 Provider""" - p = provider or self.api_provider + p = normalize_provider(provider or self.api_provider) if p == "openai" and self._openai_provider: return self._openai_provider if p == "anthropic" and self._anthropic_provider: