From 3aefdd433d32ff21b9f333ffec708834cf0ee8d3 Mon Sep 17 00:00:00 2001 From: xiamuceer Date: Thu, 30 Oct 2025 16:53:50 +0800 Subject: [PATCH] =?UTF-8?q?=E6=94=AF=E6=8C=81=E8=87=AA=E5=AE=9A=E4=B9=89AP?= =?UTF-8?q?I=E6=8E=A5=E5=8F=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/app/api/chapters.py | 13 +- backend/app/api/characters.py | 8 +- backend/app/api/outlines.py | 20 +- backend/app/api/polish.py | 11 +- backend/app/api/settings.py | 299 +++++++++++++++ backend/app/api/wizard_stream.py | 94 +++-- backend/app/main.py | 3 +- backend/app/models/settings.py | 9 +- backend/app/schemas/settings.py | 37 ++ backend/app/services/ai_service.py | 106 ++++-- backend/app/services/prompt_service.py | 60 ++- frontend/src/App.tsx | 2 + frontend/src/pages/ProjectList.tsx | 29 +- frontend/src/pages/Settings.tsx | 507 +++++++++++++++++++++++++ frontend/src/services/api.ts | 17 + frontend/src/types/index.ts | 25 ++ 16 files changed, 1143 insertions(+), 97 deletions(-) create mode 100644 backend/app/api/settings.py create mode 100644 backend/app/schemas/settings.py create mode 100644 frontend/src/pages/Settings.tsx diff --git a/backend/app/api/chapters.py b/backend/app/api/chapters.py index b4ae9b4..c625ed3 100644 --- a/backend/app/api/chapters.py +++ b/backend/app/api/chapters.py @@ -18,9 +18,10 @@ from app.schemas.chapter import ( ChapterResponse, ChapterListResponse ) -from app.services.ai_service import ai_service +from app.services.ai_service import AIService from app.services.prompt_service import prompt_service from app.logger import get_logger +from app.api.settings import get_user_ai_service router = APIRouter(prefix="/chapters", tags=["章节管理"]) logger = get_logger(__name__) @@ -247,7 +248,8 @@ async def check_can_generate( @router.post("/{chapter_id}/generate", summary="AI创作章节内容") async def generate_chapter_content( chapter_id: str, - db: AsyncSession = Depends(get_db) + db: AsyncSession = Depends(get_db), + user_ai_service: AIService = Depends(get_user_ai_service) ): """ 根据大纲、前置章节内容和项目信息AI创作章节完整内容 @@ -372,7 +374,7 @@ async def generate_chapter_content( logger.info(f"开始AI创作章节 {chapter_id}") # 调用AI生成 - ai_content = await ai_service.generate_text( + ai_content = await user_ai_service.generate_text( prompt=prompt ) @@ -410,7 +412,8 @@ async def generate_chapter_content( @router.post("/{chapter_id}/generate-stream", summary="AI创作章节内容(流式)") async def generate_chapter_content_stream( chapter_id: str, - request: Request + request: Request, + user_ai_service: AIService = Depends(get_user_ai_service) ): """ 根据大纲、前置章节内容和项目信息AI创作章节完整内容(流式返回) @@ -569,7 +572,7 @@ async def generate_chapter_content_stream( # 流式生成内容 full_content = "" - async for chunk in ai_service.generate_text_stream(prompt=prompt): + async for chunk in user_ai_service.generate_text_stream(prompt=prompt): full_content += chunk yield f"data: {json.dumps({'type': 'content', 'content': chunk}, ensure_ascii=False)}\n\n" await asyncio.sleep(0) # 让出控制权 diff --git a/backend/app/api/characters.py b/backend/app/api/characters.py index b689ce7..13066b5 100644 --- a/backend/app/api/characters.py +++ b/backend/app/api/characters.py @@ -15,9 +15,10 @@ from app.schemas.character import ( CharacterListResponse, CharacterGenerateRequest ) -from app.services.ai_service import ai_service +from app.services.ai_service import AIService from app.services.prompt_service import prompt_service from app.logger import get_logger +from app.api.settings import get_user_ai_service router = APIRouter(prefix="/characters", tags=["角色管理"]) logger = get_logger(__name__) @@ -134,7 +135,8 @@ async def delete_character( @router.post("/generate", response_model=CharacterResponse, summary="AI生成角色") async def generate_character( request: CharacterGenerateRequest, - db: AsyncSession = Depends(get_db) + db: AsyncSession = Depends(get_db), + user_ai_service: AIService = Depends(get_user_ai_service) ): """ 使用AI生成角色卡 @@ -216,7 +218,7 @@ async def generate_character( logger.info(f" - Prompt长度:{len(prompt)} 字符") try: - ai_response = await ai_service.generate_text( + ai_response = await user_ai_service.generate_text( prompt=prompt, provider=request.provider, model=request.model diff --git a/backend/app/api/outlines.py b/backend/app/api/outlines.py index b5493fa..d246977 100644 --- a/backend/app/api/outlines.py +++ b/backend/app/api/outlines.py @@ -19,9 +19,10 @@ from app.schemas.outline import ( OutlineGenerateRequest, OutlineReorderRequest ) -from app.services.ai_service import ai_service +from app.services.ai_service import AIService from app.services.prompt_service import prompt_service from app.logger import get_logger +from app.api.settings import get_user_ai_service router = APIRouter(prefix="/outlines", tags=["大纲管理"]) logger = get_logger(__name__) @@ -326,7 +327,8 @@ async def reorder_outlines( @router.post("/generate", response_model=OutlineListResponse, summary="AI生成/续写大纲") async def generate_outline( request: OutlineGenerateRequest, - db: AsyncSession = Depends(get_db) + db: AsyncSession = Depends(get_db), + user_ai_service: AIService = Depends(get_user_ai_service) ): """ 使用AI生成或续写小说大纲 - 智能模式 @@ -363,7 +365,7 @@ async def generate_outline( # 模式:全新生成 if actual_mode == "new": return await _generate_new_outline( - request, project, db + request, project, db, user_ai_service ) # 模式:续写 @@ -375,7 +377,7 @@ async def generate_outline( ) return await _continue_outline( - request, project, existing_outlines, db + request, project, existing_outlines, db, user_ai_service ) else: @@ -394,7 +396,8 @@ async def generate_outline( async def _generate_new_outline( request: OutlineGenerateRequest, project: Project, - db: AsyncSession + db: AsyncSession, + user_ai_service: AIService ) -> OutlineListResponse: """全新生成大纲""" logger.info(f"全新生成大纲 - 项目: {project.id}, keep_existing: {request.keep_existing}") @@ -427,7 +430,7 @@ async def _generate_new_outline( ) # 调用AI - ai_response = await ai_service.generate_text( + ai_response = await user_ai_service.generate_text( prompt=prompt, provider=request.provider, model=request.model @@ -473,7 +476,8 @@ async def _continue_outline( request: OutlineGenerateRequest, project: Project, existing_outlines: List[Outline], - db: AsyncSession + db: AsyncSession, + user_ai_service: AIService ) -> OutlineListResponse: """续写大纲""" logger.info(f"续写大纲 - 项目: {project.id}, 已有: {len(existing_outlines)} 章") @@ -536,7 +540,7 @@ async def _continue_outline( ) # 调用AI - ai_response = await ai_service.generate_text( + ai_response = await user_ai_service.generate_text( prompt=prompt, provider=request.provider, model=request.model diff --git a/backend/app/api/polish.py b/backend/app/api/polish.py index d07839e..35767ff 100644 --- a/backend/app/api/polish.py +++ b/backend/app/api/polish.py @@ -5,9 +5,10 @@ from sqlalchemy.ext.asyncio import AsyncSession from app.database import get_db from app.models.generation_history import GenerationHistory from app.schemas.polish import PolishRequest, PolishResponse -from app.services.ai_service import ai_service +from app.services.ai_service import AIService from app.services.prompt_service import prompt_service from app.logger import get_logger +from app.api.settings import get_user_ai_service router = APIRouter(prefix="/polish", tags=["AI去味"]) logger = get_logger(__name__) @@ -16,7 +17,8 @@ logger = get_logger(__name__) @router.post("", response_model=PolishResponse, summary="AI去味") async def polish_text( request: PolishRequest, - db: AsyncSession = Depends(get_db) + db: AsyncSession = Depends(get_db), + user_ai_service: AIService = Depends(get_user_ai_service) ): """ AI去味 - 将AI生成的文本改写得更像人类作家的手笔 @@ -83,7 +85,8 @@ async def polish_batch( project_id: int = None, provider: str = None, model: str = None, - db: AsyncSession = Depends(get_db) + db: AsyncSession = Depends(get_db), + user_ai_service: AIService = Depends(get_user_ai_service) ): """ 批量处理多个文本的AI去味 @@ -98,7 +101,7 @@ async def polish_batch( prompt = prompt_service.get_denoising_prompt(original_text=text) - polished_text = await ai_service.generate_text( + polished_text = await user_ai_service.generate_text( prompt=prompt, provider=provider, model=model diff --git a/backend/app/api/settings.py b/backend/app/api/settings.py new file mode 100644 index 0000000..434466a --- /dev/null +++ b/backend/app/api/settings.py @@ -0,0 +1,299 @@ +""" +设置管理 API +""" +from fastapi import APIRouter, HTTPException, Request, Depends +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy import select +from typing import Dict, Any, List +from pathlib import Path +import httpx + +from app.database import get_db +from app.models.settings import Settings +from app.schemas.settings import SettingsCreate, SettingsUpdate, SettingsResponse +from app.user_manager import User +from app.logger import get_logger +from app.config import settings as app_settings, PROJECT_ROOT +from app.services.ai_service import AIService, create_user_ai_service + +logger = get_logger(__name__) + +router = APIRouter(prefix="/settings", tags=["设置管理"]) + + +def read_env_defaults() -> Dict[str, Any]: + """从.env文件读取默认配置(仅读取,不修改)""" + return { + "api_provider": app_settings.default_ai_provider, + "api_key": app_settings.openai_api_key or app_settings.anthropic_api_key or "", + "api_base_url": app_settings.openai_base_url or app_settings.anthropic_base_url or "", + "model_name": app_settings.default_model, + "temperature": app_settings.default_temperature, + "max_tokens": app_settings.default_max_tokens, + } + + +def require_login(request: Request): + """依赖:要求用户已登录""" + if not hasattr(request.state, "user") or not request.state.user: + raise HTTPException(status_code=401, detail="需要登录") + return request.state.user + + +async def get_user_ai_service( + user: User = Depends(require_login), + db: AsyncSession = Depends(get_db) +) -> AIService: + """ + 依赖:获取当前用户的AI服务实例 + 从数据库读取用户设置并创建对应的AI服务 + """ + result = await db.execute( + select(Settings).where(Settings.user_id == user.user_id) + ) + settings = result.scalar_one_or_none() + + if not settings: + # 如果用户没有设置,从.env读取并保存 + env_defaults = read_env_defaults() + settings = Settings( + user_id=user.user_id, + **env_defaults + ) + db.add(settings) + await db.commit() + await db.refresh(settings) + logger.info(f"用户 {user.user_id} 首次使用AI服务,已从.env同步设置到数据库") + + # 使用用户设置创建AI服务实例 + return create_user_ai_service( + api_provider=settings.api_provider, + api_key=settings.api_key, + api_base_url=settings.api_base_url or "", + model_name=settings.model_name, + temperature=settings.temperature, + max_tokens=settings.max_tokens + ) + + +@router.get("", response_model=SettingsResponse) +async def get_settings( + user: User = Depends(require_login), + db: AsyncSession = Depends(get_db) +): + """ + 获取当前用户的设置 + 如果用户没有保存过设置,自动从.env创建并保存到数据库 + """ + result = await db.execute( + select(Settings).where(Settings.user_id == user.user_id) + ) + settings = result.scalar_one_or_none() + + if not settings: + # 如果用户没有保存过设置,从.env读取默认配置并保存到数据库 + env_defaults = read_env_defaults() + logger.info(f"用户 {user.user_id} 首次获取设置,自动从.env同步到数据库") + + # 创建新设置并保存到数据库 + settings = Settings( + user_id=user.user_id, + **env_defaults + ) + db.add(settings) + await db.commit() + await db.refresh(settings) + logger.info(f"用户 {user.user_id} 的设置已从.env同步到数据库") + + logger.info(f"用户 {user.user_id} 获取已保存的设置") + return settings + + +@router.post("", response_model=SettingsResponse) +async def save_settings( + data: SettingsCreate, + user: User = Depends(require_login), + db: AsyncSession = Depends(get_db) +): + """ + 创建或更新当前用户的设置(Upsert) + 如果设置已存在则更新,否则创建新设置 + 仅保存到数据库 + """ + # 查找现有设置 + result = await db.execute( + select(Settings).where(Settings.user_id == user.user_id) + ) + settings = result.scalar_one_or_none() + + # 准备数据 + settings_dict = data.model_dump(exclude_unset=True) + + if settings: + # 更新现有设置 + for key, value in settings_dict.items(): + setattr(settings, key, value) + + await db.commit() + await db.refresh(settings) + logger.info(f"用户 {user.user_id} 更新设置") + else: + # 创建新设置 + settings = Settings( + user_id=user.user_id, + **settings_dict + ) + db.add(settings) + await db.commit() + await db.refresh(settings) + logger.info(f"用户 {user.user_id} 创建设置") + + return settings + + +@router.put("", response_model=SettingsResponse) +async def update_settings( + data: SettingsUpdate, + user: User = Depends(require_login), + db: AsyncSession = Depends(get_db) +): + """ + 更新当前用户的设置 + 仅保存到数据库 + """ + result = await db.execute( + select(Settings).where(Settings.user_id == user.user_id) + ) + settings = result.scalar_one_or_none() + + if not settings: + raise HTTPException(status_code=404, detail="设置不存在,请先创建设置") + + # 更新设置 + update_data = data.model_dump(exclude_unset=True) + for key, value in update_data.items(): + setattr(settings, key, value) + + await db.commit() + await db.refresh(settings) + logger.info(f"用户 {user.user_id} 更新设置") + + return settings + + +@router.delete("") +async def delete_settings( + user: User = Depends(require_login), + db: AsyncSession = Depends(get_db) +): + """ + 删除当前用户的设置 + """ + result = await db.execute( + select(Settings).where(Settings.user_id == user.user_id) + ) + settings = result.scalar_one_or_none() + + if not settings: + raise HTTPException(status_code=404, detail="设置不存在") + + await db.delete(settings) + await db.commit() + logger.info(f"用户 {user.user_id} 删除设置") + + return {"message": "设置已删除", "user_id": user.user_id} + + +@router.get("/models") +async def get_available_models( + api_key: str, + api_base_url: str, + provider: str = "openai" +): + """ + 从配置的 API 获取可用的模型列表 + + Args: + api_key: API 密钥 + api_base_url: API 基础 URL + provider: API 提供商 (openai, anthropic, azure, custom) + + Returns: + 模型列表 + """ + try: + async with httpx.AsyncClient(timeout=10.0) as client: + if provider == "openai" or provider == "azure" or provider == "custom": + # OpenAI 兼容接口获取模型列表 + url = f"{api_base_url.rstrip('/')}/models" + headers = { + "Authorization": f"Bearer {api_key}", + "Content-Type": "application/json" + } + + logger.info(f"正在从 {url} 获取模型列表") + response = await client.get(url, headers=headers) + response.raise_for_status() + + data = response.json() + models = [] + + if "data" in data and isinstance(data["data"], list): + for model in data["data"]: + model_id = model.get("id", "") + # 过滤出常用的文本生成模型 + if any(keyword in model_id.lower() for keyword in [ + "gpt", "gemini", "claude", "llama", "mistral", "qwen", "deepseek" + ]): + models.append({ + "value": model_id, + "label": model_id, + "description": model.get("description", "") or f"Created: {model.get('created', 'N/A')}" + }) + + if not models: + raise HTTPException( + status_code=404, + detail="未能从 API 获取到可用的模型列表" + ) + + logger.info(f"成功获取 {len(models)} 个模型") + return { + "provider": provider, + "models": models, + "count": len(models) + } + + elif provider == "anthropic": + # Anthropic 没有公开的模型列表API + raise HTTPException( + status_code=400, + detail="Anthropic 不支持自动获取模型列表,请手动输入模型名称" + ) + + else: + raise HTTPException( + status_code=400, + detail=f"不支持的提供商: {provider}" + ) + + except httpx.HTTPStatusError as e: + logger.error(f"获取模型列表失败 (HTTP {e.response.status_code}): {e.response.text}") + raise HTTPException( + status_code=400, + detail=f"无法从 API 获取模型列表 (HTTP {e.response.status_code})" + ) + except httpx.RequestError as e: + logger.error(f"请求模型列表失败: {str(e)}") + raise HTTPException( + status_code=400, + detail=f"无法连接到 API: {str(e)}" + ) + except HTTPException: + raise + except Exception as e: + logger.error(f"获取模型列表时发生错误: {str(e)}") + raise HTTPException( + status_code=500, + detail=f"获取模型列表失败: {str(e)}" + ) \ No newline at end of file diff --git a/backend/app/api/wizard_stream.py b/backend/app/api/wizard_stream.py index ed05d0f..0d304aa 100644 --- a/backend/app/api/wizard_stream.py +++ b/backend/app/api/wizard_stream.py @@ -4,6 +4,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import select from typing import Dict, Any, AsyncGenerator import json +import re from app.database import get_db from app.models.project import Project @@ -11,10 +12,11 @@ from app.models.character import Character from app.models.outline import Outline from app.models.chapter import Chapter from app.models.relationship import CharacterRelationship, Organization, OrganizationMember, RelationshipType -from app.services.ai_service import ai_service +from app.services.ai_service import AIService from app.services.prompt_service import prompt_service from app.logger import get_logger from app.utils.sse_response import SSEResponse, create_sse_response +from app.api.settings import get_user_ai_service router = APIRouter(prefix="/wizard-stream", tags=["项目创建向导(流式)"]) logger = get_logger(__name__) @@ -22,7 +24,8 @@ logger = get_logger(__name__) async def world_building_generator( data: Dict[str, Any], - db: AsyncSession + db: AsyncSession, + user_ai_service: AIService ) -> AsyncGenerator[str, None]: """世界构建流式生成器""" # 标记数据库会话是否已提交 @@ -61,7 +64,7 @@ async def world_building_generator( accumulated_text = "" chunk_count = 0 - async for chunk in ai_service.generate_text_stream( + async for chunk in user_ai_service.generate_text_stream( prompt=prompt, provider=provider, model=model @@ -87,24 +90,26 @@ async def world_building_generator( world_data = {} try: cleaned_text = accumulated_text.strip() + + # 移除markdown代码块标记 if cleaned_text.startswith('```json'): - cleaned_text = cleaned_text[7:] - if cleaned_text.startswith('```'): - cleaned_text = cleaned_text[3:] + cleaned_text = cleaned_text[7:].lstrip('\n\r') + elif cleaned_text.startswith('```'): + cleaned_text = cleaned_text[3:].lstrip('\n\r') if cleaned_text.endswith('```'): - cleaned_text = cleaned_text[:-3] + cleaned_text = cleaned_text[:-3].rstrip('\n\r') cleaned_text = cleaned_text.strip() world_data = json.loads(cleaned_text) + except json.JSONDecodeError as e: - logger.error(f"AI返回非JSON格式: {e}") + logger.error(f"世界构建JSON解析失败: {e}") world_data = { - "time_period": accumulated_text[:300] if len(accumulated_text) > 300 else accumulated_text, + "time_period": "AI返回格式错误,请重试", "location": "AI返回格式错误,请重试", "atmosphere": "AI返回格式错误,请重试", "rules": "AI返回格式错误,请重试" } - # 保存到数据库 yield await SSEResponse.send_progress("保存到数据库...", 90) @@ -160,18 +165,20 @@ async def world_building_generator( @router.post("/world-building", summary="流式生成世界构建") async def generate_world_building_stream( data: Dict[str, Any], - db: AsyncSession = Depends(get_db) + db: AsyncSession = Depends(get_db), + user_ai_service: AIService = Depends(get_user_ai_service) ): """ 使用SSE流式生成世界构建,避免超时 前端使用EventSource接收实时进度和结果 """ - return create_sse_response(world_building_generator(data, db)) + return create_sse_response(world_building_generator(data, db, user_ai_service)) async def characters_generator( data: Dict[str, Any], - db: AsyncSession + db: AsyncSession, + user_ai_service: AIService ) -> AsyncGenerator[str, None]: """角色批量生成流式生成器 - 优化版:分批+重试""" db_committed = False @@ -270,7 +277,7 @@ async def characters_generator( # 流式生成 accumulated_text = "" - async for chunk in ai_service.generate_text_stream( + async for chunk in user_ai_service.generate_text_stream( prompt=prompt, provider=provider, model=model @@ -280,12 +287,13 @@ async def characters_generator( # 解析批次结果 cleaned_text = accumulated_text.strip() + # 移除markdown代码块标记 if cleaned_text.startswith('```json'): - cleaned_text = cleaned_text[7:] - if cleaned_text.startswith('```'): - cleaned_text = cleaned_text[3:] + cleaned_text = cleaned_text[7:].lstrip('\n\r') + elif cleaned_text.startswith('```'): + cleaned_text = cleaned_text[3:].lstrip('\n\r') if cleaned_text.endswith('```'): - cleaned_text = cleaned_text[:-3] + cleaned_text = cleaned_text[:-3].rstrip('\n\r') cleaned_text = cleaned_text.strip() characters_data = json.loads(cleaned_text) @@ -684,17 +692,19 @@ async def characters_generator( @router.post("/characters", summary="流式批量生成角色") async def generate_characters_stream( data: Dict[str, Any], - db: AsyncSession = Depends(get_db) + db: AsyncSession = Depends(get_db), + user_ai_service: AIService = Depends(get_user_ai_service) ): """ 使用SSE流式批量生成角色,避免超时 """ - return create_sse_response(characters_generator(data, db)) + return create_sse_response(characters_generator(data, db, user_ai_service)) async def outline_generator( data: Dict[str, Any], - db: AsyncSession + db: AsyncSession, + user_ai_service: AIService ) -> AsyncGenerator[str, None]: """大纲生成流式生成器 - 向导固定生成前8章作为开局""" db_committed = False @@ -778,6 +788,7 @@ async def outline_generator( batch_requirements += "2. 建立主线冲突和故事钩子\n" batch_requirements += "3. 展开初期情节,为后续发展埋下伏笔\n" batch_requirements += "4. 不要试图完结故事,这只是开始部分\n" + batch_requirements += "5. 不要在JSON字符串值中使用中文引号(""''),请使用【】或《》标记\n" batch_prompt = prompt_service.get_complete_outline_prompt( title=project.title, @@ -796,7 +807,7 @@ async def outline_generator( # 流式生成 accumulated_text = "" - async for chunk in ai_service.generate_text_stream( + async for chunk in user_ai_service.generate_text_stream( prompt=batch_prompt, provider=provider, model=model @@ -806,12 +817,14 @@ async def outline_generator( # 解析结果 cleaned_text = accumulated_text.strip() + + # 移除markdown代码块标记 if cleaned_text.startswith('```json'): - cleaned_text = cleaned_text[7:] - if cleaned_text.startswith('```'): - cleaned_text = cleaned_text[3:] + cleaned_text = cleaned_text[7:].lstrip('\n\r') + elif cleaned_text.startswith('```'): + cleaned_text = cleaned_text[3:].lstrip('\n\r') if cleaned_text.endswith('```'): - cleaned_text = cleaned_text[:-3] + cleaned_text = cleaned_text[:-3].rstrip('\n\r') cleaned_text = cleaned_text.strip() batch_outline_data = json.loads(cleaned_text) @@ -839,7 +852,7 @@ async def outline_generator( logger.info(f"批次{batch_idx+1}成功生成{len(batch_outline_data)}章大纲") except json.JSONDecodeError as e: - logger.error(f"批次{batch_idx+1}解析失败(尝试{retry_count+1}/{MAX_RETRIES}): {e}") + logger.error(f"大纲生成批次{batch_idx+1} JSON解析失败(尝试{retry_count+1}/{MAX_RETRIES}): {e}") retry_count += 1 if retry_count < MAX_RETRIES: yield await SSEResponse.send_progress( @@ -945,12 +958,13 @@ async def outline_generator( @router.post("/outline", summary="流式生成完整大纲") async def generate_outline_stream( data: Dict[str, Any], - db: AsyncSession = Depends(get_db) + db: AsyncSession = Depends(get_db), + user_ai_service: AIService = Depends(get_user_ai_service) ): """ 使用SSE流式生成完整大纲,避免超时 """ - return create_sse_response(outline_generator(data, db)) + return create_sse_response(outline_generator(data, db, user_ai_service)) async def update_world_building_generator( @@ -1037,7 +1051,8 @@ async def update_world_building_stream( async def regenerate_world_building_generator( project_id: str, data: Dict[str, Any], - db: AsyncSession + db: AsyncSession, + user_ai_service: AIService ) -> AsyncGenerator[str, None]: """重新生成世界观流式生成器""" db_committed = False @@ -1070,7 +1085,7 @@ async def regenerate_world_building_generator( accumulated_text = "" chunk_count = 0 - async for chunk in ai_service.generate_text_stream( + async for chunk in user_ai_service.generate_text_stream( prompt=prompt, provider=provider, model=model @@ -1096,19 +1111,21 @@ async def regenerate_world_building_generator( world_data = {} try: cleaned_text = accumulated_text.strip() + # 移除markdown代码块标记 if cleaned_text.startswith('```json'): - cleaned_text = cleaned_text[7:] - if cleaned_text.startswith('```'): - cleaned_text = cleaned_text[3:] + cleaned_text = cleaned_text[7:].lstrip('\n\r') + elif cleaned_text.startswith('```'): + cleaned_text = cleaned_text[3:].lstrip('\n\r') if cleaned_text.endswith('```'): - cleaned_text = cleaned_text[:-3] + cleaned_text = cleaned_text[:-3].rstrip('\n\r') cleaned_text = cleaned_text.strip() world_data = json.loads(cleaned_text) except json.JSONDecodeError as e: logger.error(f"AI返回非JSON格式: {e}") + logger.info(world_data) world_data = { - "time_period": accumulated_text[:300] if len(accumulated_text) > 300 else accumulated_text, + "time_period": "AI返回格式错误,请重试", "location": "AI返回格式错误,请重试", "atmosphere": "AI返回格式错误,请重试", "rules": "AI返回格式错误,请重试" @@ -1155,7 +1172,8 @@ async def regenerate_world_building_generator( async def regenerate_world_building_stream( project_id: str, data: Dict[str, Any], - db: AsyncSession = Depends(get_db) + db: AsyncSession = Depends(get_db), + user_ai_service: AIService = Depends(get_user_ai_service) ): """ 使用SSE流式重新生成项目的世界观 @@ -1165,7 +1183,7 @@ async def regenerate_world_building_stream( "model": "模型名称(可选)" } """ - return create_sse_response(regenerate_world_building_generator(project_id, data, db)) + return create_sse_response(regenerate_world_building_generator(project_id, data, db, user_ai_service)) async def cleanup_wizard_data_generator( diff --git a/backend/app/main.py b/backend/app/main.py index 0152612..fe65bad 100644 --- a/backend/app/main.py +++ b/backend/app/main.py @@ -114,11 +114,12 @@ async def db_session_stats(): from app.api import ( projects, outlines, characters, chapters, wizard_stream, relationships, organizations, - auth, users + auth, users, settings ) app.include_router(auth.router, prefix="/api") app.include_router(users.router, prefix="/api") +app.include_router(settings.router, prefix="/api") app.include_router(projects.router, prefix="/api") app.include_router(wizard_stream.router, prefix="/api") diff --git a/backend/app/models/settings.py b/backend/app/models/settings.py index c7aa33a..7aaa646 100644 --- a/backend/app/models/settings.py +++ b/backend/app/models/settings.py @@ -1,5 +1,5 @@ """设置数据模型""" -from sqlalchemy import Column, String, Text, Float, Integer, DateTime +from sqlalchemy import Column, String, Text, Float, Integer, DateTime, Index from sqlalchemy.sql import func from app.database import Base import uuid @@ -10,6 +10,7 @@ class Settings(Base): __tablename__ = "settings" id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4())) + user_id = Column(String(50), nullable=False, unique=True, index=True, comment="用户ID") api_provider = Column(String(50), default="openai", comment="API提供商") api_key = Column(String(500), comment="API密钥") api_base_url = Column(String(500), comment="自定义API地址") @@ -20,5 +21,9 @@ class Settings(Base): created_at = Column(DateTime, server_default=func.now(), comment="创建时间") updated_at = Column(DateTime, server_default=func.now(), onupdate=func.now(), comment="更新时间") + __table_args__ = ( + Index('idx_user_id', 'user_id'), + ) + def __repr__(self): - return f"" \ No newline at end of file + return f"" \ No newline at end of file diff --git a/backend/app/schemas/settings.py b/backend/app/schemas/settings.py new file mode 100644 index 0000000..5e87f8f --- /dev/null +++ b/backend/app/schemas/settings.py @@ -0,0 +1,37 @@ +"""设置相关的Pydantic模型""" +from pydantic import BaseModel, Field, ConfigDict +from typing import Optional +from datetime import datetime + + +class SettingsBase(BaseModel): + """设置基础模型""" + model_config = ConfigDict(protected_namespaces=()) + + api_provider: Optional[str] = Field(default="openai", description="API提供商") + api_key: Optional[str] = Field(default=None, description="API密钥") + api_base_url: Optional[str] = Field(default=None, description="自定义API地址") + model_name: Optional[str] = Field(default="gpt-4", description="模型名称") + temperature: Optional[float] = Field(default=0.7, ge=0.0, le=2.0, description="温度参数") + max_tokens: Optional[int] = Field(default=2000, ge=1, le=32000, description="最大token数") + preferences: Optional[str] = Field(default=None, description="其他偏好设置(JSON)") + + +class SettingsCreate(SettingsBase): + """创建设置请求模型""" + pass + + +class SettingsUpdate(SettingsBase): + """更新设置请求模型""" + pass + + +class SettingsResponse(SettingsBase): + """设置响应模型""" + model_config = ConfigDict(from_attributes=True, protected_namespaces=()) + + id: str + user_id: str + created_at: datetime + updated_at: datetime \ No newline at end of file diff --git a/backend/app/services/ai_service.py b/backend/app/services/ai_service.py index f2de62e..5cfe16c 100644 --- a/backend/app/services/ai_service.py +++ b/backend/app/services/ai_service.py @@ -2,7 +2,7 @@ from typing import Optional, AsyncGenerator, List, Dict, Any from openai import AsyncOpenAI from anthropic import AsyncAnthropic -from app.config import settings +from app.config import settings as app_settings from app.logger import get_logger import httpx @@ -10,12 +10,37 @@ logger = get_logger(__name__) class AIService: - """AI服务统一接口""" + """AI服务统一接口 - 支持从用户设置或全局配置初始化""" - def __init__(self): - """初始化AI客户端(优化并发性能)""" + def __init__( + self, + api_provider: Optional[str] = None, + api_key: Optional[str] = None, + api_base_url: Optional[str] = None, + default_model: Optional[str] = None, + default_temperature: Optional[float] = None, + default_max_tokens: Optional[int] = None + ): + """ + 初始化AI客户端(优化并发性能) + + Args: + api_provider: API提供商 (openai/anthropic),为None时使用全局配置 + api_key: API密钥,为None时使用全局配置 + api_base_url: API基础URL,为None时使用全局配置 + default_model: 默认模型,为None时使用全局配置 + default_temperature: 默认温度,为None时使用全局配置 + default_max_tokens: 默认最大tokens,为None时使用全局配置 + """ + # 保存用户设置或使用全局配置 + self.api_provider = api_provider or app_settings.default_ai_provider + self.default_model = default_model or app_settings.default_model + self.default_temperature = default_temperature or app_settings.default_temperature + self.default_max_tokens = default_max_tokens or app_settings.default_max_tokens + # 初始化OpenAI客户端 - if settings.openai_api_key: + openai_key = api_key if api_provider == "openai" else app_settings.openai_api_key + if openai_key: # 创建自定义的httpx客户端来避免proxies参数问题 try: # 配置连接池限制,支持高并发 @@ -43,12 +68,14 @@ class AIService: ) client_kwargs = { - "api_key": settings.openai_api_key, + "api_key": openai_key, "http_client": http_client } - if settings.openai_base_url: - client_kwargs["base_url"] = settings.openai_base_url + # 优先使用用户提供的base_url,否则使用全局配置 + base_url = api_base_url if api_provider == "openai" else app_settings.openai_base_url + if base_url: + client_kwargs["base_url"] = base_url self.openai_client = AsyncOpenAI(**client_kwargs) logger.info("✅ OpenAI客户端初始化成功") @@ -62,7 +89,8 @@ class AIService: logger.warning("OpenAI API key未配置") # 初始化Anthropic客户端 - if settings.anthropic_api_key: + anthropic_key = api_key if api_provider == "anthropic" else app_settings.anthropic_api_key + if anthropic_key: try: # 为Anthropic设置相同的超时和连接池配置 limits = httpx.Limits( @@ -82,12 +110,14 @@ class AIService: ) client_kwargs = { - "api_key": settings.anthropic_api_key, + "api_key": anthropic_key, "http_client": http_client } - if settings.anthropic_base_url: - client_kwargs["base_url"] = settings.anthropic_base_url + # 优先使用用户提供的base_url,否则使用全局配置 + base_url = api_base_url if api_provider == "anthropic" else app_settings.anthropic_base_url + if base_url: + client_kwargs["base_url"] = base_url self.anthropic_client = AsyncAnthropic(**client_kwargs) logger.info("✅ Anthropic客户端初始化成功") @@ -123,10 +153,10 @@ class AIService: Returns: 生成的文本 """ - provider = provider or settings.default_ai_provider - model = model or settings.default_model - temperature = temperature or settings.default_temperature - max_tokens = max_tokens or settings.default_max_tokens + provider = provider or self.api_provider + model = model or self.default_model + temperature = temperature or self.default_temperature + max_tokens = max_tokens or self.default_max_tokens if provider == "openai": return await self._generate_openai( @@ -162,10 +192,10 @@ class AIService: Yields: 生成的文本片段 """ - provider = provider or settings.default_ai_provider - model = model or settings.default_model - temperature = temperature or settings.default_temperature - max_tokens = max_tokens or settings.default_max_tokens + provider = provider or self.api_provider + model = model or self.default_model + temperature = temperature or self.default_temperature + max_tokens = max_tokens or self.default_max_tokens if provider == "openai": async for chunk in self._generate_openai_stream( @@ -359,5 +389,37 @@ class AIService: raise -# 创建全局AI服务实例 -ai_service = AIService() \ No newline at end of file +# 创建全局AI服务实例(使用环境变量配置,用于向后兼容) +ai_service = AIService() + + +def create_user_ai_service( + api_provider: str, + api_key: str, + api_base_url: str, + model_name: str, + temperature: float, + max_tokens: int +) -> AIService: + """ + 根据用户设置创建AI服务实例 + + Args: + api_provider: API提供商 + api_key: API密钥 + api_base_url: API基础URL + model_name: 模型名称 + temperature: 温度参数 + max_tokens: 最大tokens + + Returns: + AIService实例 + """ + return AIService( + api_provider=api_provider, + api_key=api_key, + api_base_url=api_base_url, + default_model=model_name, + default_temperature=temperature, + default_max_tokens=max_tokens + ) \ No newline at end of file diff --git a/backend/app/services/prompt_service.py b/backend/app/services/prompt_service.py index bad9a86..a96b68d 100644 --- a/backend/app/services/prompt_service.py +++ b/backend/app/services/prompt_service.py @@ -26,7 +26,14 @@ class PromptService: - 为故事发展提供支撑 - 具有独特性和吸引力 -**重要:你必须只返回纯JSON格式,不要包含任何markdown标记、代码块标记或其他说明文字。** +**重要格式要求:** +1. 只返回纯JSON格式,不要包含任何markdown标记、代码块标记或其他说明文字 +2. 不要在JSON字符串值中使用中文引号(""''),请使用英文引号或直接省略引号 +3. 专有名词和强调内容可以使用【】或《》标记,不要用引号 + +**正确示例**: +- ✅ "距离【大灾变】爆发" 或 "距离大灾变爆发" +- ❌ "距离"大灾变"爆发" (会导致JSON解析失败) 请严格按照以下JSON格式返回(每个字段为200-300字的文本描述): {{ @@ -36,7 +43,10 @@ class PromptService: "rules": "世界规则的详细描述,包括运行法则、特殊设定、社会规则、权力结构" }} -再次强调:只返回纯JSON对象,不要有```json```这样的标记,不要有任何额外的文字说明。""" +再次强调: +1. 只返回纯JSON对象,不要有```json```这样的标记 +2. 文本中不要使用中文引号(""),使用【】或《》代替 +3. 不要有任何额外的文字说明""" # 批量角色生成提示词 CHARACTERS_BATCH_GENERATION = """你是一位专业的角色设定师。请根据以下世界观和要求,生成{count}个立体丰满的角色和组织: @@ -67,7 +77,10 @@ class PromptService: - 组织要有存在的合理性 - 所有实体要为故事服务 -**重要:你必须只返回纯JSON数组格式,不要包含任何markdown标记、代码块标记或其他说明文字。** +**重要格式要求:** +1. 只返回纯JSON数组格式,不要包含任何markdown标记、代码块标记或其他说明文字 +2. 不要在JSON字符串值中使用中文引号(""''),请使用英文引号或【】《》标记 +3. 专有名词和强调内容使用【】或《》,不要用引号 请严格按照以下JSON数组格式返回(每个角色为数组中的一个对象): [ @@ -134,7 +147,8 @@ class PromptService: 再次强调: 1. 只返回纯JSON数组,不要有```json```这样的标记 2. 数组中必须精确包含{count}个对象 -3. 不要引用任何本批次中不存在的角色或组织名称""" +3. 不要引用任何本批次中不存在的角色或组织名称 +4. 文本描述中不要使用中文引号(""),改用【】或《》""" # 完整大纲生成提示词 COMPLETE_OUTLINE_GENERATION = """你是一位经验丰富的小说作家和编剧。请根据以下信息生成完整的{chapter_count}章小说大纲: @@ -166,7 +180,10 @@ class PromptService: - 节奏把控:有张有弛 - 视角统一:采用{narrative_perspective}视角叙事 -**重要:你必须只返回纯JSON数组格式,不要包含任何markdown标记、代码块标记或其他说明文字。** +**重要格式要求:** +1. 只返回纯JSON数组格式,不要包含任何markdown标记、代码块标记或其他说明文字 +2. 不要在JSON字符串值中使用中文引号(""''),请使用【】或《》标记 +3. 专有名词、书名、事件名使用【】或《》 请严格按照以下JSON数组格式返回(共{chapter_count}个章节对象): [ @@ -192,7 +209,10 @@ class PromptService: }} ] -再次强调:只返回纯JSON数组,不要有```json```这样的标记,不要有任何额外的文字说明。数组中要包含{chapter_count}个章节对象。""" +再次强调: +1. 只返回纯JSON数组,不要有```json```这样的标记 +2. 数组中要包含{chapter_count}个章节对象 +3. 文本中不要使用中文引号(""),改用【】或《》""" # 大纲续写提示词 OUTLINE_CONTINUE_GENERATION = """你是一位经验丰富的小说作家和编剧。请基于以下信息续写小说大纲: @@ -232,7 +252,10 @@ class PromptService: - 保持与已有章节相同的风格和详细程度 - 推进角色成长和情节发展 -**重要:你必须只返回纯JSON数组格式,不要包含任何markdown标记、代码块标记或其他说明文字。** +**重要格式要求:** +1. 只返回纯JSON数组格式,不要包含任何markdown标记、代码块标记或其他说明文字 +2. 不要在JSON字符串值中使用中文引号(""''),请使用【】或《》 +3. 文本描述中的专有名词使用【】标记 请严格按照以下JSON数组格式返回(共{chapter_count}个章节对象): [ @@ -262,7 +285,8 @@ class PromptService: 1. 只返回纯JSON数组,不要有```json```这样的标记 2. 数组中要包含{chapter_count}个章节对象 3. 每个summary必须是100-200字的详细描述 -4. 确保字段结构与已有章节完全一致""" +4. 确保字段结构与已有章节完全一致 +5. 文本中不要使用中文引号(""),改用【】或《》""" # AI去味提示词(核心特色功能) AI_DENOISING = """你是一位追求自然写作风格的编辑。你的任务是将AI生成的文本改写得更像人类作家的手笔。 @@ -431,7 +455,10 @@ class PromptService: 4. 情节的递进和冲突升级 5. 角色的成长弧线 -**重要:你必须只返回纯JSON格式,不要包含任何markdown标记、代码块标记或其他说明文字。** +**重要格式要求:** +1. 只返回纯JSON格式,不要包含任何markdown标记、代码块标记或其他说明文字 +2. 不要在JSON字符串值中使用中文引号(""''),改用【】或《》 +3. 专有名词和强调内容使用【】标记 请严格按照以下JSON格式返回: {{ @@ -444,7 +471,10 @@ class PromptService: ] }} -再次强调:只返回纯JSON对象,不要有```json```这样的标记,不要有任何额外的文字说明。""" +再次强调: +1. 只返回纯JSON对象,不要有```json```这样的标记 +2. 文本中不要使用中文引号(""),改用【】或《》 +3. 不要有任何额外的文字说明""" # 单个角色生成提示词 SINGLE_CHARACTER_GENERATION = """你是一位专业的角色设定师。请根据以下信息创建一个立体饱满的小说角色。 @@ -487,7 +517,10 @@ class PromptService: - 特殊技能或知识 - 符合世界观设定 -**你必须只返回纯JSON格式,不要包含任何markdown标记、代码块标记或其他说明文字。** +**重要格式要求:** +1. 只返回纯JSON格式,不要包含任何markdown标记、代码块标记或其他说明文字 +2. 不要在JSON字符串值中使用中文引号(""''),改用【】或《》 +3. 文本描述中的专有名词使用【】标记 请严格按照以下JSON格式返回: {{ @@ -543,7 +576,10 @@ class PromptService: - 配角要有独特性,不能是工具人 - 所有设定要为故事服务 -再次强调:只返回纯JSON对象,不要有```json```这样的标记,不要有任何额外的文字说明。""" +再次强调: +1. 只返回纯JSON对象,不要有```json```这样的标记 +2. 文本中不要使用中文引号(""),改用【】或《》 +3. 不要有任何额外的文字说明""" @staticmethod def format_prompt(template: str, **kwargs) -> str: diff --git a/frontend/src/App.tsx b/frontend/src/App.tsx index 18d76b6..cf26153 100644 --- a/frontend/src/App.tsx +++ b/frontend/src/App.tsx @@ -10,6 +10,7 @@ import Characters from './pages/Characters'; import Relationships from './pages/Relationships'; import Organizations from './pages/Organizations'; import Chapters from './pages/Chapters'; +import Settings from './pages/Settings'; // import Polish from './pages/Polish'; import Login from './pages/Login'; import AuthCallback from './pages/AuthCallback'; @@ -31,6 +32,7 @@ function App() { } /> } /> + } /> }> } /> } /> diff --git a/frontend/src/pages/ProjectList.tsx b/frontend/src/pages/ProjectList.tsx index fd06b2e..b79262b 100644 --- a/frontend/src/pages/ProjectList.tsx +++ b/frontend/src/pages/ProjectList.tsx @@ -1,7 +1,7 @@ import { useEffect } from 'react'; import { useNavigate } from 'react-router-dom'; import { Card, Button, Empty, Modal, message, Spin, Row, Col, Statistic, Space, Tag, Progress, Typography, Tooltip, Badge } from 'antd'; -import { EditOutlined, DeleteOutlined, BookOutlined, RocketOutlined, CalendarOutlined, FileTextOutlined, TrophyOutlined, FireOutlined } from '@ant-design/icons'; +import { EditOutlined, DeleteOutlined, BookOutlined, RocketOutlined, CalendarOutlined, FileTextOutlined, TrophyOutlined, FireOutlined, SettingOutlined } from '@ant-design/icons'; import { useStore } from '../store'; import { useProjectSync } from '../store/hooks'; import type { ReactNode } from 'react'; @@ -161,11 +161,36 @@ export default function ProjectList() { style={{ borderRadius: 8, background: 'linear-gradient(135deg, #667eea 0%, #764ba2 100%)', - border: 'none' + border: 'none', + boxShadow: '0 2px 8px rgba(102, 126, 234, 0.4)' }} > 向导创建 + diff --git a/frontend/src/pages/Settings.tsx b/frontend/src/pages/Settings.tsx new file mode 100644 index 0000000..db85280 --- /dev/null +++ b/frontend/src/pages/Settings.tsx @@ -0,0 +1,507 @@ +import { useState, useEffect } from 'react'; +import { useNavigate } from 'react-router-dom'; +import { Card, Form, Input, Button, Select, Slider, InputNumber, message, Space, Typography, Spin, Modal, Tooltip, Alert } from 'antd'; +import { SettingOutlined, SaveOutlined, DeleteOutlined, ReloadOutlined, ArrowLeftOutlined, InfoCircleOutlined} from '@ant-design/icons'; +import { settingsApi } from '../services/api'; +import type { SettingsUpdate } from '../types'; + +const { Title, Paragraph } = Typography; +const { Option } = Select; + +export default function SettingsPage() { + const navigate = useNavigate(); + const [form] = Form.useForm(); + const [loading, setLoading] = useState(false); + const [initialLoading, setInitialLoading] = useState(true); + const [hasSettings, setHasSettings] = useState(false); + const [isDefaultSettings, setIsDefaultSettings] = useState(false); + const [modelOptions, setModelOptions] = useState>([]); + const [fetchingModels, setFetchingModels] = useState(false); + const [modelsFetched, setModelsFetched] = useState(false); + + useEffect(() => { + loadSettings(); + }, []); + + const loadSettings = async () => { + setInitialLoading(true); + try { + const settings = await settingsApi.getSettings(); + form.setFieldsValue(settings); + + // 判断是否为默认设置(id='0'表示来自.env的默认配置) + if (settings.id === '0' || !settings.id) { + setIsDefaultSettings(true); + setHasSettings(false); + } else { + setIsDefaultSettings(false); + setHasSettings(true); + } + } catch (error: any) { + // 如果404表示还没有设置,使用默认值 + if (error?.response?.status === 404) { + setHasSettings(false); + setIsDefaultSettings(true); + form.setFieldsValue({ + api_provider: 'openai', + api_base_url: 'https://api.openai.com/v1', + model_name: 'gpt-4', + temperature: 0.7, + max_tokens: 2000, + }); + } else { + message.error('加载设置失败'); + } + } finally { + setInitialLoading(false); + } + }; + + const handleSave = async (values: SettingsUpdate) => { + setLoading(true); + try { + await settingsApi.saveSettings(values); + message.success('设置已保存'); + setHasSettings(true); + setIsDefaultSettings(false); + } catch (error) { + message.error('保存设置失败'); + } finally { + setLoading(false); + } + }; + + const handleReset = () => { + Modal.confirm({ + title: '重置设置', + content: '确定要重置为默认值吗?', + okText: '确定', + cancelText: '取消', + onOk: () => { + form.setFieldsValue({ + api_provider: 'openai', + api_key: '', + api_base_url: 'https://api.openai.com/v1', + model_name: 'gpt-4', + temperature: 0.7, + max_tokens: 2000, + }); + message.info('已重置为默认值,请点击保存'); + }, + }); + }; + + const handleDelete = () => { + Modal.confirm({ + title: '删除设置', + content: '确定要删除所有设置吗?此操作不可恢复。', + okText: '确定', + cancelText: '取消', + okType: 'danger', + onOk: async () => { + setLoading(true); + try { + await settingsApi.deleteSettings(); + message.success('设置已删除'); + setHasSettings(false); + form.resetFields(); + } catch (error) { + message.error('删除设置失败'); + } finally { + setLoading(false); + } + }, + }); + }; + + const apiProviders = [ + { value: 'openai', label: 'OpenAI', defaultUrl: 'https://api.openai.com/v1' }, + { value: 'azure', label: 'Azure OpenAI', defaultUrl: 'https://YOUR-RESOURCE.openai.azure.com' }, + { value: 'anthropic', label: 'Anthropic', defaultUrl: 'https://api.anthropic.com' }, + { value: 'custom', label: '自定义', defaultUrl: '' }, + ]; + + const handleProviderChange = (value: string) => { + const provider = apiProviders.find(p => p.value === value); + if (provider && provider.defaultUrl) { + form.setFieldValue('api_base_url', provider.defaultUrl); + } + // 清空模型列表,需要重新获取 + setModelOptions([]); + setModelsFetched(false); + }; + + const handleFetchModels = async (silent: boolean = false) => { + const apiKey = form.getFieldValue('api_key'); + const apiBaseUrl = form.getFieldValue('api_base_url'); + const provider = form.getFieldValue('api_provider'); + + if (!apiKey || !apiBaseUrl) { + if (!silent) { + message.warning('请先填写 API 密钥和 API 地址'); + } + return; + } + + setFetchingModels(true); + try { + const response = await settingsApi.getAvailableModels({ + api_key: apiKey, + api_base_url: apiBaseUrl, + provider: provider || 'openai' + }); + + setModelOptions(response.models); + setModelsFetched(true); + if (!silent) { + message.success(`成功获取 ${response.count || response.models.length} 个可用模型`); + } + } catch (error: any) { + const errorMsg = error?.response?.data?.detail || '获取模型列表失败'; + if (!silent) { + message.error(errorMsg); + } + setModelOptions([]); + setModelsFetched(true); // 即使失败也标记为已尝试,避免重复请求 + } finally { + setFetchingModels(false); + } + }; + + const handleModelSelectFocus = () => { + // 如果还没有获取过模型列表,自动获取 + if (!modelsFetched && !fetchingModels) { + handleFetchModels(true); // silent模式,不显示成功消息 + } + }; + + return ( +
+
+ + + {/* 标题栏 */} +
+ +
+ + + 配置你的AI API接口参数,这些设置将用于小说生成、角色创建等AI功能。 + + + {/* 默认配置提示 */} + {isDefaultSettings && ( + +

+ 当前显示的是从服务器 .env 文件读取的默认配置。 +

+

+ 点击"保存设置"后,配置将保存到数据库并同步更新到 .env 文件。 +

+
+ } + type="info" + showIcon + style={{ marginBottom: 16 }} + /> + )} + + {/* 已保存配置提示 */} + {hasSettings && !isDefaultSettings && ( + + )} + + {/* 表单 */} + +
+ + API 提供商 + + + + + } + name="api_provider" + rules={[{ required: true, message: '请选择API提供商' }]} + > + + + + + API 密钥 + + + + + } + name="api_key" + rules={[{ required: true, message: '请输入API密钥' }]} + > + + + + + API 地址 + + + + + } + name="api_base_url" + rules={[ + { required: true, message: '请输入API地址' }, + { type: 'url', message: '请输入有效的URL' } + ]} + > + + + + + 模型名称 + + + + + } + name="model_name" + rules={[{ required: true, message: '请输入或选择模型名称' }]} + > +