支持自定义API接口
This commit is contained in:
@@ -18,9 +18,10 @@ from app.schemas.chapter import (
|
|||||||
ChapterResponse,
|
ChapterResponse,
|
||||||
ChapterListResponse
|
ChapterListResponse
|
||||||
)
|
)
|
||||||
from app.services.ai_service import ai_service
|
from app.services.ai_service import AIService
|
||||||
from app.services.prompt_service import prompt_service
|
from app.services.prompt_service import prompt_service
|
||||||
from app.logger import get_logger
|
from app.logger import get_logger
|
||||||
|
from app.api.settings import get_user_ai_service
|
||||||
|
|
||||||
router = APIRouter(prefix="/chapters", tags=["章节管理"])
|
router = APIRouter(prefix="/chapters", tags=["章节管理"])
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
@@ -247,7 +248,8 @@ async def check_can_generate(
|
|||||||
@router.post("/{chapter_id}/generate", summary="AI创作章节内容")
|
@router.post("/{chapter_id}/generate", summary="AI创作章节内容")
|
||||||
async def generate_chapter_content(
|
async def generate_chapter_content(
|
||||||
chapter_id: str,
|
chapter_id: str,
|
||||||
db: AsyncSession = Depends(get_db)
|
db: AsyncSession = Depends(get_db),
|
||||||
|
user_ai_service: AIService = Depends(get_user_ai_service)
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
根据大纲、前置章节内容和项目信息AI创作章节完整内容
|
根据大纲、前置章节内容和项目信息AI创作章节完整内容
|
||||||
@@ -372,7 +374,7 @@ async def generate_chapter_content(
|
|||||||
logger.info(f"开始AI创作章节 {chapter_id}")
|
logger.info(f"开始AI创作章节 {chapter_id}")
|
||||||
|
|
||||||
# 调用AI生成
|
# 调用AI生成
|
||||||
ai_content = await ai_service.generate_text(
|
ai_content = await user_ai_service.generate_text(
|
||||||
prompt=prompt
|
prompt=prompt
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -410,7 +412,8 @@ async def generate_chapter_content(
|
|||||||
@router.post("/{chapter_id}/generate-stream", summary="AI创作章节内容(流式)")
|
@router.post("/{chapter_id}/generate-stream", summary="AI创作章节内容(流式)")
|
||||||
async def generate_chapter_content_stream(
|
async def generate_chapter_content_stream(
|
||||||
chapter_id: str,
|
chapter_id: str,
|
||||||
request: Request
|
request: Request,
|
||||||
|
user_ai_service: AIService = Depends(get_user_ai_service)
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
根据大纲、前置章节内容和项目信息AI创作章节完整内容(流式返回)
|
根据大纲、前置章节内容和项目信息AI创作章节完整内容(流式返回)
|
||||||
@@ -569,7 +572,7 @@ async def generate_chapter_content_stream(
|
|||||||
|
|
||||||
# 流式生成内容
|
# 流式生成内容
|
||||||
full_content = ""
|
full_content = ""
|
||||||
async for chunk in ai_service.generate_text_stream(prompt=prompt):
|
async for chunk in user_ai_service.generate_text_stream(prompt=prompt):
|
||||||
full_content += chunk
|
full_content += chunk
|
||||||
yield f"data: {json.dumps({'type': 'content', 'content': chunk}, ensure_ascii=False)}\n\n"
|
yield f"data: {json.dumps({'type': 'content', 'content': chunk}, ensure_ascii=False)}\n\n"
|
||||||
await asyncio.sleep(0) # 让出控制权
|
await asyncio.sleep(0) # 让出控制权
|
||||||
|
|||||||
@@ -15,9 +15,10 @@ from app.schemas.character import (
|
|||||||
CharacterListResponse,
|
CharacterListResponse,
|
||||||
CharacterGenerateRequest
|
CharacterGenerateRequest
|
||||||
)
|
)
|
||||||
from app.services.ai_service import ai_service
|
from app.services.ai_service import AIService
|
||||||
from app.services.prompt_service import prompt_service
|
from app.services.prompt_service import prompt_service
|
||||||
from app.logger import get_logger
|
from app.logger import get_logger
|
||||||
|
from app.api.settings import get_user_ai_service
|
||||||
|
|
||||||
router = APIRouter(prefix="/characters", tags=["角色管理"])
|
router = APIRouter(prefix="/characters", tags=["角色管理"])
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
@@ -134,7 +135,8 @@ async def delete_character(
|
|||||||
@router.post("/generate", response_model=CharacterResponse, summary="AI生成角色")
|
@router.post("/generate", response_model=CharacterResponse, summary="AI生成角色")
|
||||||
async def generate_character(
|
async def generate_character(
|
||||||
request: CharacterGenerateRequest,
|
request: CharacterGenerateRequest,
|
||||||
db: AsyncSession = Depends(get_db)
|
db: AsyncSession = Depends(get_db),
|
||||||
|
user_ai_service: AIService = Depends(get_user_ai_service)
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
使用AI生成角色卡
|
使用AI生成角色卡
|
||||||
@@ -216,7 +218,7 @@ async def generate_character(
|
|||||||
logger.info(f" - Prompt长度:{len(prompt)} 字符")
|
logger.info(f" - Prompt长度:{len(prompt)} 字符")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
ai_response = await ai_service.generate_text(
|
ai_response = await user_ai_service.generate_text(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
provider=request.provider,
|
provider=request.provider,
|
||||||
model=request.model
|
model=request.model
|
||||||
|
|||||||
@@ -19,9 +19,10 @@ from app.schemas.outline import (
|
|||||||
OutlineGenerateRequest,
|
OutlineGenerateRequest,
|
||||||
OutlineReorderRequest
|
OutlineReorderRequest
|
||||||
)
|
)
|
||||||
from app.services.ai_service import ai_service
|
from app.services.ai_service import AIService
|
||||||
from app.services.prompt_service import prompt_service
|
from app.services.prompt_service import prompt_service
|
||||||
from app.logger import get_logger
|
from app.logger import get_logger
|
||||||
|
from app.api.settings import get_user_ai_service
|
||||||
|
|
||||||
router = APIRouter(prefix="/outlines", tags=["大纲管理"])
|
router = APIRouter(prefix="/outlines", tags=["大纲管理"])
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
@@ -326,7 +327,8 @@ async def reorder_outlines(
|
|||||||
@router.post("/generate", response_model=OutlineListResponse, summary="AI生成/续写大纲")
|
@router.post("/generate", response_model=OutlineListResponse, summary="AI生成/续写大纲")
|
||||||
async def generate_outline(
|
async def generate_outline(
|
||||||
request: OutlineGenerateRequest,
|
request: OutlineGenerateRequest,
|
||||||
db: AsyncSession = Depends(get_db)
|
db: AsyncSession = Depends(get_db),
|
||||||
|
user_ai_service: AIService = Depends(get_user_ai_service)
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
使用AI生成或续写小说大纲 - 智能模式
|
使用AI生成或续写小说大纲 - 智能模式
|
||||||
@@ -363,7 +365,7 @@ async def generate_outline(
|
|||||||
# 模式:全新生成
|
# 模式:全新生成
|
||||||
if actual_mode == "new":
|
if actual_mode == "new":
|
||||||
return await _generate_new_outline(
|
return await _generate_new_outline(
|
||||||
request, project, db
|
request, project, db, user_ai_service
|
||||||
)
|
)
|
||||||
|
|
||||||
# 模式:续写
|
# 模式:续写
|
||||||
@@ -375,7 +377,7 @@ async def generate_outline(
|
|||||||
)
|
)
|
||||||
|
|
||||||
return await _continue_outline(
|
return await _continue_outline(
|
||||||
request, project, existing_outlines, db
|
request, project, existing_outlines, db, user_ai_service
|
||||||
)
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
@@ -394,7 +396,8 @@ async def generate_outline(
|
|||||||
async def _generate_new_outline(
|
async def _generate_new_outline(
|
||||||
request: OutlineGenerateRequest,
|
request: OutlineGenerateRequest,
|
||||||
project: Project,
|
project: Project,
|
||||||
db: AsyncSession
|
db: AsyncSession,
|
||||||
|
user_ai_service: AIService
|
||||||
) -> OutlineListResponse:
|
) -> OutlineListResponse:
|
||||||
"""全新生成大纲"""
|
"""全新生成大纲"""
|
||||||
logger.info(f"全新生成大纲 - 项目: {project.id}, keep_existing: {request.keep_existing}")
|
logger.info(f"全新生成大纲 - 项目: {project.id}, keep_existing: {request.keep_existing}")
|
||||||
@@ -427,7 +430,7 @@ async def _generate_new_outline(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# 调用AI
|
# 调用AI
|
||||||
ai_response = await ai_service.generate_text(
|
ai_response = await user_ai_service.generate_text(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
provider=request.provider,
|
provider=request.provider,
|
||||||
model=request.model
|
model=request.model
|
||||||
@@ -473,7 +476,8 @@ async def _continue_outline(
|
|||||||
request: OutlineGenerateRequest,
|
request: OutlineGenerateRequest,
|
||||||
project: Project,
|
project: Project,
|
||||||
existing_outlines: List[Outline],
|
existing_outlines: List[Outline],
|
||||||
db: AsyncSession
|
db: AsyncSession,
|
||||||
|
user_ai_service: AIService
|
||||||
) -> OutlineListResponse:
|
) -> OutlineListResponse:
|
||||||
"""续写大纲"""
|
"""续写大纲"""
|
||||||
logger.info(f"续写大纲 - 项目: {project.id}, 已有: {len(existing_outlines)} 章")
|
logger.info(f"续写大纲 - 项目: {project.id}, 已有: {len(existing_outlines)} 章")
|
||||||
@@ -536,7 +540,7 @@ async def _continue_outline(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# 调用AI
|
# 调用AI
|
||||||
ai_response = await ai_service.generate_text(
|
ai_response = await user_ai_service.generate_text(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
provider=request.provider,
|
provider=request.provider,
|
||||||
model=request.model
|
model=request.model
|
||||||
|
|||||||
@@ -5,9 +5,10 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|||||||
from app.database import get_db
|
from app.database import get_db
|
||||||
from app.models.generation_history import GenerationHistory
|
from app.models.generation_history import GenerationHistory
|
||||||
from app.schemas.polish import PolishRequest, PolishResponse
|
from app.schemas.polish import PolishRequest, PolishResponse
|
||||||
from app.services.ai_service import ai_service
|
from app.services.ai_service import AIService
|
||||||
from app.services.prompt_service import prompt_service
|
from app.services.prompt_service import prompt_service
|
||||||
from app.logger import get_logger
|
from app.logger import get_logger
|
||||||
|
from app.api.settings import get_user_ai_service
|
||||||
|
|
||||||
router = APIRouter(prefix="/polish", tags=["AI去味"])
|
router = APIRouter(prefix="/polish", tags=["AI去味"])
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
@@ -16,7 +17,8 @@ logger = get_logger(__name__)
|
|||||||
@router.post("", response_model=PolishResponse, summary="AI去味")
|
@router.post("", response_model=PolishResponse, summary="AI去味")
|
||||||
async def polish_text(
|
async def polish_text(
|
||||||
request: PolishRequest,
|
request: PolishRequest,
|
||||||
db: AsyncSession = Depends(get_db)
|
db: AsyncSession = Depends(get_db),
|
||||||
|
user_ai_service: AIService = Depends(get_user_ai_service)
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
AI去味 - 将AI生成的文本改写得更像人类作家的手笔
|
AI去味 - 将AI生成的文本改写得更像人类作家的手笔
|
||||||
@@ -83,7 +85,8 @@ async def polish_batch(
|
|||||||
project_id: int = None,
|
project_id: int = None,
|
||||||
provider: str = None,
|
provider: str = None,
|
||||||
model: str = None,
|
model: str = None,
|
||||||
db: AsyncSession = Depends(get_db)
|
db: AsyncSession = Depends(get_db),
|
||||||
|
user_ai_service: AIService = Depends(get_user_ai_service)
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
批量处理多个文本的AI去味
|
批量处理多个文本的AI去味
|
||||||
@@ -98,7 +101,7 @@ async def polish_batch(
|
|||||||
|
|
||||||
prompt = prompt_service.get_denoising_prompt(original_text=text)
|
prompt = prompt_service.get_denoising_prompt(original_text=text)
|
||||||
|
|
||||||
polished_text = await ai_service.generate_text(
|
polished_text = await user_ai_service.generate_text(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
provider=provider,
|
provider=provider,
|
||||||
model=model
|
model=model
|
||||||
|
|||||||
@@ -0,0 +1,299 @@
|
|||||||
|
"""
|
||||||
|
设置管理 API
|
||||||
|
"""
|
||||||
|
from fastapi import APIRouter, HTTPException, Request, Depends
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
from sqlalchemy import select
|
||||||
|
from typing import Dict, Any, List
|
||||||
|
from pathlib import Path
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
from app.database import get_db
|
||||||
|
from app.models.settings import Settings
|
||||||
|
from app.schemas.settings import SettingsCreate, SettingsUpdate, SettingsResponse
|
||||||
|
from app.user_manager import User
|
||||||
|
from app.logger import get_logger
|
||||||
|
from app.config import settings as app_settings, PROJECT_ROOT
|
||||||
|
from app.services.ai_service import AIService, create_user_ai_service
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/settings", tags=["设置管理"])
|
||||||
|
|
||||||
|
|
||||||
|
def read_env_defaults() -> Dict[str, Any]:
|
||||||
|
"""从.env文件读取默认配置(仅读取,不修改)"""
|
||||||
|
return {
|
||||||
|
"api_provider": app_settings.default_ai_provider,
|
||||||
|
"api_key": app_settings.openai_api_key or app_settings.anthropic_api_key or "",
|
||||||
|
"api_base_url": app_settings.openai_base_url or app_settings.anthropic_base_url or "",
|
||||||
|
"model_name": app_settings.default_model,
|
||||||
|
"temperature": app_settings.default_temperature,
|
||||||
|
"max_tokens": app_settings.default_max_tokens,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def require_login(request: Request):
|
||||||
|
"""依赖:要求用户已登录"""
|
||||||
|
if not hasattr(request.state, "user") or not request.state.user:
|
||||||
|
raise HTTPException(status_code=401, detail="需要登录")
|
||||||
|
return request.state.user
|
||||||
|
|
||||||
|
|
||||||
|
async def get_user_ai_service(
|
||||||
|
user: User = Depends(require_login),
|
||||||
|
db: AsyncSession = Depends(get_db)
|
||||||
|
) -> AIService:
|
||||||
|
"""
|
||||||
|
依赖:获取当前用户的AI服务实例
|
||||||
|
从数据库读取用户设置并创建对应的AI服务
|
||||||
|
"""
|
||||||
|
result = await db.execute(
|
||||||
|
select(Settings).where(Settings.user_id == user.user_id)
|
||||||
|
)
|
||||||
|
settings = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if not settings:
|
||||||
|
# 如果用户没有设置,从.env读取并保存
|
||||||
|
env_defaults = read_env_defaults()
|
||||||
|
settings = Settings(
|
||||||
|
user_id=user.user_id,
|
||||||
|
**env_defaults
|
||||||
|
)
|
||||||
|
db.add(settings)
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(settings)
|
||||||
|
logger.info(f"用户 {user.user_id} 首次使用AI服务,已从.env同步设置到数据库")
|
||||||
|
|
||||||
|
# 使用用户设置创建AI服务实例
|
||||||
|
return create_user_ai_service(
|
||||||
|
api_provider=settings.api_provider,
|
||||||
|
api_key=settings.api_key,
|
||||||
|
api_base_url=settings.api_base_url or "",
|
||||||
|
model_name=settings.model_name,
|
||||||
|
temperature=settings.temperature,
|
||||||
|
max_tokens=settings.max_tokens
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("", response_model=SettingsResponse)
|
||||||
|
async def get_settings(
|
||||||
|
user: User = Depends(require_login),
|
||||||
|
db: AsyncSession = Depends(get_db)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
获取当前用户的设置
|
||||||
|
如果用户没有保存过设置,自动从.env创建并保存到数据库
|
||||||
|
"""
|
||||||
|
result = await db.execute(
|
||||||
|
select(Settings).where(Settings.user_id == user.user_id)
|
||||||
|
)
|
||||||
|
settings = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if not settings:
|
||||||
|
# 如果用户没有保存过设置,从.env读取默认配置并保存到数据库
|
||||||
|
env_defaults = read_env_defaults()
|
||||||
|
logger.info(f"用户 {user.user_id} 首次获取设置,自动从.env同步到数据库")
|
||||||
|
|
||||||
|
# 创建新设置并保存到数据库
|
||||||
|
settings = Settings(
|
||||||
|
user_id=user.user_id,
|
||||||
|
**env_defaults
|
||||||
|
)
|
||||||
|
db.add(settings)
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(settings)
|
||||||
|
logger.info(f"用户 {user.user_id} 的设置已从.env同步到数据库")
|
||||||
|
|
||||||
|
logger.info(f"用户 {user.user_id} 获取已保存的设置")
|
||||||
|
return settings
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("", response_model=SettingsResponse)
|
||||||
|
async def save_settings(
|
||||||
|
data: SettingsCreate,
|
||||||
|
user: User = Depends(require_login),
|
||||||
|
db: AsyncSession = Depends(get_db)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
创建或更新当前用户的设置(Upsert)
|
||||||
|
如果设置已存在则更新,否则创建新设置
|
||||||
|
仅保存到数据库
|
||||||
|
"""
|
||||||
|
# 查找现有设置
|
||||||
|
result = await db.execute(
|
||||||
|
select(Settings).where(Settings.user_id == user.user_id)
|
||||||
|
)
|
||||||
|
settings = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
# 准备数据
|
||||||
|
settings_dict = data.model_dump(exclude_unset=True)
|
||||||
|
|
||||||
|
if settings:
|
||||||
|
# 更新现有设置
|
||||||
|
for key, value in settings_dict.items():
|
||||||
|
setattr(settings, key, value)
|
||||||
|
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(settings)
|
||||||
|
logger.info(f"用户 {user.user_id} 更新设置")
|
||||||
|
else:
|
||||||
|
# 创建新设置
|
||||||
|
settings = Settings(
|
||||||
|
user_id=user.user_id,
|
||||||
|
**settings_dict
|
||||||
|
)
|
||||||
|
db.add(settings)
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(settings)
|
||||||
|
logger.info(f"用户 {user.user_id} 创建设置")
|
||||||
|
|
||||||
|
return settings
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("", response_model=SettingsResponse)
|
||||||
|
async def update_settings(
|
||||||
|
data: SettingsUpdate,
|
||||||
|
user: User = Depends(require_login),
|
||||||
|
db: AsyncSession = Depends(get_db)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
更新当前用户的设置
|
||||||
|
仅保存到数据库
|
||||||
|
"""
|
||||||
|
result = await db.execute(
|
||||||
|
select(Settings).where(Settings.user_id == user.user_id)
|
||||||
|
)
|
||||||
|
settings = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if not settings:
|
||||||
|
raise HTTPException(status_code=404, detail="设置不存在,请先创建设置")
|
||||||
|
|
||||||
|
# 更新设置
|
||||||
|
update_data = data.model_dump(exclude_unset=True)
|
||||||
|
for key, value in update_data.items():
|
||||||
|
setattr(settings, key, value)
|
||||||
|
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(settings)
|
||||||
|
logger.info(f"用户 {user.user_id} 更新设置")
|
||||||
|
|
||||||
|
return settings
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("")
|
||||||
|
async def delete_settings(
|
||||||
|
user: User = Depends(require_login),
|
||||||
|
db: AsyncSession = Depends(get_db)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
删除当前用户的设置
|
||||||
|
"""
|
||||||
|
result = await db.execute(
|
||||||
|
select(Settings).where(Settings.user_id == user.user_id)
|
||||||
|
)
|
||||||
|
settings = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if not settings:
|
||||||
|
raise HTTPException(status_code=404, detail="设置不存在")
|
||||||
|
|
||||||
|
await db.delete(settings)
|
||||||
|
await db.commit()
|
||||||
|
logger.info(f"用户 {user.user_id} 删除设置")
|
||||||
|
|
||||||
|
return {"message": "设置已删除", "user_id": user.user_id}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/models")
|
||||||
|
async def get_available_models(
|
||||||
|
api_key: str,
|
||||||
|
api_base_url: str,
|
||||||
|
provider: str = "openai"
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
从配置的 API 获取可用的模型列表
|
||||||
|
|
||||||
|
Args:
|
||||||
|
api_key: API 密钥
|
||||||
|
api_base_url: API 基础 URL
|
||||||
|
provider: API 提供商 (openai, anthropic, azure, custom)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
模型列表
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||||
|
if provider == "openai" or provider == "azure" or provider == "custom":
|
||||||
|
# OpenAI 兼容接口获取模型列表
|
||||||
|
url = f"{api_base_url.rstrip('/')}/models"
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {api_key}",
|
||||||
|
"Content-Type": "application/json"
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.info(f"正在从 {url} 获取模型列表")
|
||||||
|
response = await client.get(url, headers=headers)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
data = response.json()
|
||||||
|
models = []
|
||||||
|
|
||||||
|
if "data" in data and isinstance(data["data"], list):
|
||||||
|
for model in data["data"]:
|
||||||
|
model_id = model.get("id", "")
|
||||||
|
# 过滤出常用的文本生成模型
|
||||||
|
if any(keyword in model_id.lower() for keyword in [
|
||||||
|
"gpt", "gemini", "claude", "llama", "mistral", "qwen", "deepseek"
|
||||||
|
]):
|
||||||
|
models.append({
|
||||||
|
"value": model_id,
|
||||||
|
"label": model_id,
|
||||||
|
"description": model.get("description", "") or f"Created: {model.get('created', 'N/A')}"
|
||||||
|
})
|
||||||
|
|
||||||
|
if not models:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=404,
|
||||||
|
detail="未能从 API 获取到可用的模型列表"
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"成功获取 {len(models)} 个模型")
|
||||||
|
return {
|
||||||
|
"provider": provider,
|
||||||
|
"models": models,
|
||||||
|
"count": len(models)
|
||||||
|
}
|
||||||
|
|
||||||
|
elif provider == "anthropic":
|
||||||
|
# Anthropic 没有公开的模型列表API
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail="Anthropic 不支持自动获取模型列表,请手动输入模型名称"
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail=f"不支持的提供商: {provider}"
|
||||||
|
)
|
||||||
|
|
||||||
|
except httpx.HTTPStatusError as e:
|
||||||
|
logger.error(f"获取模型列表失败 (HTTP {e.response.status_code}): {e.response.text}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail=f"无法从 API 获取模型列表 (HTTP {e.response.status_code})"
|
||||||
|
)
|
||||||
|
except httpx.RequestError as e:
|
||||||
|
logger.error(f"请求模型列表失败: {str(e)}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail=f"无法连接到 API: {str(e)}"
|
||||||
|
)
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"获取模型列表时发生错误: {str(e)}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500,
|
||||||
|
detail=f"获取模型列表失败: {str(e)}"
|
||||||
|
)
|
||||||
@@ -4,6 +4,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from typing import Dict, Any, AsyncGenerator
|
from typing import Dict, Any, AsyncGenerator
|
||||||
import json
|
import json
|
||||||
|
import re
|
||||||
|
|
||||||
from app.database import get_db
|
from app.database import get_db
|
||||||
from app.models.project import Project
|
from app.models.project import Project
|
||||||
@@ -11,10 +12,11 @@ from app.models.character import Character
|
|||||||
from app.models.outline import Outline
|
from app.models.outline import Outline
|
||||||
from app.models.chapter import Chapter
|
from app.models.chapter import Chapter
|
||||||
from app.models.relationship import CharacterRelationship, Organization, OrganizationMember, RelationshipType
|
from app.models.relationship import CharacterRelationship, Organization, OrganizationMember, RelationshipType
|
||||||
from app.services.ai_service import ai_service
|
from app.services.ai_service import AIService
|
||||||
from app.services.prompt_service import prompt_service
|
from app.services.prompt_service import prompt_service
|
||||||
from app.logger import get_logger
|
from app.logger import get_logger
|
||||||
from app.utils.sse_response import SSEResponse, create_sse_response
|
from app.utils.sse_response import SSEResponse, create_sse_response
|
||||||
|
from app.api.settings import get_user_ai_service
|
||||||
|
|
||||||
router = APIRouter(prefix="/wizard-stream", tags=["项目创建向导(流式)"])
|
router = APIRouter(prefix="/wizard-stream", tags=["项目创建向导(流式)"])
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
@@ -22,7 +24,8 @@ logger = get_logger(__name__)
|
|||||||
|
|
||||||
async def world_building_generator(
|
async def world_building_generator(
|
||||||
data: Dict[str, Any],
|
data: Dict[str, Any],
|
||||||
db: AsyncSession
|
db: AsyncSession,
|
||||||
|
user_ai_service: AIService
|
||||||
) -> AsyncGenerator[str, None]:
|
) -> AsyncGenerator[str, None]:
|
||||||
"""世界构建流式生成器"""
|
"""世界构建流式生成器"""
|
||||||
# 标记数据库会话是否已提交
|
# 标记数据库会话是否已提交
|
||||||
@@ -61,7 +64,7 @@ async def world_building_generator(
|
|||||||
accumulated_text = ""
|
accumulated_text = ""
|
||||||
chunk_count = 0
|
chunk_count = 0
|
||||||
|
|
||||||
async for chunk in ai_service.generate_text_stream(
|
async for chunk in user_ai_service.generate_text_stream(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
provider=provider,
|
provider=provider,
|
||||||
model=model
|
model=model
|
||||||
@@ -87,24 +90,26 @@ async def world_building_generator(
|
|||||||
world_data = {}
|
world_data = {}
|
||||||
try:
|
try:
|
||||||
cleaned_text = accumulated_text.strip()
|
cleaned_text = accumulated_text.strip()
|
||||||
|
|
||||||
|
# 移除markdown代码块标记
|
||||||
if cleaned_text.startswith('```json'):
|
if cleaned_text.startswith('```json'):
|
||||||
cleaned_text = cleaned_text[7:]
|
cleaned_text = cleaned_text[7:].lstrip('\n\r')
|
||||||
if cleaned_text.startswith('```'):
|
elif cleaned_text.startswith('```'):
|
||||||
cleaned_text = cleaned_text[3:]
|
cleaned_text = cleaned_text[3:].lstrip('\n\r')
|
||||||
if cleaned_text.endswith('```'):
|
if cleaned_text.endswith('```'):
|
||||||
cleaned_text = cleaned_text[:-3]
|
cleaned_text = cleaned_text[:-3].rstrip('\n\r')
|
||||||
cleaned_text = cleaned_text.strip()
|
cleaned_text = cleaned_text.strip()
|
||||||
|
|
||||||
world_data = json.loads(cleaned_text)
|
world_data = json.loads(cleaned_text)
|
||||||
|
|
||||||
except json.JSONDecodeError as e:
|
except json.JSONDecodeError as e:
|
||||||
logger.error(f"AI返回非JSON格式: {e}")
|
logger.error(f"世界构建JSON解析失败: {e}")
|
||||||
world_data = {
|
world_data = {
|
||||||
"time_period": accumulated_text[:300] if len(accumulated_text) > 300 else accumulated_text,
|
"time_period": "AI返回格式错误,请重试",
|
||||||
"location": "AI返回格式错误,请重试",
|
"location": "AI返回格式错误,请重试",
|
||||||
"atmosphere": "AI返回格式错误,请重试",
|
"atmosphere": "AI返回格式错误,请重试",
|
||||||
"rules": "AI返回格式错误,请重试"
|
"rules": "AI返回格式错误,请重试"
|
||||||
}
|
}
|
||||||
|
|
||||||
# 保存到数据库
|
# 保存到数据库
|
||||||
yield await SSEResponse.send_progress("保存到数据库...", 90)
|
yield await SSEResponse.send_progress("保存到数据库...", 90)
|
||||||
|
|
||||||
@@ -160,18 +165,20 @@ async def world_building_generator(
|
|||||||
@router.post("/world-building", summary="流式生成世界构建")
|
@router.post("/world-building", summary="流式生成世界构建")
|
||||||
async def generate_world_building_stream(
|
async def generate_world_building_stream(
|
||||||
data: Dict[str, Any],
|
data: Dict[str, Any],
|
||||||
db: AsyncSession = Depends(get_db)
|
db: AsyncSession = Depends(get_db),
|
||||||
|
user_ai_service: AIService = Depends(get_user_ai_service)
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
使用SSE流式生成世界构建,避免超时
|
使用SSE流式生成世界构建,避免超时
|
||||||
前端使用EventSource接收实时进度和结果
|
前端使用EventSource接收实时进度和结果
|
||||||
"""
|
"""
|
||||||
return create_sse_response(world_building_generator(data, db))
|
return create_sse_response(world_building_generator(data, db, user_ai_service))
|
||||||
|
|
||||||
|
|
||||||
async def characters_generator(
|
async def characters_generator(
|
||||||
data: Dict[str, Any],
|
data: Dict[str, Any],
|
||||||
db: AsyncSession
|
db: AsyncSession,
|
||||||
|
user_ai_service: AIService
|
||||||
) -> AsyncGenerator[str, None]:
|
) -> AsyncGenerator[str, None]:
|
||||||
"""角色批量生成流式生成器 - 优化版:分批+重试"""
|
"""角色批量生成流式生成器 - 优化版:分批+重试"""
|
||||||
db_committed = False
|
db_committed = False
|
||||||
@@ -270,7 +277,7 @@ async def characters_generator(
|
|||||||
|
|
||||||
# 流式生成
|
# 流式生成
|
||||||
accumulated_text = ""
|
accumulated_text = ""
|
||||||
async for chunk in ai_service.generate_text_stream(
|
async for chunk in user_ai_service.generate_text_stream(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
provider=provider,
|
provider=provider,
|
||||||
model=model
|
model=model
|
||||||
@@ -280,12 +287,13 @@ async def characters_generator(
|
|||||||
|
|
||||||
# 解析批次结果
|
# 解析批次结果
|
||||||
cleaned_text = accumulated_text.strip()
|
cleaned_text = accumulated_text.strip()
|
||||||
|
# 移除markdown代码块标记
|
||||||
if cleaned_text.startswith('```json'):
|
if cleaned_text.startswith('```json'):
|
||||||
cleaned_text = cleaned_text[7:]
|
cleaned_text = cleaned_text[7:].lstrip('\n\r')
|
||||||
if cleaned_text.startswith('```'):
|
elif cleaned_text.startswith('```'):
|
||||||
cleaned_text = cleaned_text[3:]
|
cleaned_text = cleaned_text[3:].lstrip('\n\r')
|
||||||
if cleaned_text.endswith('```'):
|
if cleaned_text.endswith('```'):
|
||||||
cleaned_text = cleaned_text[:-3]
|
cleaned_text = cleaned_text[:-3].rstrip('\n\r')
|
||||||
cleaned_text = cleaned_text.strip()
|
cleaned_text = cleaned_text.strip()
|
||||||
|
|
||||||
characters_data = json.loads(cleaned_text)
|
characters_data = json.loads(cleaned_text)
|
||||||
@@ -684,17 +692,19 @@ async def characters_generator(
|
|||||||
@router.post("/characters", summary="流式批量生成角色")
|
@router.post("/characters", summary="流式批量生成角色")
|
||||||
async def generate_characters_stream(
|
async def generate_characters_stream(
|
||||||
data: Dict[str, Any],
|
data: Dict[str, Any],
|
||||||
db: AsyncSession = Depends(get_db)
|
db: AsyncSession = Depends(get_db),
|
||||||
|
user_ai_service: AIService = Depends(get_user_ai_service)
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
使用SSE流式批量生成角色,避免超时
|
使用SSE流式批量生成角色,避免超时
|
||||||
"""
|
"""
|
||||||
return create_sse_response(characters_generator(data, db))
|
return create_sse_response(characters_generator(data, db, user_ai_service))
|
||||||
|
|
||||||
|
|
||||||
async def outline_generator(
|
async def outline_generator(
|
||||||
data: Dict[str, Any],
|
data: Dict[str, Any],
|
||||||
db: AsyncSession
|
db: AsyncSession,
|
||||||
|
user_ai_service: AIService
|
||||||
) -> AsyncGenerator[str, None]:
|
) -> AsyncGenerator[str, None]:
|
||||||
"""大纲生成流式生成器 - 向导固定生成前8章作为开局"""
|
"""大纲生成流式生成器 - 向导固定生成前8章作为开局"""
|
||||||
db_committed = False
|
db_committed = False
|
||||||
@@ -778,6 +788,7 @@ async def outline_generator(
|
|||||||
batch_requirements += "2. 建立主线冲突和故事钩子\n"
|
batch_requirements += "2. 建立主线冲突和故事钩子\n"
|
||||||
batch_requirements += "3. 展开初期情节,为后续发展埋下伏笔\n"
|
batch_requirements += "3. 展开初期情节,为后续发展埋下伏笔\n"
|
||||||
batch_requirements += "4. 不要试图完结故事,这只是开始部分\n"
|
batch_requirements += "4. 不要试图完结故事,这只是开始部分\n"
|
||||||
|
batch_requirements += "5. 不要在JSON字符串值中使用中文引号(""''),请使用【】或《》标记\n"
|
||||||
|
|
||||||
batch_prompt = prompt_service.get_complete_outline_prompt(
|
batch_prompt = prompt_service.get_complete_outline_prompt(
|
||||||
title=project.title,
|
title=project.title,
|
||||||
@@ -796,7 +807,7 @@ async def outline_generator(
|
|||||||
|
|
||||||
# 流式生成
|
# 流式生成
|
||||||
accumulated_text = ""
|
accumulated_text = ""
|
||||||
async for chunk in ai_service.generate_text_stream(
|
async for chunk in user_ai_service.generate_text_stream(
|
||||||
prompt=batch_prompt,
|
prompt=batch_prompt,
|
||||||
provider=provider,
|
provider=provider,
|
||||||
model=model
|
model=model
|
||||||
@@ -806,12 +817,14 @@ async def outline_generator(
|
|||||||
|
|
||||||
# 解析结果
|
# 解析结果
|
||||||
cleaned_text = accumulated_text.strip()
|
cleaned_text = accumulated_text.strip()
|
||||||
|
|
||||||
|
# 移除markdown代码块标记
|
||||||
if cleaned_text.startswith('```json'):
|
if cleaned_text.startswith('```json'):
|
||||||
cleaned_text = cleaned_text[7:]
|
cleaned_text = cleaned_text[7:].lstrip('\n\r')
|
||||||
if cleaned_text.startswith('```'):
|
elif cleaned_text.startswith('```'):
|
||||||
cleaned_text = cleaned_text[3:]
|
cleaned_text = cleaned_text[3:].lstrip('\n\r')
|
||||||
if cleaned_text.endswith('```'):
|
if cleaned_text.endswith('```'):
|
||||||
cleaned_text = cleaned_text[:-3]
|
cleaned_text = cleaned_text[:-3].rstrip('\n\r')
|
||||||
cleaned_text = cleaned_text.strip()
|
cleaned_text = cleaned_text.strip()
|
||||||
|
|
||||||
batch_outline_data = json.loads(cleaned_text)
|
batch_outline_data = json.loads(cleaned_text)
|
||||||
@@ -839,7 +852,7 @@ async def outline_generator(
|
|||||||
logger.info(f"批次{batch_idx+1}成功生成{len(batch_outline_data)}章大纲")
|
logger.info(f"批次{batch_idx+1}成功生成{len(batch_outline_data)}章大纲")
|
||||||
|
|
||||||
except json.JSONDecodeError as e:
|
except json.JSONDecodeError as e:
|
||||||
logger.error(f"批次{batch_idx+1}解析失败(尝试{retry_count+1}/{MAX_RETRIES}): {e}")
|
logger.error(f"大纲生成批次{batch_idx+1} JSON解析失败(尝试{retry_count+1}/{MAX_RETRIES}): {e}")
|
||||||
retry_count += 1
|
retry_count += 1
|
||||||
if retry_count < MAX_RETRIES:
|
if retry_count < MAX_RETRIES:
|
||||||
yield await SSEResponse.send_progress(
|
yield await SSEResponse.send_progress(
|
||||||
@@ -945,12 +958,13 @@ async def outline_generator(
|
|||||||
@router.post("/outline", summary="流式生成完整大纲")
|
@router.post("/outline", summary="流式生成完整大纲")
|
||||||
async def generate_outline_stream(
|
async def generate_outline_stream(
|
||||||
data: Dict[str, Any],
|
data: Dict[str, Any],
|
||||||
db: AsyncSession = Depends(get_db)
|
db: AsyncSession = Depends(get_db),
|
||||||
|
user_ai_service: AIService = Depends(get_user_ai_service)
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
使用SSE流式生成完整大纲,避免超时
|
使用SSE流式生成完整大纲,避免超时
|
||||||
"""
|
"""
|
||||||
return create_sse_response(outline_generator(data, db))
|
return create_sse_response(outline_generator(data, db, user_ai_service))
|
||||||
|
|
||||||
|
|
||||||
async def update_world_building_generator(
|
async def update_world_building_generator(
|
||||||
@@ -1037,7 +1051,8 @@ async def update_world_building_stream(
|
|||||||
async def regenerate_world_building_generator(
|
async def regenerate_world_building_generator(
|
||||||
project_id: str,
|
project_id: str,
|
||||||
data: Dict[str, Any],
|
data: Dict[str, Any],
|
||||||
db: AsyncSession
|
db: AsyncSession,
|
||||||
|
user_ai_service: AIService
|
||||||
) -> AsyncGenerator[str, None]:
|
) -> AsyncGenerator[str, None]:
|
||||||
"""重新生成世界观流式生成器"""
|
"""重新生成世界观流式生成器"""
|
||||||
db_committed = False
|
db_committed = False
|
||||||
@@ -1070,7 +1085,7 @@ async def regenerate_world_building_generator(
|
|||||||
accumulated_text = ""
|
accumulated_text = ""
|
||||||
chunk_count = 0
|
chunk_count = 0
|
||||||
|
|
||||||
async for chunk in ai_service.generate_text_stream(
|
async for chunk in user_ai_service.generate_text_stream(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
provider=provider,
|
provider=provider,
|
||||||
model=model
|
model=model
|
||||||
@@ -1096,19 +1111,21 @@ async def regenerate_world_building_generator(
|
|||||||
world_data = {}
|
world_data = {}
|
||||||
try:
|
try:
|
||||||
cleaned_text = accumulated_text.strip()
|
cleaned_text = accumulated_text.strip()
|
||||||
|
# 移除markdown代码块标记
|
||||||
if cleaned_text.startswith('```json'):
|
if cleaned_text.startswith('```json'):
|
||||||
cleaned_text = cleaned_text[7:]
|
cleaned_text = cleaned_text[7:].lstrip('\n\r')
|
||||||
if cleaned_text.startswith('```'):
|
elif cleaned_text.startswith('```'):
|
||||||
cleaned_text = cleaned_text[3:]
|
cleaned_text = cleaned_text[3:].lstrip('\n\r')
|
||||||
if cleaned_text.endswith('```'):
|
if cleaned_text.endswith('```'):
|
||||||
cleaned_text = cleaned_text[:-3]
|
cleaned_text = cleaned_text[:-3].rstrip('\n\r')
|
||||||
cleaned_text = cleaned_text.strip()
|
cleaned_text = cleaned_text.strip()
|
||||||
|
|
||||||
world_data = json.loads(cleaned_text)
|
world_data = json.loads(cleaned_text)
|
||||||
except json.JSONDecodeError as e:
|
except json.JSONDecodeError as e:
|
||||||
logger.error(f"AI返回非JSON格式: {e}")
|
logger.error(f"AI返回非JSON格式: {e}")
|
||||||
|
logger.info(world_data)
|
||||||
world_data = {
|
world_data = {
|
||||||
"time_period": accumulated_text[:300] if len(accumulated_text) > 300 else accumulated_text,
|
"time_period": "AI返回格式错误,请重试",
|
||||||
"location": "AI返回格式错误,请重试",
|
"location": "AI返回格式错误,请重试",
|
||||||
"atmosphere": "AI返回格式错误,请重试",
|
"atmosphere": "AI返回格式错误,请重试",
|
||||||
"rules": "AI返回格式错误,请重试"
|
"rules": "AI返回格式错误,请重试"
|
||||||
@@ -1155,7 +1172,8 @@ async def regenerate_world_building_generator(
|
|||||||
async def regenerate_world_building_stream(
|
async def regenerate_world_building_stream(
|
||||||
project_id: str,
|
project_id: str,
|
||||||
data: Dict[str, Any],
|
data: Dict[str, Any],
|
||||||
db: AsyncSession = Depends(get_db)
|
db: AsyncSession = Depends(get_db),
|
||||||
|
user_ai_service: AIService = Depends(get_user_ai_service)
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
使用SSE流式重新生成项目的世界观
|
使用SSE流式重新生成项目的世界观
|
||||||
@@ -1165,7 +1183,7 @@ async def regenerate_world_building_stream(
|
|||||||
"model": "模型名称(可选)"
|
"model": "模型名称(可选)"
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
return create_sse_response(regenerate_world_building_generator(project_id, data, db))
|
return create_sse_response(regenerate_world_building_generator(project_id, data, db, user_ai_service))
|
||||||
|
|
||||||
|
|
||||||
async def cleanup_wizard_data_generator(
|
async def cleanup_wizard_data_generator(
|
||||||
|
|||||||
+2
-1
@@ -114,11 +114,12 @@ async def db_session_stats():
|
|||||||
from app.api import (
|
from app.api import (
|
||||||
projects, outlines, characters, chapters,
|
projects, outlines, characters, chapters,
|
||||||
wizard_stream, relationships, organizations,
|
wizard_stream, relationships, organizations,
|
||||||
auth, users
|
auth, users, settings
|
||||||
)
|
)
|
||||||
|
|
||||||
app.include_router(auth.router, prefix="/api")
|
app.include_router(auth.router, prefix="/api")
|
||||||
app.include_router(users.router, prefix="/api")
|
app.include_router(users.router, prefix="/api")
|
||||||
|
app.include_router(settings.router, prefix="/api")
|
||||||
|
|
||||||
app.include_router(projects.router, prefix="/api")
|
app.include_router(projects.router, prefix="/api")
|
||||||
app.include_router(wizard_stream.router, prefix="/api")
|
app.include_router(wizard_stream.router, prefix="/api")
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
"""设置数据模型"""
|
"""设置数据模型"""
|
||||||
from sqlalchemy import Column, String, Text, Float, Integer, DateTime
|
from sqlalchemy import Column, String, Text, Float, Integer, DateTime, Index
|
||||||
from sqlalchemy.sql import func
|
from sqlalchemy.sql import func
|
||||||
from app.database import Base
|
from app.database import Base
|
||||||
import uuid
|
import uuid
|
||||||
@@ -10,6 +10,7 @@ class Settings(Base):
|
|||||||
__tablename__ = "settings"
|
__tablename__ = "settings"
|
||||||
|
|
||||||
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
|
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
|
||||||
|
user_id = Column(String(50), nullable=False, unique=True, index=True, comment="用户ID")
|
||||||
api_provider = Column(String(50), default="openai", comment="API提供商")
|
api_provider = Column(String(50), default="openai", comment="API提供商")
|
||||||
api_key = Column(String(500), comment="API密钥")
|
api_key = Column(String(500), comment="API密钥")
|
||||||
api_base_url = Column(String(500), comment="自定义API地址")
|
api_base_url = Column(String(500), comment="自定义API地址")
|
||||||
@@ -20,5 +21,9 @@ class Settings(Base):
|
|||||||
created_at = Column(DateTime, server_default=func.now(), comment="创建时间")
|
created_at = Column(DateTime, server_default=func.now(), comment="创建时间")
|
||||||
updated_at = Column(DateTime, server_default=func.now(), onupdate=func.now(), comment="更新时间")
|
updated_at = Column(DateTime, server_default=func.now(), onupdate=func.now(), comment="更新时间")
|
||||||
|
|
||||||
|
__table_args__ = (
|
||||||
|
Index('idx_user_id', 'user_id'),
|
||||||
|
)
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f"<Settings(id={self.id}, api_provider={self.api_provider})>"
|
return f"<Settings(id={self.id}, user_id={self.user_id}, api_provider={self.api_provider})>"
|
||||||
@@ -0,0 +1,37 @@
|
|||||||
|
"""设置相关的Pydantic模型"""
|
||||||
|
from pydantic import BaseModel, Field, ConfigDict
|
||||||
|
from typing import Optional
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
|
||||||
|
class SettingsBase(BaseModel):
|
||||||
|
"""设置基础模型"""
|
||||||
|
model_config = ConfigDict(protected_namespaces=())
|
||||||
|
|
||||||
|
api_provider: Optional[str] = Field(default="openai", description="API提供商")
|
||||||
|
api_key: Optional[str] = Field(default=None, description="API密钥")
|
||||||
|
api_base_url: Optional[str] = Field(default=None, description="自定义API地址")
|
||||||
|
model_name: Optional[str] = Field(default="gpt-4", description="模型名称")
|
||||||
|
temperature: Optional[float] = Field(default=0.7, ge=0.0, le=2.0, description="温度参数")
|
||||||
|
max_tokens: Optional[int] = Field(default=2000, ge=1, le=32000, description="最大token数")
|
||||||
|
preferences: Optional[str] = Field(default=None, description="其他偏好设置(JSON)")
|
||||||
|
|
||||||
|
|
||||||
|
class SettingsCreate(SettingsBase):
|
||||||
|
"""创建设置请求模型"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class SettingsUpdate(SettingsBase):
|
||||||
|
"""更新设置请求模型"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class SettingsResponse(SettingsBase):
|
||||||
|
"""设置响应模型"""
|
||||||
|
model_config = ConfigDict(from_attributes=True, protected_namespaces=())
|
||||||
|
|
||||||
|
id: str
|
||||||
|
user_id: str
|
||||||
|
created_at: datetime
|
||||||
|
updated_at: datetime
|
||||||
@@ -2,7 +2,7 @@
|
|||||||
from typing import Optional, AsyncGenerator, List, Dict, Any
|
from typing import Optional, AsyncGenerator, List, Dict, Any
|
||||||
from openai import AsyncOpenAI
|
from openai import AsyncOpenAI
|
||||||
from anthropic import AsyncAnthropic
|
from anthropic import AsyncAnthropic
|
||||||
from app.config import settings
|
from app.config import settings as app_settings
|
||||||
from app.logger import get_logger
|
from app.logger import get_logger
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
@@ -10,12 +10,37 @@ logger = get_logger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
class AIService:
|
class AIService:
|
||||||
"""AI服务统一接口"""
|
"""AI服务统一接口 - 支持从用户设置或全局配置初始化"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
api_provider: Optional[str] = None,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
api_base_url: Optional[str] = None,
|
||||||
|
default_model: Optional[str] = None,
|
||||||
|
default_temperature: Optional[float] = None,
|
||||||
|
default_max_tokens: Optional[int] = None
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
初始化AI客户端(优化并发性能)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
api_provider: API提供商 (openai/anthropic),为None时使用全局配置
|
||||||
|
api_key: API密钥,为None时使用全局配置
|
||||||
|
api_base_url: API基础URL,为None时使用全局配置
|
||||||
|
default_model: 默认模型,为None时使用全局配置
|
||||||
|
default_temperature: 默认温度,为None时使用全局配置
|
||||||
|
default_max_tokens: 默认最大tokens,为None时使用全局配置
|
||||||
|
"""
|
||||||
|
# 保存用户设置或使用全局配置
|
||||||
|
self.api_provider = api_provider or app_settings.default_ai_provider
|
||||||
|
self.default_model = default_model or app_settings.default_model
|
||||||
|
self.default_temperature = default_temperature or app_settings.default_temperature
|
||||||
|
self.default_max_tokens = default_max_tokens or app_settings.default_max_tokens
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
"""初始化AI客户端(优化并发性能)"""
|
|
||||||
# 初始化OpenAI客户端
|
# 初始化OpenAI客户端
|
||||||
if settings.openai_api_key:
|
openai_key = api_key if api_provider == "openai" else app_settings.openai_api_key
|
||||||
|
if openai_key:
|
||||||
# 创建自定义的httpx客户端来避免proxies参数问题
|
# 创建自定义的httpx客户端来避免proxies参数问题
|
||||||
try:
|
try:
|
||||||
# 配置连接池限制,支持高并发
|
# 配置连接池限制,支持高并发
|
||||||
@@ -43,12 +68,14 @@ class AIService:
|
|||||||
)
|
)
|
||||||
|
|
||||||
client_kwargs = {
|
client_kwargs = {
|
||||||
"api_key": settings.openai_api_key,
|
"api_key": openai_key,
|
||||||
"http_client": http_client
|
"http_client": http_client
|
||||||
}
|
}
|
||||||
|
|
||||||
if settings.openai_base_url:
|
# 优先使用用户提供的base_url,否则使用全局配置
|
||||||
client_kwargs["base_url"] = settings.openai_base_url
|
base_url = api_base_url if api_provider == "openai" else app_settings.openai_base_url
|
||||||
|
if base_url:
|
||||||
|
client_kwargs["base_url"] = base_url
|
||||||
|
|
||||||
self.openai_client = AsyncOpenAI(**client_kwargs)
|
self.openai_client = AsyncOpenAI(**client_kwargs)
|
||||||
logger.info("✅ OpenAI客户端初始化成功")
|
logger.info("✅ OpenAI客户端初始化成功")
|
||||||
@@ -62,7 +89,8 @@ class AIService:
|
|||||||
logger.warning("OpenAI API key未配置")
|
logger.warning("OpenAI API key未配置")
|
||||||
|
|
||||||
# 初始化Anthropic客户端
|
# 初始化Anthropic客户端
|
||||||
if settings.anthropic_api_key:
|
anthropic_key = api_key if api_provider == "anthropic" else app_settings.anthropic_api_key
|
||||||
|
if anthropic_key:
|
||||||
try:
|
try:
|
||||||
# 为Anthropic设置相同的超时和连接池配置
|
# 为Anthropic设置相同的超时和连接池配置
|
||||||
limits = httpx.Limits(
|
limits = httpx.Limits(
|
||||||
@@ -82,12 +110,14 @@ class AIService:
|
|||||||
)
|
)
|
||||||
|
|
||||||
client_kwargs = {
|
client_kwargs = {
|
||||||
"api_key": settings.anthropic_api_key,
|
"api_key": anthropic_key,
|
||||||
"http_client": http_client
|
"http_client": http_client
|
||||||
}
|
}
|
||||||
|
|
||||||
if settings.anthropic_base_url:
|
# 优先使用用户提供的base_url,否则使用全局配置
|
||||||
client_kwargs["base_url"] = settings.anthropic_base_url
|
base_url = api_base_url if api_provider == "anthropic" else app_settings.anthropic_base_url
|
||||||
|
if base_url:
|
||||||
|
client_kwargs["base_url"] = base_url
|
||||||
|
|
||||||
self.anthropic_client = AsyncAnthropic(**client_kwargs)
|
self.anthropic_client = AsyncAnthropic(**client_kwargs)
|
||||||
logger.info("✅ Anthropic客户端初始化成功")
|
logger.info("✅ Anthropic客户端初始化成功")
|
||||||
@@ -123,10 +153,10 @@ class AIService:
|
|||||||
Returns:
|
Returns:
|
||||||
生成的文本
|
生成的文本
|
||||||
"""
|
"""
|
||||||
provider = provider or settings.default_ai_provider
|
provider = provider or self.api_provider
|
||||||
model = model or settings.default_model
|
model = model or self.default_model
|
||||||
temperature = temperature or settings.default_temperature
|
temperature = temperature or self.default_temperature
|
||||||
max_tokens = max_tokens or settings.default_max_tokens
|
max_tokens = max_tokens or self.default_max_tokens
|
||||||
|
|
||||||
if provider == "openai":
|
if provider == "openai":
|
||||||
return await self._generate_openai(
|
return await self._generate_openai(
|
||||||
@@ -162,10 +192,10 @@ class AIService:
|
|||||||
Yields:
|
Yields:
|
||||||
生成的文本片段
|
生成的文本片段
|
||||||
"""
|
"""
|
||||||
provider = provider or settings.default_ai_provider
|
provider = provider or self.api_provider
|
||||||
model = model or settings.default_model
|
model = model or self.default_model
|
||||||
temperature = temperature or settings.default_temperature
|
temperature = temperature or self.default_temperature
|
||||||
max_tokens = max_tokens or settings.default_max_tokens
|
max_tokens = max_tokens or self.default_max_tokens
|
||||||
|
|
||||||
if provider == "openai":
|
if provider == "openai":
|
||||||
async for chunk in self._generate_openai_stream(
|
async for chunk in self._generate_openai_stream(
|
||||||
@@ -359,5 +389,37 @@ class AIService:
|
|||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
# 创建全局AI服务实例
|
# 创建全局AI服务实例(使用环境变量配置,用于向后兼容)
|
||||||
ai_service = AIService()
|
ai_service = AIService()
|
||||||
|
|
||||||
|
|
||||||
|
def create_user_ai_service(
|
||||||
|
api_provider: str,
|
||||||
|
api_key: str,
|
||||||
|
api_base_url: str,
|
||||||
|
model_name: str,
|
||||||
|
temperature: float,
|
||||||
|
max_tokens: int
|
||||||
|
) -> AIService:
|
||||||
|
"""
|
||||||
|
根据用户设置创建AI服务实例
|
||||||
|
|
||||||
|
Args:
|
||||||
|
api_provider: API提供商
|
||||||
|
api_key: API密钥
|
||||||
|
api_base_url: API基础URL
|
||||||
|
model_name: 模型名称
|
||||||
|
temperature: 温度参数
|
||||||
|
max_tokens: 最大tokens
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AIService实例
|
||||||
|
"""
|
||||||
|
return AIService(
|
||||||
|
api_provider=api_provider,
|
||||||
|
api_key=api_key,
|
||||||
|
api_base_url=api_base_url,
|
||||||
|
default_model=model_name,
|
||||||
|
default_temperature=temperature,
|
||||||
|
default_max_tokens=max_tokens
|
||||||
|
)
|
||||||
@@ -26,7 +26,14 @@ class PromptService:
|
|||||||
- 为故事发展提供支撑
|
- 为故事发展提供支撑
|
||||||
- 具有独特性和吸引力
|
- 具有独特性和吸引力
|
||||||
|
|
||||||
**重要:你必须只返回纯JSON格式,不要包含任何markdown标记、代码块标记或其他说明文字。**
|
**重要格式要求:**
|
||||||
|
1. 只返回纯JSON格式,不要包含任何markdown标记、代码块标记或其他说明文字
|
||||||
|
2. 不要在JSON字符串值中使用中文引号(""''),请使用英文引号或直接省略引号
|
||||||
|
3. 专有名词和强调内容可以使用【】或《》标记,不要用引号
|
||||||
|
|
||||||
|
**正确示例**:
|
||||||
|
- ✅ "距离【大灾变】爆发" 或 "距离大灾变爆发"
|
||||||
|
- ❌ "距离"大灾变"爆发" (会导致JSON解析失败)
|
||||||
|
|
||||||
请严格按照以下JSON格式返回(每个字段为200-300字的文本描述):
|
请严格按照以下JSON格式返回(每个字段为200-300字的文本描述):
|
||||||
{{
|
{{
|
||||||
@@ -36,7 +43,10 @@ class PromptService:
|
|||||||
"rules": "世界规则的详细描述,包括运行法则、特殊设定、社会规则、权力结构"
|
"rules": "世界规则的详细描述,包括运行法则、特殊设定、社会规则、权力结构"
|
||||||
}}
|
}}
|
||||||
|
|
||||||
再次强调:只返回纯JSON对象,不要有```json```这样的标记,不要有任何额外的文字说明。"""
|
再次强调:
|
||||||
|
1. 只返回纯JSON对象,不要有```json```这样的标记
|
||||||
|
2. 文本中不要使用中文引号(""),使用【】或《》代替
|
||||||
|
3. 不要有任何额外的文字说明"""
|
||||||
|
|
||||||
# 批量角色生成提示词
|
# 批量角色生成提示词
|
||||||
CHARACTERS_BATCH_GENERATION = """你是一位专业的角色设定师。请根据以下世界观和要求,生成{count}个立体丰满的角色和组织:
|
CHARACTERS_BATCH_GENERATION = """你是一位专业的角色设定师。请根据以下世界观和要求,生成{count}个立体丰满的角色和组织:
|
||||||
@@ -67,7 +77,10 @@ class PromptService:
|
|||||||
- 组织要有存在的合理性
|
- 组织要有存在的合理性
|
||||||
- 所有实体要为故事服务
|
- 所有实体要为故事服务
|
||||||
|
|
||||||
**重要:你必须只返回纯JSON数组格式,不要包含任何markdown标记、代码块标记或其他说明文字。**
|
**重要格式要求:**
|
||||||
|
1. 只返回纯JSON数组格式,不要包含任何markdown标记、代码块标记或其他说明文字
|
||||||
|
2. 不要在JSON字符串值中使用中文引号(""''),请使用英文引号或【】《》标记
|
||||||
|
3. 专有名词和强调内容使用【】或《》,不要用引号
|
||||||
|
|
||||||
请严格按照以下JSON数组格式返回(每个角色为数组中的一个对象):
|
请严格按照以下JSON数组格式返回(每个角色为数组中的一个对象):
|
||||||
[
|
[
|
||||||
@@ -134,7 +147,8 @@ class PromptService:
|
|||||||
再次强调:
|
再次强调:
|
||||||
1. 只返回纯JSON数组,不要有```json```这样的标记
|
1. 只返回纯JSON数组,不要有```json```这样的标记
|
||||||
2. 数组中必须精确包含{count}个对象
|
2. 数组中必须精确包含{count}个对象
|
||||||
3. 不要引用任何本批次中不存在的角色或组织名称"""
|
3. 不要引用任何本批次中不存在的角色或组织名称
|
||||||
|
4. 文本描述中不要使用中文引号(""),改用【】或《》"""
|
||||||
|
|
||||||
# 完整大纲生成提示词
|
# 完整大纲生成提示词
|
||||||
COMPLETE_OUTLINE_GENERATION = """你是一位经验丰富的小说作家和编剧。请根据以下信息生成完整的{chapter_count}章小说大纲:
|
COMPLETE_OUTLINE_GENERATION = """你是一位经验丰富的小说作家和编剧。请根据以下信息生成完整的{chapter_count}章小说大纲:
|
||||||
@@ -166,7 +180,10 @@ class PromptService:
|
|||||||
- 节奏把控:有张有弛
|
- 节奏把控:有张有弛
|
||||||
- 视角统一:采用{narrative_perspective}视角叙事
|
- 视角统一:采用{narrative_perspective}视角叙事
|
||||||
|
|
||||||
**重要:你必须只返回纯JSON数组格式,不要包含任何markdown标记、代码块标记或其他说明文字。**
|
**重要格式要求:**
|
||||||
|
1. 只返回纯JSON数组格式,不要包含任何markdown标记、代码块标记或其他说明文字
|
||||||
|
2. 不要在JSON字符串值中使用中文引号(""''),请使用【】或《》标记
|
||||||
|
3. 专有名词、书名、事件名使用【】或《》
|
||||||
|
|
||||||
请严格按照以下JSON数组格式返回(共{chapter_count}个章节对象):
|
请严格按照以下JSON数组格式返回(共{chapter_count}个章节对象):
|
||||||
[
|
[
|
||||||
@@ -192,7 +209,10 @@ class PromptService:
|
|||||||
}}
|
}}
|
||||||
]
|
]
|
||||||
|
|
||||||
再次强调:只返回纯JSON数组,不要有```json```这样的标记,不要有任何额外的文字说明。数组中要包含{chapter_count}个章节对象。"""
|
再次强调:
|
||||||
|
1. 只返回纯JSON数组,不要有```json```这样的标记
|
||||||
|
2. 数组中要包含{chapter_count}个章节对象
|
||||||
|
3. 文本中不要使用中文引号(""),改用【】或《》"""
|
||||||
|
|
||||||
# 大纲续写提示词
|
# 大纲续写提示词
|
||||||
OUTLINE_CONTINUE_GENERATION = """你是一位经验丰富的小说作家和编剧。请基于以下信息续写小说大纲:
|
OUTLINE_CONTINUE_GENERATION = """你是一位经验丰富的小说作家和编剧。请基于以下信息续写小说大纲:
|
||||||
@@ -232,7 +252,10 @@ class PromptService:
|
|||||||
- 保持与已有章节相同的风格和详细程度
|
- 保持与已有章节相同的风格和详细程度
|
||||||
- 推进角色成长和情节发展
|
- 推进角色成长和情节发展
|
||||||
|
|
||||||
**重要:你必须只返回纯JSON数组格式,不要包含任何markdown标记、代码块标记或其他说明文字。**
|
**重要格式要求:**
|
||||||
|
1. 只返回纯JSON数组格式,不要包含任何markdown标记、代码块标记或其他说明文字
|
||||||
|
2. 不要在JSON字符串值中使用中文引号(""''),请使用【】或《》
|
||||||
|
3. 文本描述中的专有名词使用【】标记
|
||||||
|
|
||||||
请严格按照以下JSON数组格式返回(共{chapter_count}个章节对象):
|
请严格按照以下JSON数组格式返回(共{chapter_count}个章节对象):
|
||||||
[
|
[
|
||||||
@@ -262,7 +285,8 @@ class PromptService:
|
|||||||
1. 只返回纯JSON数组,不要有```json```这样的标记
|
1. 只返回纯JSON数组,不要有```json```这样的标记
|
||||||
2. 数组中要包含{chapter_count}个章节对象
|
2. 数组中要包含{chapter_count}个章节对象
|
||||||
3. 每个summary必须是100-200字的详细描述
|
3. 每个summary必须是100-200字的详细描述
|
||||||
4. 确保字段结构与已有章节完全一致"""
|
4. 确保字段结构与已有章节完全一致
|
||||||
|
5. 文本中不要使用中文引号(""),改用【】或《》"""
|
||||||
|
|
||||||
# AI去味提示词(核心特色功能)
|
# AI去味提示词(核心特色功能)
|
||||||
AI_DENOISING = """你是一位追求自然写作风格的编辑。你的任务是将AI生成的文本改写得更像人类作家的手笔。
|
AI_DENOISING = """你是一位追求自然写作风格的编辑。你的任务是将AI生成的文本改写得更像人类作家的手笔。
|
||||||
@@ -431,7 +455,10 @@ class PromptService:
|
|||||||
4. 情节的递进和冲突升级
|
4. 情节的递进和冲突升级
|
||||||
5. 角色的成长弧线
|
5. 角色的成长弧线
|
||||||
|
|
||||||
**重要:你必须只返回纯JSON格式,不要包含任何markdown标记、代码块标记或其他说明文字。**
|
**重要格式要求:**
|
||||||
|
1. 只返回纯JSON格式,不要包含任何markdown标记、代码块标记或其他说明文字
|
||||||
|
2. 不要在JSON字符串值中使用中文引号(""''),改用【】或《》
|
||||||
|
3. 专有名词和强调内容使用【】标记
|
||||||
|
|
||||||
请严格按照以下JSON格式返回:
|
请严格按照以下JSON格式返回:
|
||||||
{{
|
{{
|
||||||
@@ -444,7 +471,10 @@ class PromptService:
|
|||||||
]
|
]
|
||||||
}}
|
}}
|
||||||
|
|
||||||
再次强调:只返回纯JSON对象,不要有```json```这样的标记,不要有任何额外的文字说明。"""
|
再次强调:
|
||||||
|
1. 只返回纯JSON对象,不要有```json```这样的标记
|
||||||
|
2. 文本中不要使用中文引号(""),改用【】或《》
|
||||||
|
3. 不要有任何额外的文字说明"""
|
||||||
|
|
||||||
# 单个角色生成提示词
|
# 单个角色生成提示词
|
||||||
SINGLE_CHARACTER_GENERATION = """你是一位专业的角色设定师。请根据以下信息创建一个立体饱满的小说角色。
|
SINGLE_CHARACTER_GENERATION = """你是一位专业的角色设定师。请根据以下信息创建一个立体饱满的小说角色。
|
||||||
@@ -487,7 +517,10 @@ class PromptService:
|
|||||||
- 特殊技能或知识
|
- 特殊技能或知识
|
||||||
- 符合世界观设定
|
- 符合世界观设定
|
||||||
|
|
||||||
**你必须只返回纯JSON格式,不要包含任何markdown标记、代码块标记或其他说明文字。**
|
**重要格式要求:**
|
||||||
|
1. 只返回纯JSON格式,不要包含任何markdown标记、代码块标记或其他说明文字
|
||||||
|
2. 不要在JSON字符串值中使用中文引号(""''),改用【】或《》
|
||||||
|
3. 文本描述中的专有名词使用【】标记
|
||||||
|
|
||||||
请严格按照以下JSON格式返回:
|
请严格按照以下JSON格式返回:
|
||||||
{{
|
{{
|
||||||
@@ -543,7 +576,10 @@ class PromptService:
|
|||||||
- 配角要有独特性,不能是工具人
|
- 配角要有独特性,不能是工具人
|
||||||
- 所有设定要为故事服务
|
- 所有设定要为故事服务
|
||||||
|
|
||||||
再次强调:只返回纯JSON对象,不要有```json```这样的标记,不要有任何额外的文字说明。"""
|
再次强调:
|
||||||
|
1. 只返回纯JSON对象,不要有```json```这样的标记
|
||||||
|
2. 文本中不要使用中文引号(""),改用【】或《》
|
||||||
|
3. 不要有任何额外的文字说明"""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def format_prompt(template: str, **kwargs) -> str:
|
def format_prompt(template: str, **kwargs) -> str:
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ import Characters from './pages/Characters';
|
|||||||
import Relationships from './pages/Relationships';
|
import Relationships from './pages/Relationships';
|
||||||
import Organizations from './pages/Organizations';
|
import Organizations from './pages/Organizations';
|
||||||
import Chapters from './pages/Chapters';
|
import Chapters from './pages/Chapters';
|
||||||
|
import Settings from './pages/Settings';
|
||||||
// import Polish from './pages/Polish';
|
// import Polish from './pages/Polish';
|
||||||
import Login from './pages/Login';
|
import Login from './pages/Login';
|
||||||
import AuthCallback from './pages/AuthCallback';
|
import AuthCallback from './pages/AuthCallback';
|
||||||
@@ -31,6 +32,7 @@ function App() {
|
|||||||
|
|
||||||
<Route path="/" element={<ProtectedRoute><ProjectList /></ProtectedRoute>} />
|
<Route path="/" element={<ProtectedRoute><ProjectList /></ProtectedRoute>} />
|
||||||
<Route path="/wizard" element={<ProtectedRoute><ProjectWizardNew /></ProtectedRoute>} />
|
<Route path="/wizard" element={<ProtectedRoute><ProjectWizardNew /></ProtectedRoute>} />
|
||||||
|
<Route path="/settings" element={<ProtectedRoute><Settings /></ProtectedRoute>} />
|
||||||
<Route path="/project/:projectId" element={<ProtectedRoute><ProjectDetail /></ProtectedRoute>}>
|
<Route path="/project/:projectId" element={<ProtectedRoute><ProjectDetail /></ProtectedRoute>}>
|
||||||
<Route index element={<Navigate to="world-setting" replace />} />
|
<Route index element={<Navigate to="world-setting" replace />} />
|
||||||
<Route path="world-setting" element={<WorldSetting />} />
|
<Route path="world-setting" element={<WorldSetting />} />
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
import { useEffect } from 'react';
|
import { useEffect } from 'react';
|
||||||
import { useNavigate } from 'react-router-dom';
|
import { useNavigate } from 'react-router-dom';
|
||||||
import { Card, Button, Empty, Modal, message, Spin, Row, Col, Statistic, Space, Tag, Progress, Typography, Tooltip, Badge } from 'antd';
|
import { Card, Button, Empty, Modal, message, Spin, Row, Col, Statistic, Space, Tag, Progress, Typography, Tooltip, Badge } from 'antd';
|
||||||
import { EditOutlined, DeleteOutlined, BookOutlined, RocketOutlined, CalendarOutlined, FileTextOutlined, TrophyOutlined, FireOutlined } from '@ant-design/icons';
|
import { EditOutlined, DeleteOutlined, BookOutlined, RocketOutlined, CalendarOutlined, FileTextOutlined, TrophyOutlined, FireOutlined, SettingOutlined } from '@ant-design/icons';
|
||||||
import { useStore } from '../store';
|
import { useStore } from '../store';
|
||||||
import { useProjectSync } from '../store/hooks';
|
import { useProjectSync } from '../store/hooks';
|
||||||
import type { ReactNode } from 'react';
|
import type { ReactNode } from 'react';
|
||||||
@@ -161,11 +161,36 @@ export default function ProjectList() {
|
|||||||
style={{
|
style={{
|
||||||
borderRadius: 8,
|
borderRadius: 8,
|
||||||
background: 'linear-gradient(135deg, #667eea 0%, #764ba2 100%)',
|
background: 'linear-gradient(135deg, #667eea 0%, #764ba2 100%)',
|
||||||
border: 'none'
|
border: 'none',
|
||||||
|
boxShadow: '0 2px 8px rgba(102, 126, 234, 0.4)'
|
||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
向导创建
|
向导创建
|
||||||
</Button>
|
</Button>
|
||||||
|
<Button
|
||||||
|
type="default"
|
||||||
|
size={window.innerWidth <= 768 ? 'middle' : 'large'}
|
||||||
|
icon={<SettingOutlined />}
|
||||||
|
onClick={() => navigate('/settings')}
|
||||||
|
style={{
|
||||||
|
borderRadius: 8,
|
||||||
|
borderColor: '#d9d9d9',
|
||||||
|
boxShadow: '0 2px 8px rgba(0, 0, 0, 0.08)',
|
||||||
|
transition: 'all 0.3s ease'
|
||||||
|
}}
|
||||||
|
onMouseEnter={(e) => {
|
||||||
|
e.currentTarget.style.borderColor = '#667eea';
|
||||||
|
e.currentTarget.style.color = '#667eea';
|
||||||
|
e.currentTarget.style.boxShadow = '0 2px 12px rgba(102, 126, 234, 0.3)';
|
||||||
|
}}
|
||||||
|
onMouseLeave={(e) => {
|
||||||
|
e.currentTarget.style.borderColor = '#d9d9d9';
|
||||||
|
e.currentTarget.style.color = 'rgba(0, 0, 0, 0.88)';
|
||||||
|
e.currentTarget.style.boxShadow = '0 2px 8px rgba(0, 0, 0, 0.08)';
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
API设置
|
||||||
|
</Button>
|
||||||
<UserMenu />
|
<UserMenu />
|
||||||
</Col>
|
</Col>
|
||||||
</Row>
|
</Row>
|
||||||
|
|||||||
@@ -0,0 +1,507 @@
|
|||||||
|
import { useState, useEffect } from 'react';
|
||||||
|
import { useNavigate } from 'react-router-dom';
|
||||||
|
import { Card, Form, Input, Button, Select, Slider, InputNumber, message, Space, Typography, Spin, Modal, Tooltip, Alert } from 'antd';
|
||||||
|
import { SettingOutlined, SaveOutlined, DeleteOutlined, ReloadOutlined, ArrowLeftOutlined, InfoCircleOutlined} from '@ant-design/icons';
|
||||||
|
import { settingsApi } from '../services/api';
|
||||||
|
import type { SettingsUpdate } from '../types';
|
||||||
|
|
||||||
|
const { Title, Paragraph } = Typography;
|
||||||
|
const { Option } = Select;
|
||||||
|
|
||||||
|
export default function SettingsPage() {
|
||||||
|
const navigate = useNavigate();
|
||||||
|
const [form] = Form.useForm();
|
||||||
|
const [loading, setLoading] = useState(false);
|
||||||
|
const [initialLoading, setInitialLoading] = useState(true);
|
||||||
|
const [hasSettings, setHasSettings] = useState(false);
|
||||||
|
const [isDefaultSettings, setIsDefaultSettings] = useState(false);
|
||||||
|
const [modelOptions, setModelOptions] = useState<Array<{ value: string; label: string; description: string }>>([]);
|
||||||
|
const [fetchingModels, setFetchingModels] = useState(false);
|
||||||
|
const [modelsFetched, setModelsFetched] = useState(false);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
loadSettings();
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
const loadSettings = async () => {
|
||||||
|
setInitialLoading(true);
|
||||||
|
try {
|
||||||
|
const settings = await settingsApi.getSettings();
|
||||||
|
form.setFieldsValue(settings);
|
||||||
|
|
||||||
|
// 判断是否为默认设置(id='0'表示来自.env的默认配置)
|
||||||
|
if (settings.id === '0' || !settings.id) {
|
||||||
|
setIsDefaultSettings(true);
|
||||||
|
setHasSettings(false);
|
||||||
|
} else {
|
||||||
|
setIsDefaultSettings(false);
|
||||||
|
setHasSettings(true);
|
||||||
|
}
|
||||||
|
} catch (error: any) {
|
||||||
|
// 如果404表示还没有设置,使用默认值
|
||||||
|
if (error?.response?.status === 404) {
|
||||||
|
setHasSettings(false);
|
||||||
|
setIsDefaultSettings(true);
|
||||||
|
form.setFieldsValue({
|
||||||
|
api_provider: 'openai',
|
||||||
|
api_base_url: 'https://api.openai.com/v1',
|
||||||
|
model_name: 'gpt-4',
|
||||||
|
temperature: 0.7,
|
||||||
|
max_tokens: 2000,
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
message.error('加载设置失败');
|
||||||
|
}
|
||||||
|
} finally {
|
||||||
|
setInitialLoading(false);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
const handleSave = async (values: SettingsUpdate) => {
|
||||||
|
setLoading(true);
|
||||||
|
try {
|
||||||
|
await settingsApi.saveSettings(values);
|
||||||
|
message.success('设置已保存');
|
||||||
|
setHasSettings(true);
|
||||||
|
setIsDefaultSettings(false);
|
||||||
|
} catch (error) {
|
||||||
|
message.error('保存设置失败');
|
||||||
|
} finally {
|
||||||
|
setLoading(false);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
const handleReset = () => {
|
||||||
|
Modal.confirm({
|
||||||
|
title: '重置设置',
|
||||||
|
content: '确定要重置为默认值吗?',
|
||||||
|
okText: '确定',
|
||||||
|
cancelText: '取消',
|
||||||
|
onOk: () => {
|
||||||
|
form.setFieldsValue({
|
||||||
|
api_provider: 'openai',
|
||||||
|
api_key: '',
|
||||||
|
api_base_url: 'https://api.openai.com/v1',
|
||||||
|
model_name: 'gpt-4',
|
||||||
|
temperature: 0.7,
|
||||||
|
max_tokens: 2000,
|
||||||
|
});
|
||||||
|
message.info('已重置为默认值,请点击保存');
|
||||||
|
},
|
||||||
|
});
|
||||||
|
};
|
||||||
|
|
||||||
|
const handleDelete = () => {
|
||||||
|
Modal.confirm({
|
||||||
|
title: '删除设置',
|
||||||
|
content: '确定要删除所有设置吗?此操作不可恢复。',
|
||||||
|
okText: '确定',
|
||||||
|
cancelText: '取消',
|
||||||
|
okType: 'danger',
|
||||||
|
onOk: async () => {
|
||||||
|
setLoading(true);
|
||||||
|
try {
|
||||||
|
await settingsApi.deleteSettings();
|
||||||
|
message.success('设置已删除');
|
||||||
|
setHasSettings(false);
|
||||||
|
form.resetFields();
|
||||||
|
} catch (error) {
|
||||||
|
message.error('删除设置失败');
|
||||||
|
} finally {
|
||||||
|
setLoading(false);
|
||||||
|
}
|
||||||
|
},
|
||||||
|
});
|
||||||
|
};
|
||||||
|
|
||||||
|
const apiProviders = [
|
||||||
|
{ value: 'openai', label: 'OpenAI', defaultUrl: 'https://api.openai.com/v1' },
|
||||||
|
{ value: 'azure', label: 'Azure OpenAI', defaultUrl: 'https://YOUR-RESOURCE.openai.azure.com' },
|
||||||
|
{ value: 'anthropic', label: 'Anthropic', defaultUrl: 'https://api.anthropic.com' },
|
||||||
|
{ value: 'custom', label: '自定义', defaultUrl: '' },
|
||||||
|
];
|
||||||
|
|
||||||
|
const handleProviderChange = (value: string) => {
|
||||||
|
const provider = apiProviders.find(p => p.value === value);
|
||||||
|
if (provider && provider.defaultUrl) {
|
||||||
|
form.setFieldValue('api_base_url', provider.defaultUrl);
|
||||||
|
}
|
||||||
|
// 清空模型列表,需要重新获取
|
||||||
|
setModelOptions([]);
|
||||||
|
setModelsFetched(false);
|
||||||
|
};
|
||||||
|
|
||||||
|
const handleFetchModels = async (silent: boolean = false) => {
|
||||||
|
const apiKey = form.getFieldValue('api_key');
|
||||||
|
const apiBaseUrl = form.getFieldValue('api_base_url');
|
||||||
|
const provider = form.getFieldValue('api_provider');
|
||||||
|
|
||||||
|
if (!apiKey || !apiBaseUrl) {
|
||||||
|
if (!silent) {
|
||||||
|
message.warning('请先填写 API 密钥和 API 地址');
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
setFetchingModels(true);
|
||||||
|
try {
|
||||||
|
const response = await settingsApi.getAvailableModels({
|
||||||
|
api_key: apiKey,
|
||||||
|
api_base_url: apiBaseUrl,
|
||||||
|
provider: provider || 'openai'
|
||||||
|
});
|
||||||
|
|
||||||
|
setModelOptions(response.models);
|
||||||
|
setModelsFetched(true);
|
||||||
|
if (!silent) {
|
||||||
|
message.success(`成功获取 ${response.count || response.models.length} 个可用模型`);
|
||||||
|
}
|
||||||
|
} catch (error: any) {
|
||||||
|
const errorMsg = error?.response?.data?.detail || '获取模型列表失败';
|
||||||
|
if (!silent) {
|
||||||
|
message.error(errorMsg);
|
||||||
|
}
|
||||||
|
setModelOptions([]);
|
||||||
|
setModelsFetched(true); // 即使失败也标记为已尝试,避免重复请求
|
||||||
|
} finally {
|
||||||
|
setFetchingModels(false);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
const handleModelSelectFocus = () => {
|
||||||
|
// 如果还没有获取过模型列表,自动获取
|
||||||
|
if (!modelsFetched && !fetchingModels) {
|
||||||
|
handleFetchModels(true); // silent模式,不显示成功消息
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div style={{
|
||||||
|
minHeight: '100vh',
|
||||||
|
background: 'linear-gradient(135deg, #667eea 0%, #764ba2 100%)',
|
||||||
|
padding: window.innerWidth <= 768 ? '20px 16px' : '40px 24px'
|
||||||
|
}}>
|
||||||
|
<div style={{ maxWidth: 800, margin: '0 auto' }}>
|
||||||
|
<Card
|
||||||
|
variant="borderless"
|
||||||
|
style={{
|
||||||
|
background: 'rgba(255, 255, 255, 0.95)',
|
||||||
|
borderRadius: window.innerWidth <= 768 ? 12 : 16,
|
||||||
|
boxShadow: '0 8px 32px rgba(0, 0, 0, 0.1)',
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<Space direction="vertical" size="large" style={{ width: '100%' }}>
|
||||||
|
{/* 标题栏 */}
|
||||||
|
<div style={{ display: 'flex', alignItems: 'center', justifyContent: 'space-between' }}>
|
||||||
|
<Space>
|
||||||
|
<Button
|
||||||
|
icon={<ArrowLeftOutlined />}
|
||||||
|
onClick={() => navigate('/')}
|
||||||
|
type="text"
|
||||||
|
/>
|
||||||
|
<Title level={window.innerWidth <= 768 ? 3 : 2} style={{ margin: 0 }}>
|
||||||
|
<SettingOutlined style={{ marginRight: 8, color: '#667eea' }} />
|
||||||
|
AI API 设置
|
||||||
|
</Title>
|
||||||
|
</Space>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<Paragraph type="secondary" style={{ marginBottom: 0 }}>
|
||||||
|
配置你的AI API接口参数,这些设置将用于小说生成、角色创建等AI功能。
|
||||||
|
</Paragraph>
|
||||||
|
|
||||||
|
{/* 默认配置提示 */}
|
||||||
|
{isDefaultSettings && (
|
||||||
|
<Alert
|
||||||
|
message="使用 .env 文件中的默认配置"
|
||||||
|
description={
|
||||||
|
<div>
|
||||||
|
<p style={{ margin: '8px 0' }}>
|
||||||
|
当前显示的是从服务器 <code>.env</code> 文件读取的默认配置。
|
||||||
|
</p>
|
||||||
|
<p style={{ margin: '8px 0 0 0' }}>
|
||||||
|
点击"保存设置"后,配置将保存到数据库并同步更新到 <code>.env</code> 文件。
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
|
}
|
||||||
|
type="info"
|
||||||
|
showIcon
|
||||||
|
style={{ marginBottom: 16 }}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
|
||||||
|
{/* 已保存配置提示 */}
|
||||||
|
{hasSettings && !isDefaultSettings && (
|
||||||
|
<Alert
|
||||||
|
message="使用已保存的个人配置"
|
||||||
|
type="success"
|
||||||
|
showIcon
|
||||||
|
style={{ marginBottom: 16 }}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
|
||||||
|
{/* 表单 */}
|
||||||
|
<Spin spinning={initialLoading}>
|
||||||
|
<Form
|
||||||
|
form={form}
|
||||||
|
layout="vertical"
|
||||||
|
onFinish={handleSave}
|
||||||
|
autoComplete="off"
|
||||||
|
>
|
||||||
|
<Form.Item
|
||||||
|
label={
|
||||||
|
<Space>
|
||||||
|
<span>API 提供商</span>
|
||||||
|
<Tooltip title="选择你的AI服务提供商">
|
||||||
|
<InfoCircleOutlined style={{ color: '#8c8c8c' }} />
|
||||||
|
</Tooltip>
|
||||||
|
</Space>
|
||||||
|
}
|
||||||
|
name="api_provider"
|
||||||
|
rules={[{ required: true, message: '请选择API提供商' }]}
|
||||||
|
>
|
||||||
|
<Select size="large" onChange={handleProviderChange}>
|
||||||
|
{apiProviders.map(provider => (
|
||||||
|
<Option key={provider.value} value={provider.value}>
|
||||||
|
{provider.label}
|
||||||
|
</Option>
|
||||||
|
))}
|
||||||
|
</Select>
|
||||||
|
</Form.Item>
|
||||||
|
|
||||||
|
<Form.Item
|
||||||
|
label={
|
||||||
|
<Space>
|
||||||
|
<span>API 密钥</span>
|
||||||
|
<Tooltip title="你的API密钥,将加密存储">
|
||||||
|
<InfoCircleOutlined style={{ color: '#8c8c8c' }} />
|
||||||
|
</Tooltip>
|
||||||
|
</Space>
|
||||||
|
}
|
||||||
|
name="api_key"
|
||||||
|
rules={[{ required: true, message: '请输入API密钥' }]}
|
||||||
|
>
|
||||||
|
<Input.Password
|
||||||
|
size="large"
|
||||||
|
placeholder="sk-..."
|
||||||
|
autoComplete="new-password"
|
||||||
|
/>
|
||||||
|
</Form.Item>
|
||||||
|
|
||||||
|
<Form.Item
|
||||||
|
label={
|
||||||
|
<Space>
|
||||||
|
<span>API 地址</span>
|
||||||
|
<Tooltip title="API的基础URL地址">
|
||||||
|
<InfoCircleOutlined style={{ color: '#8c8c8c' }} />
|
||||||
|
</Tooltip>
|
||||||
|
</Space>
|
||||||
|
}
|
||||||
|
name="api_base_url"
|
||||||
|
rules={[
|
||||||
|
{ required: true, message: '请输入API地址' },
|
||||||
|
{ type: 'url', message: '请输入有效的URL' }
|
||||||
|
]}
|
||||||
|
>
|
||||||
|
<Input
|
||||||
|
size="large"
|
||||||
|
placeholder="https://api.openai.com/v1"
|
||||||
|
/>
|
||||||
|
</Form.Item>
|
||||||
|
|
||||||
|
<Form.Item
|
||||||
|
label={
|
||||||
|
<Space>
|
||||||
|
<span>模型名称</span>
|
||||||
|
<Tooltip title="AI模型的名称,如 gpt-4, gpt-3.5-turbo">
|
||||||
|
<InfoCircleOutlined style={{ color: '#8c8c8c' }} />
|
||||||
|
</Tooltip>
|
||||||
|
</Space>
|
||||||
|
}
|
||||||
|
name="model_name"
|
||||||
|
rules={[{ required: true, message: '请输入或选择模型名称' }]}
|
||||||
|
>
|
||||||
|
<Select
|
||||||
|
size="large"
|
||||||
|
showSearch
|
||||||
|
placeholder="输入模型名称或点击获取"
|
||||||
|
optionFilterProp="label"
|
||||||
|
loading={fetchingModels}
|
||||||
|
onFocus={handleModelSelectFocus}
|
||||||
|
filterOption={(input, option) =>
|
||||||
|
(option?.label ?? '').toLowerCase().includes(input.toLowerCase()) ||
|
||||||
|
(option?.description ?? '').toLowerCase().includes(input.toLowerCase())
|
||||||
|
}
|
||||||
|
dropdownRender={(menu) => (
|
||||||
|
<>
|
||||||
|
{menu}
|
||||||
|
{fetchingModels && (
|
||||||
|
<div style={{ padding: '8px 12px', color: '#8c8c8c', textAlign: 'center' }}>
|
||||||
|
<Spin size="small" /> 正在获取模型列表...
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
{!fetchingModels && modelOptions.length === 0 && modelsFetched && (
|
||||||
|
<div style={{ padding: '8px 12px', color: '#ff4d4f', textAlign: 'center' }}>
|
||||||
|
未能获取到模型列表,请检查 API 配置
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
{!fetchingModels && modelOptions.length === 0 && !modelsFetched && (
|
||||||
|
<div style={{ padding: '8px 12px', color: '#8c8c8c', textAlign: 'center' }}>
|
||||||
|
点击输入框自动获取模型列表
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
</>
|
||||||
|
)}
|
||||||
|
notFoundContent={
|
||||||
|
fetchingModels ? (
|
||||||
|
<div style={{ padding: '8px 12px', textAlign: 'center' }}>
|
||||||
|
<Spin size="small" /> 加载中...
|
||||||
|
</div>
|
||||||
|
) : (
|
||||||
|
<div style={{ padding: '8px 12px', color: '#8c8c8c', textAlign: 'center' }}>
|
||||||
|
未找到匹配的模型
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
suffixIcon={
|
||||||
|
<div
|
||||||
|
onClick={(e) => {
|
||||||
|
e.stopPropagation();
|
||||||
|
if (!fetchingModels) {
|
||||||
|
setModelsFetched(false);
|
||||||
|
handleFetchModels(false);
|
||||||
|
}
|
||||||
|
}}
|
||||||
|
style={{
|
||||||
|
cursor: fetchingModels ? 'not-allowed' : 'pointer',
|
||||||
|
display: 'flex',
|
||||||
|
alignItems: 'center',
|
||||||
|
padding: '0 4px',
|
||||||
|
height: '100%',
|
||||||
|
marginRight: -8
|
||||||
|
}}
|
||||||
|
title="重新获取模型列表"
|
||||||
|
>
|
||||||
|
<Button
|
||||||
|
type="text"
|
||||||
|
size="small"
|
||||||
|
icon={<ReloadOutlined />}
|
||||||
|
loading={fetchingModels}
|
||||||
|
style={{ pointerEvents: 'none' }}
|
||||||
|
>
|
||||||
|
刷新
|
||||||
|
</Button>
|
||||||
|
</div>
|
||||||
|
}
|
||||||
|
options={modelOptions.map(model => ({
|
||||||
|
value: model.value,
|
||||||
|
label: model.label,
|
||||||
|
description: model.description
|
||||||
|
}))}
|
||||||
|
optionRender={(option) => (
|
||||||
|
<div>
|
||||||
|
<div style={{ fontWeight: 500 }}>{option.data.label}</div>
|
||||||
|
{option.data.description && (
|
||||||
|
<div style={{ fontSize: '12px', color: '#8c8c8c' }}>
|
||||||
|
{option.data.description}
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
/>
|
||||||
|
</Form.Item>
|
||||||
|
|
||||||
|
<Form.Item
|
||||||
|
label={
|
||||||
|
<Space>
|
||||||
|
<span>温度参数</span>
|
||||||
|
<Tooltip title="控制输出的随机性,值越高越随机(0.0-2.0)">
|
||||||
|
<InfoCircleOutlined style={{ color: '#8c8c8c' }} />
|
||||||
|
</Tooltip>
|
||||||
|
</Space>
|
||||||
|
}
|
||||||
|
name="temperature"
|
||||||
|
>
|
||||||
|
<Slider
|
||||||
|
min={0}
|
||||||
|
max={2}
|
||||||
|
step={0.1}
|
||||||
|
marks={{
|
||||||
|
0: '0.0',
|
||||||
|
0.7: '0.7',
|
||||||
|
1: '1.0',
|
||||||
|
2: '2.0'
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
</Form.Item>
|
||||||
|
|
||||||
|
<Form.Item
|
||||||
|
label={
|
||||||
|
<Space>
|
||||||
|
<span>最大 Token 数</span>
|
||||||
|
<Tooltip title="单次请求的最大token数量">
|
||||||
|
<InfoCircleOutlined style={{ color: '#8c8c8c' }} />
|
||||||
|
</Tooltip>
|
||||||
|
</Space>
|
||||||
|
}
|
||||||
|
name="max_tokens"
|
||||||
|
rules={[
|
||||||
|
{ required: true, message: '请输入最大token数' },
|
||||||
|
{ type: 'number', min: 1, max: 32000, message: '请输入1-32000之间的数字' }
|
||||||
|
]}
|
||||||
|
>
|
||||||
|
<InputNumber
|
||||||
|
size="large"
|
||||||
|
style={{ width: '100%' }}
|
||||||
|
min={1}
|
||||||
|
max={32000}
|
||||||
|
placeholder="2000"
|
||||||
|
/>
|
||||||
|
</Form.Item>
|
||||||
|
|
||||||
|
{/* 操作按钮 */}
|
||||||
|
<Form.Item style={{ marginBottom: 0, marginTop: 32 }}>
|
||||||
|
<Space size="middle" style={{ width: '100%', justifyContent: 'space-between' }}>
|
||||||
|
<Space>
|
||||||
|
<Button
|
||||||
|
type="primary"
|
||||||
|
size="large"
|
||||||
|
icon={<SaveOutlined />}
|
||||||
|
htmlType="submit"
|
||||||
|
loading={loading}
|
||||||
|
style={{
|
||||||
|
background: 'linear-gradient(135deg, #667eea 0%, #764ba2 100%)',
|
||||||
|
border: 'none'
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
保存设置
|
||||||
|
</Button>
|
||||||
|
<Button
|
||||||
|
size="large"
|
||||||
|
icon={<ReloadOutlined />}
|
||||||
|
onClick={handleReset}
|
||||||
|
>
|
||||||
|
重置
|
||||||
|
</Button>
|
||||||
|
</Space>
|
||||||
|
{hasSettings && (
|
||||||
|
<Button
|
||||||
|
danger
|
||||||
|
size="large"
|
||||||
|
icon={<DeleteOutlined />}
|
||||||
|
onClick={handleDelete}
|
||||||
|
loading={loading}
|
||||||
|
>
|
||||||
|
删除设置
|
||||||
|
</Button>
|
||||||
|
)}
|
||||||
|
</Space>
|
||||||
|
</Form.Item>
|
||||||
|
</Form>
|
||||||
|
</Spin>
|
||||||
|
</Space>
|
||||||
|
</Card>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
@@ -23,6 +23,8 @@ import type {
|
|||||||
PolishTextRequest,
|
PolishTextRequest,
|
||||||
GenerateCharactersResponse,
|
GenerateCharactersResponse,
|
||||||
GenerateOutlineResponse,
|
GenerateOutlineResponse,
|
||||||
|
Settings,
|
||||||
|
SettingsUpdate,
|
||||||
} from '../types';
|
} from '../types';
|
||||||
|
|
||||||
const api = axios.create({
|
const api = axios.create({
|
||||||
@@ -124,6 +126,21 @@ export const userApi = {
|
|||||||
getUser: (userId: string) => api.get<unknown, User>(`/users/${userId}`),
|
getUser: (userId: string) => api.get<unknown, User>(`/users/${userId}`),
|
||||||
};
|
};
|
||||||
|
|
||||||
|
export const settingsApi = {
|
||||||
|
getSettings: () => api.get<unknown, Settings>('/settings'),
|
||||||
|
|
||||||
|
saveSettings: (data: SettingsUpdate) =>
|
||||||
|
api.post<unknown, Settings>('/settings', data),
|
||||||
|
|
||||||
|
updateSettings: (data: SettingsUpdate) =>
|
||||||
|
api.put<unknown, Settings>('/settings', data),
|
||||||
|
|
||||||
|
deleteSettings: () => api.delete<unknown, { message: string; user_id: string }>('/settings'),
|
||||||
|
|
||||||
|
getAvailableModels: (params: { api_key: string; api_base_url: string; provider: string }) =>
|
||||||
|
api.get<unknown, { provider: string; models: Array<{ value: string; label: string; description: string }>; count?: number }>('/settings/models', { params }),
|
||||||
|
};
|
||||||
|
|
||||||
export const projectApi = {
|
export const projectApi = {
|
||||||
getProjects: () => api.get<unknown, Project[]>('/projects'),
|
getProjects: () => api.get<unknown, Project[]>('/projects'),
|
||||||
|
|
||||||
|
|||||||
@@ -11,6 +11,31 @@ export interface User {
|
|||||||
last_login: string;
|
last_login: string;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 设置类型定义
|
||||||
|
export interface Settings {
|
||||||
|
id: string;
|
||||||
|
user_id: string;
|
||||||
|
api_provider: string;
|
||||||
|
api_key: string;
|
||||||
|
api_base_url: string;
|
||||||
|
model_name: string;
|
||||||
|
temperature: number;
|
||||||
|
max_tokens: number;
|
||||||
|
preferences?: string;
|
||||||
|
created_at: string;
|
||||||
|
updated_at: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface SettingsUpdate {
|
||||||
|
api_provider?: string;
|
||||||
|
api_key?: string;
|
||||||
|
api_base_url?: string;
|
||||||
|
model_name?: string;
|
||||||
|
temperature?: number;
|
||||||
|
max_tokens?: number;
|
||||||
|
preferences?: string;
|
||||||
|
}
|
||||||
|
|
||||||
// LinuxDO 授权 URL 响应
|
// LinuxDO 授权 URL 响应
|
||||||
export interface AuthUrlResponse {
|
export interface AuthUrlResponse {
|
||||||
auth_url: string;
|
auth_url: string;
|
||||||
|
|||||||
Reference in New Issue
Block a user