支持自定义API接口

This commit is contained in:
xiamuceer
2025-10-30 16:53:50 +08:00
parent fe974d1524
commit 3aefdd433d
16 changed files with 1143 additions and 97 deletions
+56 -38
View File
@@ -4,6 +4,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
from typing import Dict, Any, AsyncGenerator
import json
import re
from app.database import get_db
from app.models.project import Project
@@ -11,10 +12,11 @@ from app.models.character import Character
from app.models.outline import Outline
from app.models.chapter import Chapter
from app.models.relationship import CharacterRelationship, Organization, OrganizationMember, RelationshipType
from app.services.ai_service import ai_service
from app.services.ai_service import AIService
from app.services.prompt_service import prompt_service
from app.logger import get_logger
from app.utils.sse_response import SSEResponse, create_sse_response
from app.api.settings import get_user_ai_service
router = APIRouter(prefix="/wizard-stream", tags=["项目创建向导(流式)"])
logger = get_logger(__name__)
@@ -22,7 +24,8 @@ logger = get_logger(__name__)
async def world_building_generator(
data: Dict[str, Any],
db: AsyncSession
db: AsyncSession,
user_ai_service: AIService
) -> AsyncGenerator[str, None]:
"""世界构建流式生成器"""
# 标记数据库会话是否已提交
@@ -61,7 +64,7 @@ async def world_building_generator(
accumulated_text = ""
chunk_count = 0
async for chunk in ai_service.generate_text_stream(
async for chunk in user_ai_service.generate_text_stream(
prompt=prompt,
provider=provider,
model=model
@@ -87,24 +90,26 @@ async def world_building_generator(
world_data = {}
try:
cleaned_text = accumulated_text.strip()
# 移除markdown代码块标记
if cleaned_text.startswith('```json'):
cleaned_text = cleaned_text[7:]
if cleaned_text.startswith('```'):
cleaned_text = cleaned_text[3:]
cleaned_text = cleaned_text[7:].lstrip('\n\r')
elif cleaned_text.startswith('```'):
cleaned_text = cleaned_text[3:].lstrip('\n\r')
if cleaned_text.endswith('```'):
cleaned_text = cleaned_text[:-3]
cleaned_text = cleaned_text[:-3].rstrip('\n\r')
cleaned_text = cleaned_text.strip()
world_data = json.loads(cleaned_text)
except json.JSONDecodeError as e:
logger.error(f"AI返回非JSON格式: {e}")
logger.error(f"世界构建JSON解析失败: {e}")
world_data = {
"time_period": accumulated_text[:300] if len(accumulated_text) > 300 else accumulated_text,
"time_period": "AI返回格式错误,请重试",
"location": "AI返回格式错误,请重试",
"atmosphere": "AI返回格式错误,请重试",
"rules": "AI返回格式错误,请重试"
}
# 保存到数据库
yield await SSEResponse.send_progress("保存到数据库...", 90)
@@ -160,18 +165,20 @@ async def world_building_generator(
@router.post("/world-building", summary="流式生成世界构建")
async def generate_world_building_stream(
data: Dict[str, Any],
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
user_ai_service: AIService = Depends(get_user_ai_service)
):
"""
使用SSE流式生成世界构建,避免超时
前端使用EventSource接收实时进度和结果
"""
return create_sse_response(world_building_generator(data, db))
return create_sse_response(world_building_generator(data, db, user_ai_service))
async def characters_generator(
data: Dict[str, Any],
db: AsyncSession
db: AsyncSession,
user_ai_service: AIService
) -> AsyncGenerator[str, None]:
"""角色批量生成流式生成器 - 优化版:分批+重试"""
db_committed = False
@@ -270,7 +277,7 @@ async def characters_generator(
# 流式生成
accumulated_text = ""
async for chunk in ai_service.generate_text_stream(
async for chunk in user_ai_service.generate_text_stream(
prompt=prompt,
provider=provider,
model=model
@@ -280,12 +287,13 @@ async def characters_generator(
# 解析批次结果
cleaned_text = accumulated_text.strip()
# 移除markdown代码块标记
if cleaned_text.startswith('```json'):
cleaned_text = cleaned_text[7:]
if cleaned_text.startswith('```'):
cleaned_text = cleaned_text[3:]
cleaned_text = cleaned_text[7:].lstrip('\n\r')
elif cleaned_text.startswith('```'):
cleaned_text = cleaned_text[3:].lstrip('\n\r')
if cleaned_text.endswith('```'):
cleaned_text = cleaned_text[:-3]
cleaned_text = cleaned_text[:-3].rstrip('\n\r')
cleaned_text = cleaned_text.strip()
characters_data = json.loads(cleaned_text)
@@ -684,17 +692,19 @@ async def characters_generator(
@router.post("/characters", summary="流式批量生成角色")
async def generate_characters_stream(
data: Dict[str, Any],
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
user_ai_service: AIService = Depends(get_user_ai_service)
):
"""
使用SSE流式批量生成角色,避免超时
"""
return create_sse_response(characters_generator(data, db))
return create_sse_response(characters_generator(data, db, user_ai_service))
async def outline_generator(
data: Dict[str, Any],
db: AsyncSession
db: AsyncSession,
user_ai_service: AIService
) -> AsyncGenerator[str, None]:
"""大纲生成流式生成器 - 向导固定生成前8章作为开局"""
db_committed = False
@@ -778,6 +788,7 @@ async def outline_generator(
batch_requirements += "2. 建立主线冲突和故事钩子\n"
batch_requirements += "3. 展开初期情节,为后续发展埋下伏笔\n"
batch_requirements += "4. 不要试图完结故事,这只是开始部分\n"
batch_requirements += "5. 不要在JSON字符串值中使用中文引号(""''),请使用【】或《》标记\n"
batch_prompt = prompt_service.get_complete_outline_prompt(
title=project.title,
@@ -796,7 +807,7 @@ async def outline_generator(
# 流式生成
accumulated_text = ""
async for chunk in ai_service.generate_text_stream(
async for chunk in user_ai_service.generate_text_stream(
prompt=batch_prompt,
provider=provider,
model=model
@@ -806,12 +817,14 @@ async def outline_generator(
# 解析结果
cleaned_text = accumulated_text.strip()
# 移除markdown代码块标记
if cleaned_text.startswith('```json'):
cleaned_text = cleaned_text[7:]
if cleaned_text.startswith('```'):
cleaned_text = cleaned_text[3:]
cleaned_text = cleaned_text[7:].lstrip('\n\r')
elif cleaned_text.startswith('```'):
cleaned_text = cleaned_text[3:].lstrip('\n\r')
if cleaned_text.endswith('```'):
cleaned_text = cleaned_text[:-3]
cleaned_text = cleaned_text[:-3].rstrip('\n\r')
cleaned_text = cleaned_text.strip()
batch_outline_data = json.loads(cleaned_text)
@@ -839,7 +852,7 @@ async def outline_generator(
logger.info(f"批次{batch_idx+1}成功生成{len(batch_outline_data)}章大纲")
except json.JSONDecodeError as e:
logger.error(f"批次{batch_idx+1}解析失败(尝试{retry_count+1}/{MAX_RETRIES}): {e}")
logger.error(f"大纲生成批次{batch_idx+1} JSON解析失败(尝试{retry_count+1}/{MAX_RETRIES}): {e}")
retry_count += 1
if retry_count < MAX_RETRIES:
yield await SSEResponse.send_progress(
@@ -945,12 +958,13 @@ async def outline_generator(
@router.post("/outline", summary="流式生成完整大纲")
async def generate_outline_stream(
data: Dict[str, Any],
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
user_ai_service: AIService = Depends(get_user_ai_service)
):
"""
使用SSE流式生成完整大纲,避免超时
"""
return create_sse_response(outline_generator(data, db))
return create_sse_response(outline_generator(data, db, user_ai_service))
async def update_world_building_generator(
@@ -1037,7 +1051,8 @@ async def update_world_building_stream(
async def regenerate_world_building_generator(
project_id: str,
data: Dict[str, Any],
db: AsyncSession
db: AsyncSession,
user_ai_service: AIService
) -> AsyncGenerator[str, None]:
"""重新生成世界观流式生成器"""
db_committed = False
@@ -1070,7 +1085,7 @@ async def regenerate_world_building_generator(
accumulated_text = ""
chunk_count = 0
async for chunk in ai_service.generate_text_stream(
async for chunk in user_ai_service.generate_text_stream(
prompt=prompt,
provider=provider,
model=model
@@ -1096,19 +1111,21 @@ async def regenerate_world_building_generator(
world_data = {}
try:
cleaned_text = accumulated_text.strip()
# 移除markdown代码块标记
if cleaned_text.startswith('```json'):
cleaned_text = cleaned_text[7:]
if cleaned_text.startswith('```'):
cleaned_text = cleaned_text[3:]
cleaned_text = cleaned_text[7:].lstrip('\n\r')
elif cleaned_text.startswith('```'):
cleaned_text = cleaned_text[3:].lstrip('\n\r')
if cleaned_text.endswith('```'):
cleaned_text = cleaned_text[:-3]
cleaned_text = cleaned_text[:-3].rstrip('\n\r')
cleaned_text = cleaned_text.strip()
world_data = json.loads(cleaned_text)
except json.JSONDecodeError as e:
logger.error(f"AI返回非JSON格式: {e}")
logger.info(world_data)
world_data = {
"time_period": accumulated_text[:300] if len(accumulated_text) > 300 else accumulated_text,
"time_period": "AI返回格式错误,请重试",
"location": "AI返回格式错误,请重试",
"atmosphere": "AI返回格式错误,请重试",
"rules": "AI返回格式错误,请重试"
@@ -1155,7 +1172,8 @@ async def regenerate_world_building_generator(
async def regenerate_world_building_stream(
project_id: str,
data: Dict[str, Any],
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
user_ai_service: AIService = Depends(get_user_ai_service)
):
"""
使用SSE流式重新生成项目的世界观
@@ -1165,7 +1183,7 @@ async def regenerate_world_building_stream(
"model": "模型名称(可选)"
}
"""
return create_sse_response(regenerate_world_building_generator(project_id, data, db))
return create_sse_response(regenerate_world_building_generator(project_id, data, db, user_ai_service))
async def cleanup_wizard_data_generator(