支持自定义API接口
This commit is contained in:
@@ -18,9 +18,10 @@ from app.schemas.chapter import (
|
||||
ChapterResponse,
|
||||
ChapterListResponse
|
||||
)
|
||||
from app.services.ai_service import ai_service
|
||||
from app.services.ai_service import AIService
|
||||
from app.services.prompt_service import prompt_service
|
||||
from app.logger import get_logger
|
||||
from app.api.settings import get_user_ai_service
|
||||
|
||||
router = APIRouter(prefix="/chapters", tags=["章节管理"])
|
||||
logger = get_logger(__name__)
|
||||
@@ -247,7 +248,8 @@ async def check_can_generate(
|
||||
@router.post("/{chapter_id}/generate", summary="AI创作章节内容")
|
||||
async def generate_chapter_content(
|
||||
chapter_id: str,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
user_ai_service: AIService = Depends(get_user_ai_service)
|
||||
):
|
||||
"""
|
||||
根据大纲、前置章节内容和项目信息AI创作章节完整内容
|
||||
@@ -372,7 +374,7 @@ async def generate_chapter_content(
|
||||
logger.info(f"开始AI创作章节 {chapter_id}")
|
||||
|
||||
# 调用AI生成
|
||||
ai_content = await ai_service.generate_text(
|
||||
ai_content = await user_ai_service.generate_text(
|
||||
prompt=prompt
|
||||
)
|
||||
|
||||
@@ -410,7 +412,8 @@ async def generate_chapter_content(
|
||||
@router.post("/{chapter_id}/generate-stream", summary="AI创作章节内容(流式)")
|
||||
async def generate_chapter_content_stream(
|
||||
chapter_id: str,
|
||||
request: Request
|
||||
request: Request,
|
||||
user_ai_service: AIService = Depends(get_user_ai_service)
|
||||
):
|
||||
"""
|
||||
根据大纲、前置章节内容和项目信息AI创作章节完整内容(流式返回)
|
||||
@@ -569,7 +572,7 @@ async def generate_chapter_content_stream(
|
||||
|
||||
# 流式生成内容
|
||||
full_content = ""
|
||||
async for chunk in ai_service.generate_text_stream(prompt=prompt):
|
||||
async for chunk in user_ai_service.generate_text_stream(prompt=prompt):
|
||||
full_content += chunk
|
||||
yield f"data: {json.dumps({'type': 'content', 'content': chunk}, ensure_ascii=False)}\n\n"
|
||||
await asyncio.sleep(0) # 让出控制权
|
||||
|
||||
@@ -15,9 +15,10 @@ from app.schemas.character import (
|
||||
CharacterListResponse,
|
||||
CharacterGenerateRequest
|
||||
)
|
||||
from app.services.ai_service import ai_service
|
||||
from app.services.ai_service import AIService
|
||||
from app.services.prompt_service import prompt_service
|
||||
from app.logger import get_logger
|
||||
from app.api.settings import get_user_ai_service
|
||||
|
||||
router = APIRouter(prefix="/characters", tags=["角色管理"])
|
||||
logger = get_logger(__name__)
|
||||
@@ -134,7 +135,8 @@ async def delete_character(
|
||||
@router.post("/generate", response_model=CharacterResponse, summary="AI生成角色")
|
||||
async def generate_character(
|
||||
request: CharacterGenerateRequest,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
user_ai_service: AIService = Depends(get_user_ai_service)
|
||||
):
|
||||
"""
|
||||
使用AI生成角色卡
|
||||
@@ -216,7 +218,7 @@ async def generate_character(
|
||||
logger.info(f" - Prompt长度:{len(prompt)} 字符")
|
||||
|
||||
try:
|
||||
ai_response = await ai_service.generate_text(
|
||||
ai_response = await user_ai_service.generate_text(
|
||||
prompt=prompt,
|
||||
provider=request.provider,
|
||||
model=request.model
|
||||
|
||||
@@ -19,9 +19,10 @@ from app.schemas.outline import (
|
||||
OutlineGenerateRequest,
|
||||
OutlineReorderRequest
|
||||
)
|
||||
from app.services.ai_service import ai_service
|
||||
from app.services.ai_service import AIService
|
||||
from app.services.prompt_service import prompt_service
|
||||
from app.logger import get_logger
|
||||
from app.api.settings import get_user_ai_service
|
||||
|
||||
router = APIRouter(prefix="/outlines", tags=["大纲管理"])
|
||||
logger = get_logger(__name__)
|
||||
@@ -326,7 +327,8 @@ async def reorder_outlines(
|
||||
@router.post("/generate", response_model=OutlineListResponse, summary="AI生成/续写大纲")
|
||||
async def generate_outline(
|
||||
request: OutlineGenerateRequest,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
user_ai_service: AIService = Depends(get_user_ai_service)
|
||||
):
|
||||
"""
|
||||
使用AI生成或续写小说大纲 - 智能模式
|
||||
@@ -363,7 +365,7 @@ async def generate_outline(
|
||||
# 模式:全新生成
|
||||
if actual_mode == "new":
|
||||
return await _generate_new_outline(
|
||||
request, project, db
|
||||
request, project, db, user_ai_service
|
||||
)
|
||||
|
||||
# 模式:续写
|
||||
@@ -375,7 +377,7 @@ async def generate_outline(
|
||||
)
|
||||
|
||||
return await _continue_outline(
|
||||
request, project, existing_outlines, db
|
||||
request, project, existing_outlines, db, user_ai_service
|
||||
)
|
||||
|
||||
else:
|
||||
@@ -394,7 +396,8 @@ async def generate_outline(
|
||||
async def _generate_new_outline(
|
||||
request: OutlineGenerateRequest,
|
||||
project: Project,
|
||||
db: AsyncSession
|
||||
db: AsyncSession,
|
||||
user_ai_service: AIService
|
||||
) -> OutlineListResponse:
|
||||
"""全新生成大纲"""
|
||||
logger.info(f"全新生成大纲 - 项目: {project.id}, keep_existing: {request.keep_existing}")
|
||||
@@ -427,7 +430,7 @@ async def _generate_new_outline(
|
||||
)
|
||||
|
||||
# 调用AI
|
||||
ai_response = await ai_service.generate_text(
|
||||
ai_response = await user_ai_service.generate_text(
|
||||
prompt=prompt,
|
||||
provider=request.provider,
|
||||
model=request.model
|
||||
@@ -473,7 +476,8 @@ async def _continue_outline(
|
||||
request: OutlineGenerateRequest,
|
||||
project: Project,
|
||||
existing_outlines: List[Outline],
|
||||
db: AsyncSession
|
||||
db: AsyncSession,
|
||||
user_ai_service: AIService
|
||||
) -> OutlineListResponse:
|
||||
"""续写大纲"""
|
||||
logger.info(f"续写大纲 - 项目: {project.id}, 已有: {len(existing_outlines)} 章")
|
||||
@@ -536,7 +540,7 @@ async def _continue_outline(
|
||||
)
|
||||
|
||||
# 调用AI
|
||||
ai_response = await ai_service.generate_text(
|
||||
ai_response = await user_ai_service.generate_text(
|
||||
prompt=prompt,
|
||||
provider=request.provider,
|
||||
model=request.model
|
||||
|
||||
@@ -5,9 +5,10 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from app.database import get_db
|
||||
from app.models.generation_history import GenerationHistory
|
||||
from app.schemas.polish import PolishRequest, PolishResponse
|
||||
from app.services.ai_service import ai_service
|
||||
from app.services.ai_service import AIService
|
||||
from app.services.prompt_service import prompt_service
|
||||
from app.logger import get_logger
|
||||
from app.api.settings import get_user_ai_service
|
||||
|
||||
router = APIRouter(prefix="/polish", tags=["AI去味"])
|
||||
logger = get_logger(__name__)
|
||||
@@ -16,7 +17,8 @@ logger = get_logger(__name__)
|
||||
@router.post("", response_model=PolishResponse, summary="AI去味")
|
||||
async def polish_text(
|
||||
request: PolishRequest,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
user_ai_service: AIService = Depends(get_user_ai_service)
|
||||
):
|
||||
"""
|
||||
AI去味 - 将AI生成的文本改写得更像人类作家的手笔
|
||||
@@ -83,7 +85,8 @@ async def polish_batch(
|
||||
project_id: int = None,
|
||||
provider: str = None,
|
||||
model: str = None,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
user_ai_service: AIService = Depends(get_user_ai_service)
|
||||
):
|
||||
"""
|
||||
批量处理多个文本的AI去味
|
||||
@@ -98,7 +101,7 @@ async def polish_batch(
|
||||
|
||||
prompt = prompt_service.get_denoising_prompt(original_text=text)
|
||||
|
||||
polished_text = await ai_service.generate_text(
|
||||
polished_text = await user_ai_service.generate_text(
|
||||
prompt=prompt,
|
||||
provider=provider,
|
||||
model=model
|
||||
|
||||
@@ -0,0 +1,299 @@
|
||||
"""
|
||||
设置管理 API
|
||||
"""
|
||||
from fastapi import APIRouter, HTTPException, Request, Depends
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select
|
||||
from typing import Dict, Any, List
|
||||
from pathlib import Path
|
||||
import httpx
|
||||
|
||||
from app.database import get_db
|
||||
from app.models.settings import Settings
|
||||
from app.schemas.settings import SettingsCreate, SettingsUpdate, SettingsResponse
|
||||
from app.user_manager import User
|
||||
from app.logger import get_logger
|
||||
from app.config import settings as app_settings, PROJECT_ROOT
|
||||
from app.services.ai_service import AIService, create_user_ai_service
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/settings", tags=["设置管理"])
|
||||
|
||||
|
||||
def read_env_defaults() -> Dict[str, Any]:
|
||||
"""从.env文件读取默认配置(仅读取,不修改)"""
|
||||
return {
|
||||
"api_provider": app_settings.default_ai_provider,
|
||||
"api_key": app_settings.openai_api_key or app_settings.anthropic_api_key or "",
|
||||
"api_base_url": app_settings.openai_base_url or app_settings.anthropic_base_url or "",
|
||||
"model_name": app_settings.default_model,
|
||||
"temperature": app_settings.default_temperature,
|
||||
"max_tokens": app_settings.default_max_tokens,
|
||||
}
|
||||
|
||||
|
||||
def require_login(request: Request):
|
||||
"""依赖:要求用户已登录"""
|
||||
if not hasattr(request.state, "user") or not request.state.user:
|
||||
raise HTTPException(status_code=401, detail="需要登录")
|
||||
return request.state.user
|
||||
|
||||
|
||||
async def get_user_ai_service(
|
||||
user: User = Depends(require_login),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
) -> AIService:
|
||||
"""
|
||||
依赖:获取当前用户的AI服务实例
|
||||
从数据库读取用户设置并创建对应的AI服务
|
||||
"""
|
||||
result = await db.execute(
|
||||
select(Settings).where(Settings.user_id == user.user_id)
|
||||
)
|
||||
settings = result.scalar_one_or_none()
|
||||
|
||||
if not settings:
|
||||
# 如果用户没有设置,从.env读取并保存
|
||||
env_defaults = read_env_defaults()
|
||||
settings = Settings(
|
||||
user_id=user.user_id,
|
||||
**env_defaults
|
||||
)
|
||||
db.add(settings)
|
||||
await db.commit()
|
||||
await db.refresh(settings)
|
||||
logger.info(f"用户 {user.user_id} 首次使用AI服务,已从.env同步设置到数据库")
|
||||
|
||||
# 使用用户设置创建AI服务实例
|
||||
return create_user_ai_service(
|
||||
api_provider=settings.api_provider,
|
||||
api_key=settings.api_key,
|
||||
api_base_url=settings.api_base_url or "",
|
||||
model_name=settings.model_name,
|
||||
temperature=settings.temperature,
|
||||
max_tokens=settings.max_tokens
|
||||
)
|
||||
|
||||
|
||||
@router.get("", response_model=SettingsResponse)
|
||||
async def get_settings(
|
||||
user: User = Depends(require_login),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
获取当前用户的设置
|
||||
如果用户没有保存过设置,自动从.env创建并保存到数据库
|
||||
"""
|
||||
result = await db.execute(
|
||||
select(Settings).where(Settings.user_id == user.user_id)
|
||||
)
|
||||
settings = result.scalar_one_or_none()
|
||||
|
||||
if not settings:
|
||||
# 如果用户没有保存过设置,从.env读取默认配置并保存到数据库
|
||||
env_defaults = read_env_defaults()
|
||||
logger.info(f"用户 {user.user_id} 首次获取设置,自动从.env同步到数据库")
|
||||
|
||||
# 创建新设置并保存到数据库
|
||||
settings = Settings(
|
||||
user_id=user.user_id,
|
||||
**env_defaults
|
||||
)
|
||||
db.add(settings)
|
||||
await db.commit()
|
||||
await db.refresh(settings)
|
||||
logger.info(f"用户 {user.user_id} 的设置已从.env同步到数据库")
|
||||
|
||||
logger.info(f"用户 {user.user_id} 获取已保存的设置")
|
||||
return settings
|
||||
|
||||
|
||||
@router.post("", response_model=SettingsResponse)
|
||||
async def save_settings(
|
||||
data: SettingsCreate,
|
||||
user: User = Depends(require_login),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
创建或更新当前用户的设置(Upsert)
|
||||
如果设置已存在则更新,否则创建新设置
|
||||
仅保存到数据库
|
||||
"""
|
||||
# 查找现有设置
|
||||
result = await db.execute(
|
||||
select(Settings).where(Settings.user_id == user.user_id)
|
||||
)
|
||||
settings = result.scalar_one_or_none()
|
||||
|
||||
# 准备数据
|
||||
settings_dict = data.model_dump(exclude_unset=True)
|
||||
|
||||
if settings:
|
||||
# 更新现有设置
|
||||
for key, value in settings_dict.items():
|
||||
setattr(settings, key, value)
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(settings)
|
||||
logger.info(f"用户 {user.user_id} 更新设置")
|
||||
else:
|
||||
# 创建新设置
|
||||
settings = Settings(
|
||||
user_id=user.user_id,
|
||||
**settings_dict
|
||||
)
|
||||
db.add(settings)
|
||||
await db.commit()
|
||||
await db.refresh(settings)
|
||||
logger.info(f"用户 {user.user_id} 创建设置")
|
||||
|
||||
return settings
|
||||
|
||||
|
||||
@router.put("", response_model=SettingsResponse)
|
||||
async def update_settings(
|
||||
data: SettingsUpdate,
|
||||
user: User = Depends(require_login),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
更新当前用户的设置
|
||||
仅保存到数据库
|
||||
"""
|
||||
result = await db.execute(
|
||||
select(Settings).where(Settings.user_id == user.user_id)
|
||||
)
|
||||
settings = result.scalar_one_or_none()
|
||||
|
||||
if not settings:
|
||||
raise HTTPException(status_code=404, detail="设置不存在,请先创建设置")
|
||||
|
||||
# 更新设置
|
||||
update_data = data.model_dump(exclude_unset=True)
|
||||
for key, value in update_data.items():
|
||||
setattr(settings, key, value)
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(settings)
|
||||
logger.info(f"用户 {user.user_id} 更新设置")
|
||||
|
||||
return settings
|
||||
|
||||
|
||||
@router.delete("")
|
||||
async def delete_settings(
|
||||
user: User = Depends(require_login),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
删除当前用户的设置
|
||||
"""
|
||||
result = await db.execute(
|
||||
select(Settings).where(Settings.user_id == user.user_id)
|
||||
)
|
||||
settings = result.scalar_one_or_none()
|
||||
|
||||
if not settings:
|
||||
raise HTTPException(status_code=404, detail="设置不存在")
|
||||
|
||||
await db.delete(settings)
|
||||
await db.commit()
|
||||
logger.info(f"用户 {user.user_id} 删除设置")
|
||||
|
||||
return {"message": "设置已删除", "user_id": user.user_id}
|
||||
|
||||
|
||||
@router.get("/models")
|
||||
async def get_available_models(
|
||||
api_key: str,
|
||||
api_base_url: str,
|
||||
provider: str = "openai"
|
||||
):
|
||||
"""
|
||||
从配置的 API 获取可用的模型列表
|
||||
|
||||
Args:
|
||||
api_key: API 密钥
|
||||
api_base_url: API 基础 URL
|
||||
provider: API 提供商 (openai, anthropic, azure, custom)
|
||||
|
||||
Returns:
|
||||
模型列表
|
||||
"""
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
if provider == "openai" or provider == "azure" or provider == "custom":
|
||||
# OpenAI 兼容接口获取模型列表
|
||||
url = f"{api_base_url.rstrip('/')}/models"
|
||||
headers = {
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
logger.info(f"正在从 {url} 获取模型列表")
|
||||
response = await client.get(url, headers=headers)
|
||||
response.raise_for_status()
|
||||
|
||||
data = response.json()
|
||||
models = []
|
||||
|
||||
if "data" in data and isinstance(data["data"], list):
|
||||
for model in data["data"]:
|
||||
model_id = model.get("id", "")
|
||||
# 过滤出常用的文本生成模型
|
||||
if any(keyword in model_id.lower() for keyword in [
|
||||
"gpt", "gemini", "claude", "llama", "mistral", "qwen", "deepseek"
|
||||
]):
|
||||
models.append({
|
||||
"value": model_id,
|
||||
"label": model_id,
|
||||
"description": model.get("description", "") or f"Created: {model.get('created', 'N/A')}"
|
||||
})
|
||||
|
||||
if not models:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="未能从 API 获取到可用的模型列表"
|
||||
)
|
||||
|
||||
logger.info(f"成功获取 {len(models)} 个模型")
|
||||
return {
|
||||
"provider": provider,
|
||||
"models": models,
|
||||
"count": len(models)
|
||||
}
|
||||
|
||||
elif provider == "anthropic":
|
||||
# Anthropic 没有公开的模型列表API
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Anthropic 不支持自动获取模型列表,请手动输入模型名称"
|
||||
)
|
||||
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"不支持的提供商: {provider}"
|
||||
)
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.error(f"获取模型列表失败 (HTTP {e.response.status_code}): {e.response.text}")
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"无法从 API 获取模型列表 (HTTP {e.response.status_code})"
|
||||
)
|
||||
except httpx.RequestError as e:
|
||||
logger.error(f"请求模型列表失败: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"无法连接到 API: {str(e)}"
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"获取模型列表时发生错误: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"获取模型列表失败: {str(e)}"
|
||||
)
|
||||
@@ -4,6 +4,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select
|
||||
from typing import Dict, Any, AsyncGenerator
|
||||
import json
|
||||
import re
|
||||
|
||||
from app.database import get_db
|
||||
from app.models.project import Project
|
||||
@@ -11,10 +12,11 @@ from app.models.character import Character
|
||||
from app.models.outline import Outline
|
||||
from app.models.chapter import Chapter
|
||||
from app.models.relationship import CharacterRelationship, Organization, OrganizationMember, RelationshipType
|
||||
from app.services.ai_service import ai_service
|
||||
from app.services.ai_service import AIService
|
||||
from app.services.prompt_service import prompt_service
|
||||
from app.logger import get_logger
|
||||
from app.utils.sse_response import SSEResponse, create_sse_response
|
||||
from app.api.settings import get_user_ai_service
|
||||
|
||||
router = APIRouter(prefix="/wizard-stream", tags=["项目创建向导(流式)"])
|
||||
logger = get_logger(__name__)
|
||||
@@ -22,7 +24,8 @@ logger = get_logger(__name__)
|
||||
|
||||
async def world_building_generator(
|
||||
data: Dict[str, Any],
|
||||
db: AsyncSession
|
||||
db: AsyncSession,
|
||||
user_ai_service: AIService
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""世界构建流式生成器"""
|
||||
# 标记数据库会话是否已提交
|
||||
@@ -61,7 +64,7 @@ async def world_building_generator(
|
||||
accumulated_text = ""
|
||||
chunk_count = 0
|
||||
|
||||
async for chunk in ai_service.generate_text_stream(
|
||||
async for chunk in user_ai_service.generate_text_stream(
|
||||
prompt=prompt,
|
||||
provider=provider,
|
||||
model=model
|
||||
@@ -87,24 +90,26 @@ async def world_building_generator(
|
||||
world_data = {}
|
||||
try:
|
||||
cleaned_text = accumulated_text.strip()
|
||||
|
||||
# 移除markdown代码块标记
|
||||
if cleaned_text.startswith('```json'):
|
||||
cleaned_text = cleaned_text[7:]
|
||||
if cleaned_text.startswith('```'):
|
||||
cleaned_text = cleaned_text[3:]
|
||||
cleaned_text = cleaned_text[7:].lstrip('\n\r')
|
||||
elif cleaned_text.startswith('```'):
|
||||
cleaned_text = cleaned_text[3:].lstrip('\n\r')
|
||||
if cleaned_text.endswith('```'):
|
||||
cleaned_text = cleaned_text[:-3]
|
||||
cleaned_text = cleaned_text[:-3].rstrip('\n\r')
|
||||
cleaned_text = cleaned_text.strip()
|
||||
|
||||
world_data = json.loads(cleaned_text)
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"AI返回非JSON格式: {e}")
|
||||
logger.error(f"世界构建JSON解析失败: {e}")
|
||||
world_data = {
|
||||
"time_period": accumulated_text[:300] if len(accumulated_text) > 300 else accumulated_text,
|
||||
"time_period": "AI返回格式错误,请重试",
|
||||
"location": "AI返回格式错误,请重试",
|
||||
"atmosphere": "AI返回格式错误,请重试",
|
||||
"rules": "AI返回格式错误,请重试"
|
||||
}
|
||||
|
||||
# 保存到数据库
|
||||
yield await SSEResponse.send_progress("保存到数据库...", 90)
|
||||
|
||||
@@ -160,18 +165,20 @@ async def world_building_generator(
|
||||
@router.post("/world-building", summary="流式生成世界构建")
|
||||
async def generate_world_building_stream(
|
||||
data: Dict[str, Any],
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
user_ai_service: AIService = Depends(get_user_ai_service)
|
||||
):
|
||||
"""
|
||||
使用SSE流式生成世界构建,避免超时
|
||||
前端使用EventSource接收实时进度和结果
|
||||
"""
|
||||
return create_sse_response(world_building_generator(data, db))
|
||||
return create_sse_response(world_building_generator(data, db, user_ai_service))
|
||||
|
||||
|
||||
async def characters_generator(
|
||||
data: Dict[str, Any],
|
||||
db: AsyncSession
|
||||
db: AsyncSession,
|
||||
user_ai_service: AIService
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""角色批量生成流式生成器 - 优化版:分批+重试"""
|
||||
db_committed = False
|
||||
@@ -270,7 +277,7 @@ async def characters_generator(
|
||||
|
||||
# 流式生成
|
||||
accumulated_text = ""
|
||||
async for chunk in ai_service.generate_text_stream(
|
||||
async for chunk in user_ai_service.generate_text_stream(
|
||||
prompt=prompt,
|
||||
provider=provider,
|
||||
model=model
|
||||
@@ -280,12 +287,13 @@ async def characters_generator(
|
||||
|
||||
# 解析批次结果
|
||||
cleaned_text = accumulated_text.strip()
|
||||
# 移除markdown代码块标记
|
||||
if cleaned_text.startswith('```json'):
|
||||
cleaned_text = cleaned_text[7:]
|
||||
if cleaned_text.startswith('```'):
|
||||
cleaned_text = cleaned_text[3:]
|
||||
cleaned_text = cleaned_text[7:].lstrip('\n\r')
|
||||
elif cleaned_text.startswith('```'):
|
||||
cleaned_text = cleaned_text[3:].lstrip('\n\r')
|
||||
if cleaned_text.endswith('```'):
|
||||
cleaned_text = cleaned_text[:-3]
|
||||
cleaned_text = cleaned_text[:-3].rstrip('\n\r')
|
||||
cleaned_text = cleaned_text.strip()
|
||||
|
||||
characters_data = json.loads(cleaned_text)
|
||||
@@ -684,17 +692,19 @@ async def characters_generator(
|
||||
@router.post("/characters", summary="流式批量生成角色")
|
||||
async def generate_characters_stream(
|
||||
data: Dict[str, Any],
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
user_ai_service: AIService = Depends(get_user_ai_service)
|
||||
):
|
||||
"""
|
||||
使用SSE流式批量生成角色,避免超时
|
||||
"""
|
||||
return create_sse_response(characters_generator(data, db))
|
||||
return create_sse_response(characters_generator(data, db, user_ai_service))
|
||||
|
||||
|
||||
async def outline_generator(
|
||||
data: Dict[str, Any],
|
||||
db: AsyncSession
|
||||
db: AsyncSession,
|
||||
user_ai_service: AIService
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""大纲生成流式生成器 - 向导固定生成前8章作为开局"""
|
||||
db_committed = False
|
||||
@@ -778,6 +788,7 @@ async def outline_generator(
|
||||
batch_requirements += "2. 建立主线冲突和故事钩子\n"
|
||||
batch_requirements += "3. 展开初期情节,为后续发展埋下伏笔\n"
|
||||
batch_requirements += "4. 不要试图完结故事,这只是开始部分\n"
|
||||
batch_requirements += "5. 不要在JSON字符串值中使用中文引号(""''),请使用【】或《》标记\n"
|
||||
|
||||
batch_prompt = prompt_service.get_complete_outline_prompt(
|
||||
title=project.title,
|
||||
@@ -796,7 +807,7 @@ async def outline_generator(
|
||||
|
||||
# 流式生成
|
||||
accumulated_text = ""
|
||||
async for chunk in ai_service.generate_text_stream(
|
||||
async for chunk in user_ai_service.generate_text_stream(
|
||||
prompt=batch_prompt,
|
||||
provider=provider,
|
||||
model=model
|
||||
@@ -806,12 +817,14 @@ async def outline_generator(
|
||||
|
||||
# 解析结果
|
||||
cleaned_text = accumulated_text.strip()
|
||||
|
||||
# 移除markdown代码块标记
|
||||
if cleaned_text.startswith('```json'):
|
||||
cleaned_text = cleaned_text[7:]
|
||||
if cleaned_text.startswith('```'):
|
||||
cleaned_text = cleaned_text[3:]
|
||||
cleaned_text = cleaned_text[7:].lstrip('\n\r')
|
||||
elif cleaned_text.startswith('```'):
|
||||
cleaned_text = cleaned_text[3:].lstrip('\n\r')
|
||||
if cleaned_text.endswith('```'):
|
||||
cleaned_text = cleaned_text[:-3]
|
||||
cleaned_text = cleaned_text[:-3].rstrip('\n\r')
|
||||
cleaned_text = cleaned_text.strip()
|
||||
|
||||
batch_outline_data = json.loads(cleaned_text)
|
||||
@@ -839,7 +852,7 @@ async def outline_generator(
|
||||
logger.info(f"批次{batch_idx+1}成功生成{len(batch_outline_data)}章大纲")
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"批次{batch_idx+1}解析失败(尝试{retry_count+1}/{MAX_RETRIES}): {e}")
|
||||
logger.error(f"大纲生成批次{batch_idx+1} JSON解析失败(尝试{retry_count+1}/{MAX_RETRIES}): {e}")
|
||||
retry_count += 1
|
||||
if retry_count < MAX_RETRIES:
|
||||
yield await SSEResponse.send_progress(
|
||||
@@ -945,12 +958,13 @@ async def outline_generator(
|
||||
@router.post("/outline", summary="流式生成完整大纲")
|
||||
async def generate_outline_stream(
|
||||
data: Dict[str, Any],
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
user_ai_service: AIService = Depends(get_user_ai_service)
|
||||
):
|
||||
"""
|
||||
使用SSE流式生成完整大纲,避免超时
|
||||
"""
|
||||
return create_sse_response(outline_generator(data, db))
|
||||
return create_sse_response(outline_generator(data, db, user_ai_service))
|
||||
|
||||
|
||||
async def update_world_building_generator(
|
||||
@@ -1037,7 +1051,8 @@ async def update_world_building_stream(
|
||||
async def regenerate_world_building_generator(
|
||||
project_id: str,
|
||||
data: Dict[str, Any],
|
||||
db: AsyncSession
|
||||
db: AsyncSession,
|
||||
user_ai_service: AIService
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""重新生成世界观流式生成器"""
|
||||
db_committed = False
|
||||
@@ -1070,7 +1085,7 @@ async def regenerate_world_building_generator(
|
||||
accumulated_text = ""
|
||||
chunk_count = 0
|
||||
|
||||
async for chunk in ai_service.generate_text_stream(
|
||||
async for chunk in user_ai_service.generate_text_stream(
|
||||
prompt=prompt,
|
||||
provider=provider,
|
||||
model=model
|
||||
@@ -1096,19 +1111,21 @@ async def regenerate_world_building_generator(
|
||||
world_data = {}
|
||||
try:
|
||||
cleaned_text = accumulated_text.strip()
|
||||
# 移除markdown代码块标记
|
||||
if cleaned_text.startswith('```json'):
|
||||
cleaned_text = cleaned_text[7:]
|
||||
if cleaned_text.startswith('```'):
|
||||
cleaned_text = cleaned_text[3:]
|
||||
cleaned_text = cleaned_text[7:].lstrip('\n\r')
|
||||
elif cleaned_text.startswith('```'):
|
||||
cleaned_text = cleaned_text[3:].lstrip('\n\r')
|
||||
if cleaned_text.endswith('```'):
|
||||
cleaned_text = cleaned_text[:-3]
|
||||
cleaned_text = cleaned_text[:-3].rstrip('\n\r')
|
||||
cleaned_text = cleaned_text.strip()
|
||||
|
||||
world_data = json.loads(cleaned_text)
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"AI返回非JSON格式: {e}")
|
||||
logger.info(world_data)
|
||||
world_data = {
|
||||
"time_period": accumulated_text[:300] if len(accumulated_text) > 300 else accumulated_text,
|
||||
"time_period": "AI返回格式错误,请重试",
|
||||
"location": "AI返回格式错误,请重试",
|
||||
"atmosphere": "AI返回格式错误,请重试",
|
||||
"rules": "AI返回格式错误,请重试"
|
||||
@@ -1155,7 +1172,8 @@ async def regenerate_world_building_generator(
|
||||
async def regenerate_world_building_stream(
|
||||
project_id: str,
|
||||
data: Dict[str, Any],
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
user_ai_service: AIService = Depends(get_user_ai_service)
|
||||
):
|
||||
"""
|
||||
使用SSE流式重新生成项目的世界观
|
||||
@@ -1165,7 +1183,7 @@ async def regenerate_world_building_stream(
|
||||
"model": "模型名称(可选)"
|
||||
}
|
||||
"""
|
||||
return create_sse_response(regenerate_world_building_generator(project_id, data, db))
|
||||
return create_sse_response(regenerate_world_building_generator(project_id, data, db, user_ai_service))
|
||||
|
||||
|
||||
async def cleanup_wizard_data_generator(
|
||||
|
||||
Reference in New Issue
Block a user