支持自定义API接口

This commit is contained in:
xiamuceer
2025-10-30 16:53:50 +08:00
parent fe974d1524
commit 3aefdd433d
16 changed files with 1143 additions and 97 deletions
+8 -5
View File
@@ -18,9 +18,10 @@ from app.schemas.chapter import (
ChapterResponse,
ChapterListResponse
)
from app.services.ai_service import ai_service
from app.services.ai_service import AIService
from app.services.prompt_service import prompt_service
from app.logger import get_logger
from app.api.settings import get_user_ai_service
router = APIRouter(prefix="/chapters", tags=["章节管理"])
logger = get_logger(__name__)
@@ -247,7 +248,8 @@ async def check_can_generate(
@router.post("/{chapter_id}/generate", summary="AI创作章节内容")
async def generate_chapter_content(
chapter_id: str,
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
user_ai_service: AIService = Depends(get_user_ai_service)
):
"""
根据大纲、前置章节内容和项目信息AI创作章节完整内容
@@ -372,7 +374,7 @@ async def generate_chapter_content(
logger.info(f"开始AI创作章节 {chapter_id}")
# 调用AI生成
ai_content = await ai_service.generate_text(
ai_content = await user_ai_service.generate_text(
prompt=prompt
)
@@ -410,7 +412,8 @@ async def generate_chapter_content(
@router.post("/{chapter_id}/generate-stream", summary="AI创作章节内容(流式)")
async def generate_chapter_content_stream(
chapter_id: str,
request: Request
request: Request,
user_ai_service: AIService = Depends(get_user_ai_service)
):
"""
根据大纲、前置章节内容和项目信息AI创作章节完整内容(流式返回)
@@ -569,7 +572,7 @@ async def generate_chapter_content_stream(
# 流式生成内容
full_content = ""
async for chunk in ai_service.generate_text_stream(prompt=prompt):
async for chunk in user_ai_service.generate_text_stream(prompt=prompt):
full_content += chunk
yield f"data: {json.dumps({'type': 'content', 'content': chunk}, ensure_ascii=False)}\n\n"
await asyncio.sleep(0) # 让出控制权
+5 -3
View File
@@ -15,9 +15,10 @@ from app.schemas.character import (
CharacterListResponse,
CharacterGenerateRequest
)
from app.services.ai_service import ai_service
from app.services.ai_service import AIService
from app.services.prompt_service import prompt_service
from app.logger import get_logger
from app.api.settings import get_user_ai_service
router = APIRouter(prefix="/characters", tags=["角色管理"])
logger = get_logger(__name__)
@@ -134,7 +135,8 @@ async def delete_character(
@router.post("/generate", response_model=CharacterResponse, summary="AI生成角色")
async def generate_character(
request: CharacterGenerateRequest,
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
user_ai_service: AIService = Depends(get_user_ai_service)
):
"""
使用AI生成角色卡
@@ -216,7 +218,7 @@ async def generate_character(
logger.info(f" - Prompt长度:{len(prompt)} 字符")
try:
ai_response = await ai_service.generate_text(
ai_response = await user_ai_service.generate_text(
prompt=prompt,
provider=request.provider,
model=request.model
+12 -8
View File
@@ -19,9 +19,10 @@ from app.schemas.outline import (
OutlineGenerateRequest,
OutlineReorderRequest
)
from app.services.ai_service import ai_service
from app.services.ai_service import AIService
from app.services.prompt_service import prompt_service
from app.logger import get_logger
from app.api.settings import get_user_ai_service
router = APIRouter(prefix="/outlines", tags=["大纲管理"])
logger = get_logger(__name__)
@@ -326,7 +327,8 @@ async def reorder_outlines(
@router.post("/generate", response_model=OutlineListResponse, summary="AI生成/续写大纲")
async def generate_outline(
request: OutlineGenerateRequest,
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
user_ai_service: AIService = Depends(get_user_ai_service)
):
"""
使用AI生成或续写小说大纲 - 智能模式
@@ -363,7 +365,7 @@ async def generate_outline(
# 模式:全新生成
if actual_mode == "new":
return await _generate_new_outline(
request, project, db
request, project, db, user_ai_service
)
# 模式:续写
@@ -375,7 +377,7 @@ async def generate_outline(
)
return await _continue_outline(
request, project, existing_outlines, db
request, project, existing_outlines, db, user_ai_service
)
else:
@@ -394,7 +396,8 @@ async def generate_outline(
async def _generate_new_outline(
request: OutlineGenerateRequest,
project: Project,
db: AsyncSession
db: AsyncSession,
user_ai_service: AIService
) -> OutlineListResponse:
"""全新生成大纲"""
logger.info(f"全新生成大纲 - 项目: {project.id}, keep_existing: {request.keep_existing}")
@@ -427,7 +430,7 @@ async def _generate_new_outline(
)
# 调用AI
ai_response = await ai_service.generate_text(
ai_response = await user_ai_service.generate_text(
prompt=prompt,
provider=request.provider,
model=request.model
@@ -473,7 +476,8 @@ async def _continue_outline(
request: OutlineGenerateRequest,
project: Project,
existing_outlines: List[Outline],
db: AsyncSession
db: AsyncSession,
user_ai_service: AIService
) -> OutlineListResponse:
"""续写大纲"""
logger.info(f"续写大纲 - 项目: {project.id}, 已有: {len(existing_outlines)}")
@@ -536,7 +540,7 @@ async def _continue_outline(
)
# 调用AI
ai_response = await ai_service.generate_text(
ai_response = await user_ai_service.generate_text(
prompt=prompt,
provider=request.provider,
model=request.model
+7 -4
View File
@@ -5,9 +5,10 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.database import get_db
from app.models.generation_history import GenerationHistory
from app.schemas.polish import PolishRequest, PolishResponse
from app.services.ai_service import ai_service
from app.services.ai_service import AIService
from app.services.prompt_service import prompt_service
from app.logger import get_logger
from app.api.settings import get_user_ai_service
router = APIRouter(prefix="/polish", tags=["AI去味"])
logger = get_logger(__name__)
@@ -16,7 +17,8 @@ logger = get_logger(__name__)
@router.post("", response_model=PolishResponse, summary="AI去味")
async def polish_text(
request: PolishRequest,
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
user_ai_service: AIService = Depends(get_user_ai_service)
):
"""
AI去味 - 将AI生成的文本改写得更像人类作家的手笔
@@ -83,7 +85,8 @@ async def polish_batch(
project_id: int = None,
provider: str = None,
model: str = None,
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
user_ai_service: AIService = Depends(get_user_ai_service)
):
"""
批量处理多个文本的AI去味
@@ -98,7 +101,7 @@ async def polish_batch(
prompt = prompt_service.get_denoising_prompt(original_text=text)
polished_text = await ai_service.generate_text(
polished_text = await user_ai_service.generate_text(
prompt=prompt,
provider=provider,
model=model
+299
View File
@@ -0,0 +1,299 @@
"""
设置管理 API
"""
from fastapi import APIRouter, HTTPException, Request, Depends
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
from typing import Dict, Any, List
from pathlib import Path
import httpx
from app.database import get_db
from app.models.settings import Settings
from app.schemas.settings import SettingsCreate, SettingsUpdate, SettingsResponse
from app.user_manager import User
from app.logger import get_logger
from app.config import settings as app_settings, PROJECT_ROOT
from app.services.ai_service import AIService, create_user_ai_service
logger = get_logger(__name__)
router = APIRouter(prefix="/settings", tags=["设置管理"])
def read_env_defaults() -> Dict[str, Any]:
"""从.env文件读取默认配置(仅读取,不修改)"""
return {
"api_provider": app_settings.default_ai_provider,
"api_key": app_settings.openai_api_key or app_settings.anthropic_api_key or "",
"api_base_url": app_settings.openai_base_url or app_settings.anthropic_base_url or "",
"model_name": app_settings.default_model,
"temperature": app_settings.default_temperature,
"max_tokens": app_settings.default_max_tokens,
}
def require_login(request: Request):
"""依赖:要求用户已登录"""
if not hasattr(request.state, "user") or not request.state.user:
raise HTTPException(status_code=401, detail="需要登录")
return request.state.user
async def get_user_ai_service(
user: User = Depends(require_login),
db: AsyncSession = Depends(get_db)
) -> AIService:
"""
依赖:获取当前用户的AI服务实例
从数据库读取用户设置并创建对应的AI服务
"""
result = await db.execute(
select(Settings).where(Settings.user_id == user.user_id)
)
settings = result.scalar_one_or_none()
if not settings:
# 如果用户没有设置,从.env读取并保存
env_defaults = read_env_defaults()
settings = Settings(
user_id=user.user_id,
**env_defaults
)
db.add(settings)
await db.commit()
await db.refresh(settings)
logger.info(f"用户 {user.user_id} 首次使用AI服务,已从.env同步设置到数据库")
# 使用用户设置创建AI服务实例
return create_user_ai_service(
api_provider=settings.api_provider,
api_key=settings.api_key,
api_base_url=settings.api_base_url or "",
model_name=settings.model_name,
temperature=settings.temperature,
max_tokens=settings.max_tokens
)
@router.get("", response_model=SettingsResponse)
async def get_settings(
user: User = Depends(require_login),
db: AsyncSession = Depends(get_db)
):
"""
获取当前用户的设置
如果用户没有保存过设置,自动从.env创建并保存到数据库
"""
result = await db.execute(
select(Settings).where(Settings.user_id == user.user_id)
)
settings = result.scalar_one_or_none()
if not settings:
# 如果用户没有保存过设置,从.env读取默认配置并保存到数据库
env_defaults = read_env_defaults()
logger.info(f"用户 {user.user_id} 首次获取设置,自动从.env同步到数据库")
# 创建新设置并保存到数据库
settings = Settings(
user_id=user.user_id,
**env_defaults
)
db.add(settings)
await db.commit()
await db.refresh(settings)
logger.info(f"用户 {user.user_id} 的设置已从.env同步到数据库")
logger.info(f"用户 {user.user_id} 获取已保存的设置")
return settings
@router.post("", response_model=SettingsResponse)
async def save_settings(
data: SettingsCreate,
user: User = Depends(require_login),
db: AsyncSession = Depends(get_db)
):
"""
创建或更新当前用户的设置(Upsert)
如果设置已存在则更新,否则创建新设置
仅保存到数据库
"""
# 查找现有设置
result = await db.execute(
select(Settings).where(Settings.user_id == user.user_id)
)
settings = result.scalar_one_or_none()
# 准备数据
settings_dict = data.model_dump(exclude_unset=True)
if settings:
# 更新现有设置
for key, value in settings_dict.items():
setattr(settings, key, value)
await db.commit()
await db.refresh(settings)
logger.info(f"用户 {user.user_id} 更新设置")
else:
# 创建新设置
settings = Settings(
user_id=user.user_id,
**settings_dict
)
db.add(settings)
await db.commit()
await db.refresh(settings)
logger.info(f"用户 {user.user_id} 创建设置")
return settings
@router.put("", response_model=SettingsResponse)
async def update_settings(
data: SettingsUpdate,
user: User = Depends(require_login),
db: AsyncSession = Depends(get_db)
):
"""
更新当前用户的设置
仅保存到数据库
"""
result = await db.execute(
select(Settings).where(Settings.user_id == user.user_id)
)
settings = result.scalar_one_or_none()
if not settings:
raise HTTPException(status_code=404, detail="设置不存在,请先创建设置")
# 更新设置
update_data = data.model_dump(exclude_unset=True)
for key, value in update_data.items():
setattr(settings, key, value)
await db.commit()
await db.refresh(settings)
logger.info(f"用户 {user.user_id} 更新设置")
return settings
@router.delete("")
async def delete_settings(
user: User = Depends(require_login),
db: AsyncSession = Depends(get_db)
):
"""
删除当前用户的设置
"""
result = await db.execute(
select(Settings).where(Settings.user_id == user.user_id)
)
settings = result.scalar_one_or_none()
if not settings:
raise HTTPException(status_code=404, detail="设置不存在")
await db.delete(settings)
await db.commit()
logger.info(f"用户 {user.user_id} 删除设置")
return {"message": "设置已删除", "user_id": user.user_id}
@router.get("/models")
async def get_available_models(
api_key: str,
api_base_url: str,
provider: str = "openai"
):
"""
从配置的 API 获取可用的模型列表
Args:
api_key: API 密钥
api_base_url: API 基础 URL
provider: API 提供商 (openai, anthropic, azure, custom)
Returns:
模型列表
"""
try:
async with httpx.AsyncClient(timeout=10.0) as client:
if provider == "openai" or provider == "azure" or provider == "custom":
# OpenAI 兼容接口获取模型列表
url = f"{api_base_url.rstrip('/')}/models"
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json"
}
logger.info(f"正在从 {url} 获取模型列表")
response = await client.get(url, headers=headers)
response.raise_for_status()
data = response.json()
models = []
if "data" in data and isinstance(data["data"], list):
for model in data["data"]:
model_id = model.get("id", "")
# 过滤出常用的文本生成模型
if any(keyword in model_id.lower() for keyword in [
"gpt", "gemini", "claude", "llama", "mistral", "qwen", "deepseek"
]):
models.append({
"value": model_id,
"label": model_id,
"description": model.get("description", "") or f"Created: {model.get('created', 'N/A')}"
})
if not models:
raise HTTPException(
status_code=404,
detail="未能从 API 获取到可用的模型列表"
)
logger.info(f"成功获取 {len(models)} 个模型")
return {
"provider": provider,
"models": models,
"count": len(models)
}
elif provider == "anthropic":
# Anthropic 没有公开的模型列表API
raise HTTPException(
status_code=400,
detail="Anthropic 不支持自动获取模型列表,请手动输入模型名称"
)
else:
raise HTTPException(
status_code=400,
detail=f"不支持的提供商: {provider}"
)
except httpx.HTTPStatusError as e:
logger.error(f"获取模型列表失败 (HTTP {e.response.status_code}): {e.response.text}")
raise HTTPException(
status_code=400,
detail=f"无法从 API 获取模型列表 (HTTP {e.response.status_code})"
)
except httpx.RequestError as e:
logger.error(f"请求模型列表失败: {str(e)}")
raise HTTPException(
status_code=400,
detail=f"无法连接到 API: {str(e)}"
)
except HTTPException:
raise
except Exception as e:
logger.error(f"获取模型列表时发生错误: {str(e)}")
raise HTTPException(
status_code=500,
detail=f"获取模型列表失败: {str(e)}"
)
+56 -38
View File
@@ -4,6 +4,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
from typing import Dict, Any, AsyncGenerator
import json
import re
from app.database import get_db
from app.models.project import Project
@@ -11,10 +12,11 @@ from app.models.character import Character
from app.models.outline import Outline
from app.models.chapter import Chapter
from app.models.relationship import CharacterRelationship, Organization, OrganizationMember, RelationshipType
from app.services.ai_service import ai_service
from app.services.ai_service import AIService
from app.services.prompt_service import prompt_service
from app.logger import get_logger
from app.utils.sse_response import SSEResponse, create_sse_response
from app.api.settings import get_user_ai_service
router = APIRouter(prefix="/wizard-stream", tags=["项目创建向导(流式)"])
logger = get_logger(__name__)
@@ -22,7 +24,8 @@ logger = get_logger(__name__)
async def world_building_generator(
data: Dict[str, Any],
db: AsyncSession
db: AsyncSession,
user_ai_service: AIService
) -> AsyncGenerator[str, None]:
"""世界构建流式生成器"""
# 标记数据库会话是否已提交
@@ -61,7 +64,7 @@ async def world_building_generator(
accumulated_text = ""
chunk_count = 0
async for chunk in ai_service.generate_text_stream(
async for chunk in user_ai_service.generate_text_stream(
prompt=prompt,
provider=provider,
model=model
@@ -87,24 +90,26 @@ async def world_building_generator(
world_data = {}
try:
cleaned_text = accumulated_text.strip()
# 移除markdown代码块标记
if cleaned_text.startswith('```json'):
cleaned_text = cleaned_text[7:]
if cleaned_text.startswith('```'):
cleaned_text = cleaned_text[3:]
cleaned_text = cleaned_text[7:].lstrip('\n\r')
elif cleaned_text.startswith('```'):
cleaned_text = cleaned_text[3:].lstrip('\n\r')
if cleaned_text.endswith('```'):
cleaned_text = cleaned_text[:-3]
cleaned_text = cleaned_text[:-3].rstrip('\n\r')
cleaned_text = cleaned_text.strip()
world_data = json.loads(cleaned_text)
except json.JSONDecodeError as e:
logger.error(f"AI返回非JSON格式: {e}")
logger.error(f"世界构建JSON解析失败: {e}")
world_data = {
"time_period": accumulated_text[:300] if len(accumulated_text) > 300 else accumulated_text,
"time_period": "AI返回格式错误,请重试",
"location": "AI返回格式错误,请重试",
"atmosphere": "AI返回格式错误,请重试",
"rules": "AI返回格式错误,请重试"
}
# 保存到数据库
yield await SSEResponse.send_progress("保存到数据库...", 90)
@@ -160,18 +165,20 @@ async def world_building_generator(
@router.post("/world-building", summary="流式生成世界构建")
async def generate_world_building_stream(
data: Dict[str, Any],
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
user_ai_service: AIService = Depends(get_user_ai_service)
):
"""
使用SSE流式生成世界构建,避免超时
前端使用EventSource接收实时进度和结果
"""
return create_sse_response(world_building_generator(data, db))
return create_sse_response(world_building_generator(data, db, user_ai_service))
async def characters_generator(
data: Dict[str, Any],
db: AsyncSession
db: AsyncSession,
user_ai_service: AIService
) -> AsyncGenerator[str, None]:
"""角色批量生成流式生成器 - 优化版:分批+重试"""
db_committed = False
@@ -270,7 +277,7 @@ async def characters_generator(
# 流式生成
accumulated_text = ""
async for chunk in ai_service.generate_text_stream(
async for chunk in user_ai_service.generate_text_stream(
prompt=prompt,
provider=provider,
model=model
@@ -280,12 +287,13 @@ async def characters_generator(
# 解析批次结果
cleaned_text = accumulated_text.strip()
# 移除markdown代码块标记
if cleaned_text.startswith('```json'):
cleaned_text = cleaned_text[7:]
if cleaned_text.startswith('```'):
cleaned_text = cleaned_text[3:]
cleaned_text = cleaned_text[7:].lstrip('\n\r')
elif cleaned_text.startswith('```'):
cleaned_text = cleaned_text[3:].lstrip('\n\r')
if cleaned_text.endswith('```'):
cleaned_text = cleaned_text[:-3]
cleaned_text = cleaned_text[:-3].rstrip('\n\r')
cleaned_text = cleaned_text.strip()
characters_data = json.loads(cleaned_text)
@@ -684,17 +692,19 @@ async def characters_generator(
@router.post("/characters", summary="流式批量生成角色")
async def generate_characters_stream(
data: Dict[str, Any],
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
user_ai_service: AIService = Depends(get_user_ai_service)
):
"""
使用SSE流式批量生成角色,避免超时
"""
return create_sse_response(characters_generator(data, db))
return create_sse_response(characters_generator(data, db, user_ai_service))
async def outline_generator(
data: Dict[str, Any],
db: AsyncSession
db: AsyncSession,
user_ai_service: AIService
) -> AsyncGenerator[str, None]:
"""大纲生成流式生成器 - 向导固定生成前8章作为开局"""
db_committed = False
@@ -778,6 +788,7 @@ async def outline_generator(
batch_requirements += "2. 建立主线冲突和故事钩子\n"
batch_requirements += "3. 展开初期情节,为后续发展埋下伏笔\n"
batch_requirements += "4. 不要试图完结故事,这只是开始部分\n"
batch_requirements += "5. 不要在JSON字符串值中使用中文引号(""''),请使用【】或《》标记\n"
batch_prompt = prompt_service.get_complete_outline_prompt(
title=project.title,
@@ -796,7 +807,7 @@ async def outline_generator(
# 流式生成
accumulated_text = ""
async for chunk in ai_service.generate_text_stream(
async for chunk in user_ai_service.generate_text_stream(
prompt=batch_prompt,
provider=provider,
model=model
@@ -806,12 +817,14 @@ async def outline_generator(
# 解析结果
cleaned_text = accumulated_text.strip()
# 移除markdown代码块标记
if cleaned_text.startswith('```json'):
cleaned_text = cleaned_text[7:]
if cleaned_text.startswith('```'):
cleaned_text = cleaned_text[3:]
cleaned_text = cleaned_text[7:].lstrip('\n\r')
elif cleaned_text.startswith('```'):
cleaned_text = cleaned_text[3:].lstrip('\n\r')
if cleaned_text.endswith('```'):
cleaned_text = cleaned_text[:-3]
cleaned_text = cleaned_text[:-3].rstrip('\n\r')
cleaned_text = cleaned_text.strip()
batch_outline_data = json.loads(cleaned_text)
@@ -839,7 +852,7 @@ async def outline_generator(
logger.info(f"批次{batch_idx+1}成功生成{len(batch_outline_data)}章大纲")
except json.JSONDecodeError as e:
logger.error(f"批次{batch_idx+1}解析失败(尝试{retry_count+1}/{MAX_RETRIES}): {e}")
logger.error(f"大纲生成批次{batch_idx+1} JSON解析失败(尝试{retry_count+1}/{MAX_RETRIES}): {e}")
retry_count += 1
if retry_count < MAX_RETRIES:
yield await SSEResponse.send_progress(
@@ -945,12 +958,13 @@ async def outline_generator(
@router.post("/outline", summary="流式生成完整大纲")
async def generate_outline_stream(
data: Dict[str, Any],
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
user_ai_service: AIService = Depends(get_user_ai_service)
):
"""
使用SSE流式生成完整大纲,避免超时
"""
return create_sse_response(outline_generator(data, db))
return create_sse_response(outline_generator(data, db, user_ai_service))
async def update_world_building_generator(
@@ -1037,7 +1051,8 @@ async def update_world_building_stream(
async def regenerate_world_building_generator(
project_id: str,
data: Dict[str, Any],
db: AsyncSession
db: AsyncSession,
user_ai_service: AIService
) -> AsyncGenerator[str, None]:
"""重新生成世界观流式生成器"""
db_committed = False
@@ -1070,7 +1085,7 @@ async def regenerate_world_building_generator(
accumulated_text = ""
chunk_count = 0
async for chunk in ai_service.generate_text_stream(
async for chunk in user_ai_service.generate_text_stream(
prompt=prompt,
provider=provider,
model=model
@@ -1096,19 +1111,21 @@ async def regenerate_world_building_generator(
world_data = {}
try:
cleaned_text = accumulated_text.strip()
# 移除markdown代码块标记
if cleaned_text.startswith('```json'):
cleaned_text = cleaned_text[7:]
if cleaned_text.startswith('```'):
cleaned_text = cleaned_text[3:]
cleaned_text = cleaned_text[7:].lstrip('\n\r')
elif cleaned_text.startswith('```'):
cleaned_text = cleaned_text[3:].lstrip('\n\r')
if cleaned_text.endswith('```'):
cleaned_text = cleaned_text[:-3]
cleaned_text = cleaned_text[:-3].rstrip('\n\r')
cleaned_text = cleaned_text.strip()
world_data = json.loads(cleaned_text)
except json.JSONDecodeError as e:
logger.error(f"AI返回非JSON格式: {e}")
logger.info(world_data)
world_data = {
"time_period": accumulated_text[:300] if len(accumulated_text) > 300 else accumulated_text,
"time_period": "AI返回格式错误,请重试",
"location": "AI返回格式错误,请重试",
"atmosphere": "AI返回格式错误,请重试",
"rules": "AI返回格式错误,请重试"
@@ -1155,7 +1172,8 @@ async def regenerate_world_building_generator(
async def regenerate_world_building_stream(
project_id: str,
data: Dict[str, Any],
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
user_ai_service: AIService = Depends(get_user_ai_service)
):
"""
使用SSE流式重新生成项目的世界观
@@ -1165,7 +1183,7 @@ async def regenerate_world_building_stream(
"model": "模型名称(可选)"
}
"""
return create_sse_response(regenerate_world_building_generator(project_id, data, db))
return create_sse_response(regenerate_world_building_generator(project_id, data, db, user_ai_service))
async def cleanup_wizard_data_generator(
+2 -1
View File
@@ -114,11 +114,12 @@ async def db_session_stats():
from app.api import (
projects, outlines, characters, chapters,
wizard_stream, relationships, organizations,
auth, users
auth, users, settings
)
app.include_router(auth.router, prefix="/api")
app.include_router(users.router, prefix="/api")
app.include_router(settings.router, prefix="/api")
app.include_router(projects.router, prefix="/api")
app.include_router(wizard_stream.router, prefix="/api")
+7 -2
View File
@@ -1,5 +1,5 @@
"""设置数据模型"""
from sqlalchemy import Column, String, Text, Float, Integer, DateTime
from sqlalchemy import Column, String, Text, Float, Integer, DateTime, Index
from sqlalchemy.sql import func
from app.database import Base
import uuid
@@ -10,6 +10,7 @@ class Settings(Base):
__tablename__ = "settings"
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
user_id = Column(String(50), nullable=False, unique=True, index=True, comment="用户ID")
api_provider = Column(String(50), default="openai", comment="API提供商")
api_key = Column(String(500), comment="API密钥")
api_base_url = Column(String(500), comment="自定义API地址")
@@ -20,5 +21,9 @@ class Settings(Base):
created_at = Column(DateTime, server_default=func.now(), comment="创建时间")
updated_at = Column(DateTime, server_default=func.now(), onupdate=func.now(), comment="更新时间")
__table_args__ = (
Index('idx_user_id', 'user_id'),
)
def __repr__(self):
return f"<Settings(id={self.id}, api_provider={self.api_provider})>"
return f"<Settings(id={self.id}, user_id={self.user_id}, api_provider={self.api_provider})>"
+37
View File
@@ -0,0 +1,37 @@
"""设置相关的Pydantic模型"""
from pydantic import BaseModel, Field, ConfigDict
from typing import Optional
from datetime import datetime
class SettingsBase(BaseModel):
"""设置基础模型"""
model_config = ConfigDict(protected_namespaces=())
api_provider: Optional[str] = Field(default="openai", description="API提供商")
api_key: Optional[str] = Field(default=None, description="API密钥")
api_base_url: Optional[str] = Field(default=None, description="自定义API地址")
model_name: Optional[str] = Field(default="gpt-4", description="模型名称")
temperature: Optional[float] = Field(default=0.7, ge=0.0, le=2.0, description="温度参数")
max_tokens: Optional[int] = Field(default=2000, ge=1, le=32000, description="最大token数")
preferences: Optional[str] = Field(default=None, description="其他偏好设置(JSON)")
class SettingsCreate(SettingsBase):
"""创建设置请求模型"""
pass
class SettingsUpdate(SettingsBase):
"""更新设置请求模型"""
pass
class SettingsResponse(SettingsBase):
"""设置响应模型"""
model_config = ConfigDict(from_attributes=True, protected_namespaces=())
id: str
user_id: str
created_at: datetime
updated_at: datetime
+84 -22
View File
@@ -2,7 +2,7 @@
from typing import Optional, AsyncGenerator, List, Dict, Any
from openai import AsyncOpenAI
from anthropic import AsyncAnthropic
from app.config import settings
from app.config import settings as app_settings
from app.logger import get_logger
import httpx
@@ -10,12 +10,37 @@ logger = get_logger(__name__)
class AIService:
"""AI服务统一接口"""
"""AI服务统一接口 - 支持从用户设置或全局配置初始化"""
def __init__(self):
"""初始化AI客户端(优化并发性能)"""
def __init__(
self,
api_provider: Optional[str] = None,
api_key: Optional[str] = None,
api_base_url: Optional[str] = None,
default_model: Optional[str] = None,
default_temperature: Optional[float] = None,
default_max_tokens: Optional[int] = None
):
"""
初始化AI客户端(优化并发性能)
Args:
api_provider: API提供商 (openai/anthropic),为None时使用全局配置
api_key: API密钥,为None时使用全局配置
api_base_url: API基础URL,为None时使用全局配置
default_model: 默认模型,为None时使用全局配置
default_temperature: 默认温度,为None时使用全局配置
default_max_tokens: 默认最大tokens,为None时使用全局配置
"""
# 保存用户设置或使用全局配置
self.api_provider = api_provider or app_settings.default_ai_provider
self.default_model = default_model or app_settings.default_model
self.default_temperature = default_temperature or app_settings.default_temperature
self.default_max_tokens = default_max_tokens or app_settings.default_max_tokens
# 初始化OpenAI客户端
if settings.openai_api_key:
openai_key = api_key if api_provider == "openai" else app_settings.openai_api_key
if openai_key:
# 创建自定义的httpx客户端来避免proxies参数问题
try:
# 配置连接池限制,支持高并发
@@ -43,12 +68,14 @@ class AIService:
)
client_kwargs = {
"api_key": settings.openai_api_key,
"api_key": openai_key,
"http_client": http_client
}
if settings.openai_base_url:
client_kwargs["base_url"] = settings.openai_base_url
# 优先使用用户提供的base_url,否则使用全局配置
base_url = api_base_url if api_provider == "openai" else app_settings.openai_base_url
if base_url:
client_kwargs["base_url"] = base_url
self.openai_client = AsyncOpenAI(**client_kwargs)
logger.info("✅ OpenAI客户端初始化成功")
@@ -62,7 +89,8 @@ class AIService:
logger.warning("OpenAI API key未配置")
# 初始化Anthropic客户端
if settings.anthropic_api_key:
anthropic_key = api_key if api_provider == "anthropic" else app_settings.anthropic_api_key
if anthropic_key:
try:
# 为Anthropic设置相同的超时和连接池配置
limits = httpx.Limits(
@@ -82,12 +110,14 @@ class AIService:
)
client_kwargs = {
"api_key": settings.anthropic_api_key,
"api_key": anthropic_key,
"http_client": http_client
}
if settings.anthropic_base_url:
client_kwargs["base_url"] = settings.anthropic_base_url
# 优先使用用户提供的base_url,否则使用全局配置
base_url = api_base_url if api_provider == "anthropic" else app_settings.anthropic_base_url
if base_url:
client_kwargs["base_url"] = base_url
self.anthropic_client = AsyncAnthropic(**client_kwargs)
logger.info("✅ Anthropic客户端初始化成功")
@@ -123,10 +153,10 @@ class AIService:
Returns:
生成的文本
"""
provider = provider or settings.default_ai_provider
model = model or settings.default_model
temperature = temperature or settings.default_temperature
max_tokens = max_tokens or settings.default_max_tokens
provider = provider or self.api_provider
model = model or self.default_model
temperature = temperature or self.default_temperature
max_tokens = max_tokens or self.default_max_tokens
if provider == "openai":
return await self._generate_openai(
@@ -162,10 +192,10 @@ class AIService:
Yields:
生成的文本片段
"""
provider = provider or settings.default_ai_provider
model = model or settings.default_model
temperature = temperature or settings.default_temperature
max_tokens = max_tokens or settings.default_max_tokens
provider = provider or self.api_provider
model = model or self.default_model
temperature = temperature or self.default_temperature
max_tokens = max_tokens or self.default_max_tokens
if provider == "openai":
async for chunk in self._generate_openai_stream(
@@ -359,5 +389,37 @@ class AIService:
raise
# 创建全局AI服务实例
ai_service = AIService()
# 创建全局AI服务实例(使用环境变量配置,用于向后兼容)
ai_service = AIService()
def create_user_ai_service(
api_provider: str,
api_key: str,
api_base_url: str,
model_name: str,
temperature: float,
max_tokens: int
) -> AIService:
"""
根据用户设置创建AI服务实例
Args:
api_provider: API提供商
api_key: API密钥
api_base_url: API基础URL
model_name: 模型名称
temperature: 温度参数
max_tokens: 最大tokens
Returns:
AIService实例
"""
return AIService(
api_provider=api_provider,
api_key=api_key,
api_base_url=api_base_url,
default_model=model_name,
default_temperature=temperature,
default_max_tokens=max_tokens
)
+48 -12
View File
@@ -26,7 +26,14 @@ class PromptService:
- 为故事发展提供支撑
- 具有独特性和吸引力
**重要:你必须只返回纯JSON格式,不要包含任何markdown标记、代码块标记或其他说明文字。**
**重要格式要求**
1. 只返回纯JSON格式,不要包含任何markdown标记、代码块标记或其他说明文字
2. 不要在JSON字符串值中使用中文引号(""''),请使用英文引号或直接省略引号
3. 专有名词和强调内容可以使用【】或《》标记,不要用引号
**正确示例**
- ✅ "距离【大灾变】爆发""距离大灾变爆发"
- ❌ "距离"大灾变"爆发" (会导致JSON解析失败)
请严格按照以下JSON格式返回(每个字段为200-300字的文本描述):
{{
@@ -36,7 +43,10 @@ class PromptService:
"rules": "世界规则的详细描述,包括运行法则、特殊设定、社会规则、权力结构"
}}
再次强调:只返回纯JSON对象,不要有```json```这样的标记,不要有任何额外的文字说明。"""
再次强调:
1. 只返回纯JSON对象,不要有```json```这样的标记
2. 文本中不要使用中文引号(""),使用【】或《》代替
3. 不要有任何额外的文字说明"""
# 批量角色生成提示词
CHARACTERS_BATCH_GENERATION = """你是一位专业的角色设定师。请根据以下世界观和要求,生成{count}个立体丰满的角色和组织:
@@ -67,7 +77,10 @@ class PromptService:
- 组织要有存在的合理性
- 所有实体要为故事服务
**重要:你必须只返回纯JSON数组格式,不要包含任何markdown标记、代码块标记或其他说明文字。**
**重要格式要求**
1. 只返回纯JSON数组格式,不要包含任何markdown标记、代码块标记或其他说明文字
2. 不要在JSON字符串值中使用中文引号(""''),请使用英文引号或【】《》标记
3. 专有名词和强调内容使用【】或《》,不要用引号
请严格按照以下JSON数组格式返回(每个角色为数组中的一个对象):
[
@@ -134,7 +147,8 @@ class PromptService:
再次强调:
1. 只返回纯JSON数组,不要有```json```这样的标记
2. 数组中必须精确包含{count}个对象
3. 不要引用任何本批次中不存在的角色或组织名称"""
3. 不要引用任何本批次中不存在的角色或组织名称
4. 文本描述中不要使用中文引号(""),改用【】或《》"""
# 完整大纲生成提示词
COMPLETE_OUTLINE_GENERATION = """你是一位经验丰富的小说作家和编剧。请根据以下信息生成完整的{chapter_count}章小说大纲:
@@ -166,7 +180,10 @@ class PromptService:
- 节奏把控:有张有弛
- 视角统一:采用{narrative_perspective}视角叙事
**重要:你必须只返回纯JSON数组格式,不要包含任何markdown标记、代码块标记或其他说明文字。**
**重要格式要求**
1. 只返回纯JSON数组格式,不要包含任何markdown标记、代码块标记或其他说明文字
2. 不要在JSON字符串值中使用中文引号(""''),请使用【】或《》标记
3. 专有名词、书名、事件名使用【】或《》
请严格按照以下JSON数组格式返回(共{chapter_count}个章节对象):
[
@@ -192,7 +209,10 @@ class PromptService:
}}
]
再次强调:只返回纯JSON数组,不要有```json```这样的标记,不要有任何额外的文字说明。数组中要包含{chapter_count}个章节对象。"""
再次强调:
1. 只返回纯JSON数组,不要有```json```这样的标记
2. 数组中要包含{chapter_count}个章节对象
3. 文本中不要使用中文引号(""),改用【】或《》"""
# 大纲续写提示词
OUTLINE_CONTINUE_GENERATION = """你是一位经验丰富的小说作家和编剧。请基于以下信息续写小说大纲:
@@ -232,7 +252,10 @@ class PromptService:
- 保持与已有章节相同的风格和详细程度
- 推进角色成长和情节发展
**重要:你必须只返回纯JSON数组格式,不要包含任何markdown标记、代码块标记或其他说明文字。**
**重要格式要求**
1. 只返回纯JSON数组格式,不要包含任何markdown标记、代码块标记或其他说明文字
2. 不要在JSON字符串值中使用中文引号(""''),请使用【】或《》
3. 文本描述中的专有名词使用【】标记
请严格按照以下JSON数组格式返回(共{chapter_count}个章节对象):
[
@@ -262,7 +285,8 @@ class PromptService:
1. 只返回纯JSON数组,不要有```json```这样的标记
2. 数组中要包含{chapter_count}个章节对象
3. 每个summary必须是100-200字的详细描述
4. 确保字段结构与已有章节完全一致"""
4. 确保字段结构与已有章节完全一致
5. 文本中不要使用中文引号(""),改用【】或《》"""
# AI去味提示词(核心特色功能)
AI_DENOISING = """你是一位追求自然写作风格的编辑。你的任务是将AI生成的文本改写得更像人类作家的手笔。
@@ -431,7 +455,10 @@ class PromptService:
4. 情节的递进和冲突升级
5. 角色的成长弧线
**重要:你必须只返回纯JSON格式,不要包含任何markdown标记、代码块标记或其他说明文字。**
**重要格式要求**
1. 只返回纯JSON格式,不要包含任何markdown标记、代码块标记或其他说明文字
2. 不要在JSON字符串值中使用中文引号(""''),改用【】或《》
3. 专有名词和强调内容使用【】标记
请严格按照以下JSON格式返回:
{{
@@ -444,7 +471,10 @@ class PromptService:
]
}}
再次强调:只返回纯JSON对象,不要有```json```这样的标记,不要有任何额外的文字说明。"""
再次强调:
1. 只返回纯JSON对象,不要有```json```这样的标记
2. 文本中不要使用中文引号(""),改用【】或《》
3. 不要有任何额外的文字说明"""
# 单个角色生成提示词
SINGLE_CHARACTER_GENERATION = """你是一位专业的角色设定师。请根据以下信息创建一个立体饱满的小说角色。
@@ -487,7 +517,10 @@ class PromptService:
- 特殊技能或知识
- 符合世界观设定
**你必须只返回纯JSON格式,不要包含任何markdown标记、代码块标记或其他说明文字。**
**重要格式要求:**
1. 只返回纯JSON格式,不要包含任何markdown标记、代码块标记或其他说明文字
2. 不要在JSON字符串值中使用中文引号(""''),改用【】或《》
3. 文本描述中的专有名词使用【】标记
请严格按照以下JSON格式返回:
{{
@@ -543,7 +576,10 @@ class PromptService:
- 配角要有独特性,不能是工具人
- 所有设定要为故事服务
再次强调:只返回纯JSON对象,不要有```json```这样的标记,不要有任何额外的文字说明。"""
再次强调:
1. 只返回纯JSON对象,不要有```json```这样的标记
2. 文本中不要使用中文引号(""),改用【】或《》
3. 不要有任何额外的文字说明"""
@staticmethod
def format_prompt(template: str, **kwargs) -> str:
+2
View File
@@ -10,6 +10,7 @@ import Characters from './pages/Characters';
import Relationships from './pages/Relationships';
import Organizations from './pages/Organizations';
import Chapters from './pages/Chapters';
import Settings from './pages/Settings';
// import Polish from './pages/Polish';
import Login from './pages/Login';
import AuthCallback from './pages/AuthCallback';
@@ -31,6 +32,7 @@ function App() {
<Route path="/" element={<ProtectedRoute><ProjectList /></ProtectedRoute>} />
<Route path="/wizard" element={<ProtectedRoute><ProjectWizardNew /></ProtectedRoute>} />
<Route path="/settings" element={<ProtectedRoute><Settings /></ProtectedRoute>} />
<Route path="/project/:projectId" element={<ProtectedRoute><ProjectDetail /></ProtectedRoute>}>
<Route index element={<Navigate to="world-setting" replace />} />
<Route path="world-setting" element={<WorldSetting />} />
+27 -2
View File
@@ -1,7 +1,7 @@
import { useEffect } from 'react';
import { useNavigate } from 'react-router-dom';
import { Card, Button, Empty, Modal, message, Spin, Row, Col, Statistic, Space, Tag, Progress, Typography, Tooltip, Badge } from 'antd';
import { EditOutlined, DeleteOutlined, BookOutlined, RocketOutlined, CalendarOutlined, FileTextOutlined, TrophyOutlined, FireOutlined } from '@ant-design/icons';
import { EditOutlined, DeleteOutlined, BookOutlined, RocketOutlined, CalendarOutlined, FileTextOutlined, TrophyOutlined, FireOutlined, SettingOutlined } from '@ant-design/icons';
import { useStore } from '../store';
import { useProjectSync } from '../store/hooks';
import type { ReactNode } from 'react';
@@ -161,11 +161,36 @@ export default function ProjectList() {
style={{
borderRadius: 8,
background: 'linear-gradient(135deg, #667eea 0%, #764ba2 100%)',
border: 'none'
border: 'none',
boxShadow: '0 2px 8px rgba(102, 126, 234, 0.4)'
}}
>
</Button>
<Button
type="default"
size={window.innerWidth <= 768 ? 'middle' : 'large'}
icon={<SettingOutlined />}
onClick={() => navigate('/settings')}
style={{
borderRadius: 8,
borderColor: '#d9d9d9',
boxShadow: '0 2px 8px rgba(0, 0, 0, 0.08)',
transition: 'all 0.3s ease'
}}
onMouseEnter={(e) => {
e.currentTarget.style.borderColor = '#667eea';
e.currentTarget.style.color = '#667eea';
e.currentTarget.style.boxShadow = '0 2px 12px rgba(102, 126, 234, 0.3)';
}}
onMouseLeave={(e) => {
e.currentTarget.style.borderColor = '#d9d9d9';
e.currentTarget.style.color = 'rgba(0, 0, 0, 0.88)';
e.currentTarget.style.boxShadow = '0 2px 8px rgba(0, 0, 0, 0.08)';
}}
>
API设置
</Button>
<UserMenu />
</Col>
</Row>
+507
View File
@@ -0,0 +1,507 @@
import { useState, useEffect } from 'react';
import { useNavigate } from 'react-router-dom';
import { Card, Form, Input, Button, Select, Slider, InputNumber, message, Space, Typography, Spin, Modal, Tooltip, Alert } from 'antd';
import { SettingOutlined, SaveOutlined, DeleteOutlined, ReloadOutlined, ArrowLeftOutlined, InfoCircleOutlined} from '@ant-design/icons';
import { settingsApi } from '../services/api';
import type { SettingsUpdate } from '../types';
const { Title, Paragraph } = Typography;
const { Option } = Select;
export default function SettingsPage() {
const navigate = useNavigate();
const [form] = Form.useForm();
const [loading, setLoading] = useState(false);
const [initialLoading, setInitialLoading] = useState(true);
const [hasSettings, setHasSettings] = useState(false);
const [isDefaultSettings, setIsDefaultSettings] = useState(false);
const [modelOptions, setModelOptions] = useState<Array<{ value: string; label: string; description: string }>>([]);
const [fetchingModels, setFetchingModels] = useState(false);
const [modelsFetched, setModelsFetched] = useState(false);
useEffect(() => {
loadSettings();
}, []);
const loadSettings = async () => {
setInitialLoading(true);
try {
const settings = await settingsApi.getSettings();
form.setFieldsValue(settings);
// 判断是否为默认设置(id='0'表示来自.env的默认配置)
if (settings.id === '0' || !settings.id) {
setIsDefaultSettings(true);
setHasSettings(false);
} else {
setIsDefaultSettings(false);
setHasSettings(true);
}
} catch (error: any) {
// 如果404表示还没有设置,使用默认值
if (error?.response?.status === 404) {
setHasSettings(false);
setIsDefaultSettings(true);
form.setFieldsValue({
api_provider: 'openai',
api_base_url: 'https://api.openai.com/v1',
model_name: 'gpt-4',
temperature: 0.7,
max_tokens: 2000,
});
} else {
message.error('加载设置失败');
}
} finally {
setInitialLoading(false);
}
};
const handleSave = async (values: SettingsUpdate) => {
setLoading(true);
try {
await settingsApi.saveSettings(values);
message.success('设置已保存');
setHasSettings(true);
setIsDefaultSettings(false);
} catch (error) {
message.error('保存设置失败');
} finally {
setLoading(false);
}
};
const handleReset = () => {
Modal.confirm({
title: '重置设置',
content: '确定要重置为默认值吗?',
okText: '确定',
cancelText: '取消',
onOk: () => {
form.setFieldsValue({
api_provider: 'openai',
api_key: '',
api_base_url: 'https://api.openai.com/v1',
model_name: 'gpt-4',
temperature: 0.7,
max_tokens: 2000,
});
message.info('已重置为默认值,请点击保存');
},
});
};
const handleDelete = () => {
Modal.confirm({
title: '删除设置',
content: '确定要删除所有设置吗?此操作不可恢复。',
okText: '确定',
cancelText: '取消',
okType: 'danger',
onOk: async () => {
setLoading(true);
try {
await settingsApi.deleteSettings();
message.success('设置已删除');
setHasSettings(false);
form.resetFields();
} catch (error) {
message.error('删除设置失败');
} finally {
setLoading(false);
}
},
});
};
const apiProviders = [
{ value: 'openai', label: 'OpenAI', defaultUrl: 'https://api.openai.com/v1' },
{ value: 'azure', label: 'Azure OpenAI', defaultUrl: 'https://YOUR-RESOURCE.openai.azure.com' },
{ value: 'anthropic', label: 'Anthropic', defaultUrl: 'https://api.anthropic.com' },
{ value: 'custom', label: '自定义', defaultUrl: '' },
];
const handleProviderChange = (value: string) => {
const provider = apiProviders.find(p => p.value === value);
if (provider && provider.defaultUrl) {
form.setFieldValue('api_base_url', provider.defaultUrl);
}
// 清空模型列表,需要重新获取
setModelOptions([]);
setModelsFetched(false);
};
const handleFetchModels = async (silent: boolean = false) => {
const apiKey = form.getFieldValue('api_key');
const apiBaseUrl = form.getFieldValue('api_base_url');
const provider = form.getFieldValue('api_provider');
if (!apiKey || !apiBaseUrl) {
if (!silent) {
message.warning('请先填写 API 密钥和 API 地址');
}
return;
}
setFetchingModels(true);
try {
const response = await settingsApi.getAvailableModels({
api_key: apiKey,
api_base_url: apiBaseUrl,
provider: provider || 'openai'
});
setModelOptions(response.models);
setModelsFetched(true);
if (!silent) {
message.success(`成功获取 ${response.count || response.models.length} 个可用模型`);
}
} catch (error: any) {
const errorMsg = error?.response?.data?.detail || '获取模型列表失败';
if (!silent) {
message.error(errorMsg);
}
setModelOptions([]);
setModelsFetched(true); // 即使失败也标记为已尝试,避免重复请求
} finally {
setFetchingModels(false);
}
};
const handleModelSelectFocus = () => {
// 如果还没有获取过模型列表,自动获取
if (!modelsFetched && !fetchingModels) {
handleFetchModels(true); // silent模式,不显示成功消息
}
};
return (
<div style={{
minHeight: '100vh',
background: 'linear-gradient(135deg, #667eea 0%, #764ba2 100%)',
padding: window.innerWidth <= 768 ? '20px 16px' : '40px 24px'
}}>
<div style={{ maxWidth: 800, margin: '0 auto' }}>
<Card
variant="borderless"
style={{
background: 'rgba(255, 255, 255, 0.95)',
borderRadius: window.innerWidth <= 768 ? 12 : 16,
boxShadow: '0 8px 32px rgba(0, 0, 0, 0.1)',
}}
>
<Space direction="vertical" size="large" style={{ width: '100%' }}>
{/* 标题栏 */}
<div style={{ display: 'flex', alignItems: 'center', justifyContent: 'space-between' }}>
<Space>
<Button
icon={<ArrowLeftOutlined />}
onClick={() => navigate('/')}
type="text"
/>
<Title level={window.innerWidth <= 768 ? 3 : 2} style={{ margin: 0 }}>
<SettingOutlined style={{ marginRight: 8, color: '#667eea' }} />
AI API
</Title>
</Space>
</div>
<Paragraph type="secondary" style={{ marginBottom: 0 }}>
AI API接口参数AI功能
</Paragraph>
{/* 默认配置提示 */}
{isDefaultSettings && (
<Alert
message="使用 .env 文件中的默认配置"
description={
<div>
<p style={{ margin: '8px 0' }}>
<code>.env</code>
</p>
<p style={{ margin: '8px 0 0 0' }}>
"保存设置" <code>.env</code>
</p>
</div>
}
type="info"
showIcon
style={{ marginBottom: 16 }}
/>
)}
{/* 已保存配置提示 */}
{hasSettings && !isDefaultSettings && (
<Alert
message="使用已保存的个人配置"
type="success"
showIcon
style={{ marginBottom: 16 }}
/>
)}
{/* 表单 */}
<Spin spinning={initialLoading}>
<Form
form={form}
layout="vertical"
onFinish={handleSave}
autoComplete="off"
>
<Form.Item
label={
<Space>
<span>API </span>
<Tooltip title="选择你的AI服务提供商">
<InfoCircleOutlined style={{ color: '#8c8c8c' }} />
</Tooltip>
</Space>
}
name="api_provider"
rules={[{ required: true, message: '请选择API提供商' }]}
>
<Select size="large" onChange={handleProviderChange}>
{apiProviders.map(provider => (
<Option key={provider.value} value={provider.value}>
{provider.label}
</Option>
))}
</Select>
</Form.Item>
<Form.Item
label={
<Space>
<span>API </span>
<Tooltip title="你的API密钥,将加密存储">
<InfoCircleOutlined style={{ color: '#8c8c8c' }} />
</Tooltip>
</Space>
}
name="api_key"
rules={[{ required: true, message: '请输入API密钥' }]}
>
<Input.Password
size="large"
placeholder="sk-..."
autoComplete="new-password"
/>
</Form.Item>
<Form.Item
label={
<Space>
<span>API </span>
<Tooltip title="API的基础URL地址">
<InfoCircleOutlined style={{ color: '#8c8c8c' }} />
</Tooltip>
</Space>
}
name="api_base_url"
rules={[
{ required: true, message: '请输入API地址' },
{ type: 'url', message: '请输入有效的URL' }
]}
>
<Input
size="large"
placeholder="https://api.openai.com/v1"
/>
</Form.Item>
<Form.Item
label={
<Space>
<span></span>
<Tooltip title="AI模型的名称,如 gpt-4, gpt-3.5-turbo">
<InfoCircleOutlined style={{ color: '#8c8c8c' }} />
</Tooltip>
</Space>
}
name="model_name"
rules={[{ required: true, message: '请输入或选择模型名称' }]}
>
<Select
size="large"
showSearch
placeholder="输入模型名称或点击获取"
optionFilterProp="label"
loading={fetchingModels}
onFocus={handleModelSelectFocus}
filterOption={(input, option) =>
(option?.label ?? '').toLowerCase().includes(input.toLowerCase()) ||
(option?.description ?? '').toLowerCase().includes(input.toLowerCase())
}
dropdownRender={(menu) => (
<>
{menu}
{fetchingModels && (
<div style={{ padding: '8px 12px', color: '#8c8c8c', textAlign: 'center' }}>
<Spin size="small" /> ...
</div>
)}
{!fetchingModels && modelOptions.length === 0 && modelsFetched && (
<div style={{ padding: '8px 12px', color: '#ff4d4f', textAlign: 'center' }}>
API
</div>
)}
{!fetchingModels && modelOptions.length === 0 && !modelsFetched && (
<div style={{ padding: '8px 12px', color: '#8c8c8c', textAlign: 'center' }}>
</div>
)}
</>
)}
notFoundContent={
fetchingModels ? (
<div style={{ padding: '8px 12px', textAlign: 'center' }}>
<Spin size="small" /> ...
</div>
) : (
<div style={{ padding: '8px 12px', color: '#8c8c8c', textAlign: 'center' }}>
</div>
)
}
suffixIcon={
<div
onClick={(e) => {
e.stopPropagation();
if (!fetchingModels) {
setModelsFetched(false);
handleFetchModels(false);
}
}}
style={{
cursor: fetchingModels ? 'not-allowed' : 'pointer',
display: 'flex',
alignItems: 'center',
padding: '0 4px',
height: '100%',
marginRight: -8
}}
title="重新获取模型列表"
>
<Button
type="text"
size="small"
icon={<ReloadOutlined />}
loading={fetchingModels}
style={{ pointerEvents: 'none' }}
>
</Button>
</div>
}
options={modelOptions.map(model => ({
value: model.value,
label: model.label,
description: model.description
}))}
optionRender={(option) => (
<div>
<div style={{ fontWeight: 500 }}>{option.data.label}</div>
{option.data.description && (
<div style={{ fontSize: '12px', color: '#8c8c8c' }}>
{option.data.description}
</div>
)}
</div>
)}
/>
</Form.Item>
<Form.Item
label={
<Space>
<span></span>
<Tooltip title="控制输出的随机性,值越高越随机(0.0-2.0)">
<InfoCircleOutlined style={{ color: '#8c8c8c' }} />
</Tooltip>
</Space>
}
name="temperature"
>
<Slider
min={0}
max={2}
step={0.1}
marks={{
0: '0.0',
0.7: '0.7',
1: '1.0',
2: '2.0'
}}
/>
</Form.Item>
<Form.Item
label={
<Space>
<span> Token </span>
<Tooltip title="单次请求的最大token数量">
<InfoCircleOutlined style={{ color: '#8c8c8c' }} />
</Tooltip>
</Space>
}
name="max_tokens"
rules={[
{ required: true, message: '请输入最大token数' },
{ type: 'number', min: 1, max: 32000, message: '请输入1-32000之间的数字' }
]}
>
<InputNumber
size="large"
style={{ width: '100%' }}
min={1}
max={32000}
placeholder="2000"
/>
</Form.Item>
{/* 操作按钮 */}
<Form.Item style={{ marginBottom: 0, marginTop: 32 }}>
<Space size="middle" style={{ width: '100%', justifyContent: 'space-between' }}>
<Space>
<Button
type="primary"
size="large"
icon={<SaveOutlined />}
htmlType="submit"
loading={loading}
style={{
background: 'linear-gradient(135deg, #667eea 0%, #764ba2 100%)',
border: 'none'
}}
>
</Button>
<Button
size="large"
icon={<ReloadOutlined />}
onClick={handleReset}
>
</Button>
</Space>
{hasSettings && (
<Button
danger
size="large"
icon={<DeleteOutlined />}
onClick={handleDelete}
loading={loading}
>
</Button>
)}
</Space>
</Form.Item>
</Form>
</Spin>
</Space>
</Card>
</div>
</div>
);
}
+17
View File
@@ -23,6 +23,8 @@ import type {
PolishTextRequest,
GenerateCharactersResponse,
GenerateOutlineResponse,
Settings,
SettingsUpdate,
} from '../types';
const api = axios.create({
@@ -124,6 +126,21 @@ export const userApi = {
getUser: (userId: string) => api.get<unknown, User>(`/users/${userId}`),
};
export const settingsApi = {
getSettings: () => api.get<unknown, Settings>('/settings'),
saveSettings: (data: SettingsUpdate) =>
api.post<unknown, Settings>('/settings', data),
updateSettings: (data: SettingsUpdate) =>
api.put<unknown, Settings>('/settings', data),
deleteSettings: () => api.delete<unknown, { message: string; user_id: string }>('/settings'),
getAvailableModels: (params: { api_key: string; api_base_url: string; provider: string }) =>
api.get<unknown, { provider: string; models: Array<{ value: string; label: string; description: string }>; count?: number }>('/settings/models', { params }),
};
export const projectApi = {
getProjects: () => api.get<unknown, Project[]>('/projects'),
+25
View File
@@ -11,6 +11,31 @@ export interface User {
last_login: string;
}
// 设置类型定义
export interface Settings {
id: string;
user_id: string;
api_provider: string;
api_key: string;
api_base_url: string;
model_name: string;
temperature: number;
max_tokens: number;
preferences?: string;
created_at: string;
updated_at: string;
}
export interface SettingsUpdate {
api_provider?: string;
api_key?: string;
api_base_url?: string;
model_name?: string;
temperature?: number;
max_tokens?: number;
preferences?: string;
}
// LinuxDO 授权 URL 响应
export interface AuthUrlResponse {
auth_url: string;