2025-10-30 11:14:43 +08:00
|
|
|
|
"""项目创建向导流式API - 使用SSE避免超时"""
|
2025-11-07 22:14:20 +08:00
|
|
|
|
from fastapi import APIRouter, Depends, HTTPException, Request
|
2025-10-30 11:14:43 +08:00
|
|
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
|
|
from sqlalchemy import select
|
|
|
|
|
|
from typing import Dict, Any, AsyncGenerator
|
|
|
|
|
|
import json
|
2025-10-30 16:53:50 +08:00
|
|
|
|
import re
|
2025-10-30 11:14:43 +08:00
|
|
|
|
|
|
|
|
|
|
from app.database import get_db
|
|
|
|
|
|
from app.models.project import Project
|
|
|
|
|
|
from app.models.character import Character
|
|
|
|
|
|
from app.models.outline import Outline
|
|
|
|
|
|
from app.models.chapter import Chapter
|
|
|
|
|
|
from app.models.relationship import CharacterRelationship, Organization, OrganizationMember, RelationshipType
|
2025-10-31 17:23:25 +08:00
|
|
|
|
from app.models.writing_style import WritingStyle
|
|
|
|
|
|
from app.models.project_default_style import ProjectDefaultStyle
|
2025-10-30 16:53:50 +08:00
|
|
|
|
from app.services.ai_service import AIService
|
2025-11-07 22:14:20 +08:00
|
|
|
|
from app.services.mcp_tool_service import MCPToolService
|
2025-10-30 11:14:43 +08:00
|
|
|
|
from app.services.prompt_service import prompt_service
|
2025-11-19 13:30:55 +08:00
|
|
|
|
from app.services.plot_expansion_service import PlotExpansionService
|
2025-10-30 11:14:43 +08:00
|
|
|
|
from app.logger import get_logger
|
|
|
|
|
|
from app.utils.sse_response import SSEResponse, create_sse_response
|
2025-10-30 16:53:50 +08:00
|
|
|
|
from app.api.settings import get_user_ai_service
|
2025-10-30 11:14:43 +08:00
|
|
|
|
|
|
|
|
|
|
router = APIRouter(prefix="/wizard-stream", tags=["项目创建向导(流式)"])
|
|
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def world_building_generator(
|
|
|
|
|
|
data: Dict[str, Any],
|
2025-10-30 16:53:50 +08:00
|
|
|
|
db: AsyncSession,
|
|
|
|
|
|
user_ai_service: AIService
|
2025-10-30 11:14:43 +08:00
|
|
|
|
) -> AsyncGenerator[str, None]:
|
2025-11-07 22:14:20 +08:00
|
|
|
|
"""世界构建流式生成器 - 支持MCP工具增强"""
|
2025-10-30 11:14:43 +08:00
|
|
|
|
# 标记数据库会话是否已提交
|
|
|
|
|
|
db_committed = False
|
|
|
|
|
|
try:
|
|
|
|
|
|
# 发送开始消息
|
|
|
|
|
|
yield await SSEResponse.send_progress("开始生成世界观...", 10)
|
|
|
|
|
|
|
|
|
|
|
|
# 提取参数
|
|
|
|
|
|
title = data.get("title")
|
|
|
|
|
|
description = data.get("description")
|
|
|
|
|
|
theme = data.get("theme")
|
|
|
|
|
|
genre = data.get("genre")
|
|
|
|
|
|
narrative_perspective = data.get("narrative_perspective")
|
|
|
|
|
|
target_words = data.get("target_words")
|
|
|
|
|
|
chapter_count = data.get("chapter_count")
|
|
|
|
|
|
character_count = data.get("character_count")
|
|
|
|
|
|
provider = data.get("provider")
|
|
|
|
|
|
model = data.get("model")
|
2025-11-07 22:14:20 +08:00
|
|
|
|
enable_mcp = data.get("enable_mcp", True) # 默认启用MCP
|
|
|
|
|
|
user_id = data.get("user_id") # 从中间件注入
|
2025-10-30 11:14:43 +08:00
|
|
|
|
|
|
|
|
|
|
if not title or not description or not theme or not genre:
|
|
|
|
|
|
yield await SSEResponse.send_error("title、description、theme 和 genre 是必需的参数", 400)
|
|
|
|
|
|
return
|
|
|
|
|
|
|
2025-11-07 22:14:20 +08:00
|
|
|
|
# 获取基础提示词
|
|
|
|
|
|
yield await SSEResponse.send_progress("准备AI提示词...", 15)
|
|
|
|
|
|
base_prompt = prompt_service.get_world_building_prompt(
|
2025-10-30 11:14:43 +08:00
|
|
|
|
title=title,
|
|
|
|
|
|
theme=theme,
|
|
|
|
|
|
genre=genre
|
|
|
|
|
|
)
|
|
|
|
|
|
|
2025-11-07 22:14:20 +08:00
|
|
|
|
# MCP工具增强:收集参考资料
|
|
|
|
|
|
reference_materials = ""
|
|
|
|
|
|
if enable_mcp and user_id:
|
|
|
|
|
|
try:
|
|
|
|
|
|
yield await SSEResponse.send_progress("🔍 尝试使用MCP工具收集参考资料...", 18)
|
|
|
|
|
|
|
|
|
|
|
|
# 直接调用MCP增强的AI,内部会自动检查和加载工具
|
|
|
|
|
|
# 构建资料收集提示词
|
|
|
|
|
|
planning_prompt = f"""你正在为小说《{title}》设计世界观。
|
|
|
|
|
|
|
|
|
|
|
|
【小说信息】
|
|
|
|
|
|
- 题材:{genre}
|
|
|
|
|
|
- 主题:{theme}
|
|
|
|
|
|
- 简介:{description}
|
|
|
|
|
|
|
|
|
|
|
|
【任务】
|
|
|
|
|
|
请使用可用工具搜索相关背景资料,帮助构建更真实、更有深度的世界观设定。
|
|
|
|
|
|
你可以查询:
|
|
|
|
|
|
1. 历史背景(如果是历史题材)
|
|
|
|
|
|
2. 地理环境和文化特征
|
|
|
|
|
|
3. 相关领域的专业知识
|
|
|
|
|
|
4. 类似作品的设定参考
|
|
|
|
|
|
|
|
|
|
|
|
请根据题材特点,有针对性地查询2-3个关键问题。"""
|
|
|
|
|
|
|
|
|
|
|
|
# 调用MCP增强的AI(非流式,最多2轮工具调用)
|
|
|
|
|
|
planning_result = await user_ai_service.generate_text_with_mcp(
|
|
|
|
|
|
prompt=planning_prompt,
|
|
|
|
|
|
user_id=user_id,
|
|
|
|
|
|
db_session=db,
|
|
|
|
|
|
enable_mcp=True,
|
|
|
|
|
|
max_tool_rounds=2,
|
|
|
|
|
|
tool_choice="auto",
|
|
|
|
|
|
provider=None,
|
|
|
|
|
|
model=None
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
# 提取参考资料
|
|
|
|
|
|
if planning_result.get("tool_calls_made", 0) > 0:
|
|
|
|
|
|
yield await SSEResponse.send_progress(
|
|
|
|
|
|
f"✅ MCP工具调用成功({planning_result['tool_calls_made']}次)",
|
|
|
|
|
|
25
|
|
|
|
|
|
)
|
|
|
|
|
|
reference_materials = planning_result.get("content", "")
|
|
|
|
|
|
else:
|
|
|
|
|
|
yield await SSEResponse.send_progress("ℹ️ 未使用MCP工具(无可用工具或不需要)", 25)
|
|
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
logger.warning(f"MCP工具调用失败(降级处理): {e}")
|
|
|
|
|
|
yield await SSEResponse.send_progress("⚠️ MCP工具暂时不可用,使用基础模式", 25)
|
2025-10-30 11:14:43 +08:00
|
|
|
|
|
2025-11-07 22:14:20 +08:00
|
|
|
|
# 构建增强提示词
|
|
|
|
|
|
if reference_materials:
|
|
|
|
|
|
enhanced_prompt = f"""{base_prompt}
|
|
|
|
|
|
|
|
|
|
|
|
【参考资料】
|
|
|
|
|
|
以下是通过MCP工具收集的真实背景资料,请参考这些信息构建更真实的世界观:
|
|
|
|
|
|
|
|
|
|
|
|
{reference_materials}
|
|
|
|
|
|
|
|
|
|
|
|
请结合上述资料,生成符合历史/现实的世界观设定。"""
|
|
|
|
|
|
final_prompt = enhanced_prompt
|
|
|
|
|
|
yield await SSEResponse.send_progress("💡 已整合参考资料,开始生成世界观...", 30)
|
|
|
|
|
|
else:
|
|
|
|
|
|
final_prompt = base_prompt
|
|
|
|
|
|
yield await SSEResponse.send_progress("正在调用AI生成...", 30)
|
|
|
|
|
|
|
|
|
|
|
|
# 流式生成世界观
|
2025-10-30 11:14:43 +08:00
|
|
|
|
accumulated_text = ""
|
|
|
|
|
|
chunk_count = 0
|
|
|
|
|
|
|
2025-10-30 16:53:50 +08:00
|
|
|
|
async for chunk in user_ai_service.generate_text_stream(
|
2025-11-07 22:14:20 +08:00
|
|
|
|
prompt=final_prompt,
|
2025-10-30 11:14:43 +08:00
|
|
|
|
provider=provider,
|
|
|
|
|
|
model=model
|
|
|
|
|
|
):
|
|
|
|
|
|
chunk_count += 1
|
|
|
|
|
|
accumulated_text += chunk
|
|
|
|
|
|
|
|
|
|
|
|
# 发送内容块
|
|
|
|
|
|
yield await SSEResponse.send_chunk(chunk)
|
|
|
|
|
|
|
|
|
|
|
|
# 定期更新进度
|
|
|
|
|
|
if chunk_count % 5 == 0:
|
|
|
|
|
|
progress = min(30 + (chunk_count // 5), 70)
|
|
|
|
|
|
yield await SSEResponse.send_progress(f"生成中... ({len(accumulated_text)}字符)", progress)
|
|
|
|
|
|
|
|
|
|
|
|
# 每20个块发送心跳
|
|
|
|
|
|
if chunk_count % 20 == 0:
|
|
|
|
|
|
yield await SSEResponse.send_heartbeat()
|
|
|
|
|
|
|
|
|
|
|
|
# 解析结果
|
|
|
|
|
|
yield await SSEResponse.send_progress("解析AI返回结果...", 80)
|
|
|
|
|
|
|
|
|
|
|
|
world_data = {}
|
|
|
|
|
|
try:
|
|
|
|
|
|
cleaned_text = accumulated_text.strip()
|
2025-10-30 16:53:50 +08:00
|
|
|
|
|
|
|
|
|
|
# 移除markdown代码块标记
|
2025-10-30 11:14:43 +08:00
|
|
|
|
if cleaned_text.startswith('```json'):
|
2025-10-30 16:53:50 +08:00
|
|
|
|
cleaned_text = cleaned_text[7:].lstrip('\n\r')
|
|
|
|
|
|
elif cleaned_text.startswith('```'):
|
|
|
|
|
|
cleaned_text = cleaned_text[3:].lstrip('\n\r')
|
2025-10-30 11:14:43 +08:00
|
|
|
|
if cleaned_text.endswith('```'):
|
2025-10-30 16:53:50 +08:00
|
|
|
|
cleaned_text = cleaned_text[:-3].rstrip('\n\r')
|
2025-10-30 11:14:43 +08:00
|
|
|
|
cleaned_text = cleaned_text.strip()
|
|
|
|
|
|
|
|
|
|
|
|
world_data = json.loads(cleaned_text)
|
2025-10-30 16:53:50 +08:00
|
|
|
|
|
2025-10-30 11:14:43 +08:00
|
|
|
|
except json.JSONDecodeError as e:
|
2025-10-30 16:53:50 +08:00
|
|
|
|
logger.error(f"世界构建JSON解析失败: {e}")
|
2025-10-30 11:14:43 +08:00
|
|
|
|
world_data = {
|
2025-10-30 16:53:50 +08:00
|
|
|
|
"time_period": "AI返回格式错误,请重试",
|
2025-10-30 11:14:43 +08:00
|
|
|
|
"location": "AI返回格式错误,请重试",
|
|
|
|
|
|
"atmosphere": "AI返回格式错误,请重试",
|
|
|
|
|
|
"rules": "AI返回格式错误,请重试"
|
|
|
|
|
|
}
|
|
|
|
|
|
# 保存到数据库
|
|
|
|
|
|
yield await SSEResponse.send_progress("保存到数据库...", 90)
|
|
|
|
|
|
|
2025-11-10 21:16:55 +08:00
|
|
|
|
# 确保user_id存在
|
|
|
|
|
|
if not user_id:
|
|
|
|
|
|
yield await SSEResponse.send_error("用户ID缺失,无法创建项目", 401)
|
|
|
|
|
|
return
|
|
|
|
|
|
|
2025-10-30 11:14:43 +08:00
|
|
|
|
project = Project(
|
2025-11-10 21:16:55 +08:00
|
|
|
|
user_id=user_id, # 添加user_id字段
|
2025-10-30 11:14:43 +08:00
|
|
|
|
title=title,
|
|
|
|
|
|
description=description,
|
|
|
|
|
|
theme=theme,
|
|
|
|
|
|
genre=genre,
|
|
|
|
|
|
world_time_period=world_data.get("time_period"),
|
|
|
|
|
|
world_location=world_data.get("location"),
|
|
|
|
|
|
world_atmosphere=world_data.get("atmosphere"),
|
|
|
|
|
|
world_rules=world_data.get("rules"),
|
|
|
|
|
|
narrative_perspective=narrative_perspective,
|
|
|
|
|
|
target_words=target_words,
|
|
|
|
|
|
chapter_count=chapter_count,
|
|
|
|
|
|
character_count=character_count,
|
|
|
|
|
|
wizard_status="incomplete",
|
|
|
|
|
|
wizard_step=1,
|
|
|
|
|
|
status="planning"
|
|
|
|
|
|
)
|
|
|
|
|
|
db.add(project)
|
|
|
|
|
|
await db.commit()
|
|
|
|
|
|
await db.refresh(project)
|
|
|
|
|
|
|
2025-10-31 17:23:25 +08:00
|
|
|
|
# 自动设置默认写作风格为第一个全局预设风格
|
|
|
|
|
|
try:
|
|
|
|
|
|
result = await db.execute(
|
|
|
|
|
|
select(WritingStyle).where(
|
|
|
|
|
|
WritingStyle.project_id.is_(None),
|
|
|
|
|
|
WritingStyle.order_index == 1
|
|
|
|
|
|
).limit(1)
|
|
|
|
|
|
)
|
|
|
|
|
|
first_style = result.scalar_one_or_none()
|
|
|
|
|
|
|
|
|
|
|
|
if first_style:
|
|
|
|
|
|
default_style = ProjectDefaultStyle(
|
|
|
|
|
|
project_id=project.id,
|
|
|
|
|
|
style_id=first_style.id
|
|
|
|
|
|
)
|
|
|
|
|
|
db.add(default_style)
|
|
|
|
|
|
await db.commit()
|
|
|
|
|
|
logger.info(f"为项目 {project.id} 自动设置默认风格: {first_style.name}")
|
|
|
|
|
|
else:
|
|
|
|
|
|
logger.warning(f"未找到order_index=1的全局预设风格,项目 {project.id} 未设置默认风格")
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
logger.warning(f"设置默认写作风格失败: {e},不影响项目创建")
|
|
|
|
|
|
|
|
|
|
|
|
db_committed = True
|
|
|
|
|
|
|
2025-10-30 11:14:43 +08:00
|
|
|
|
# 发送最终结果
|
|
|
|
|
|
yield await SSEResponse.send_result({
|
|
|
|
|
|
"project_id": project.id,
|
|
|
|
|
|
"time_period": world_data.get("time_period"),
|
|
|
|
|
|
"location": world_data.get("location"),
|
|
|
|
|
|
"atmosphere": world_data.get("atmosphere"),
|
|
|
|
|
|
"rules": world_data.get("rules")
|
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
|
|
yield await SSEResponse.send_progress("完成!", 100, "success")
|
|
|
|
|
|
yield await SSEResponse.send_done()
|
|
|
|
|
|
|
|
|
|
|
|
except GeneratorExit:
|
|
|
|
|
|
# SSE连接断开,回滚未提交的事务
|
|
|
|
|
|
logger.warning("世界构建生成器被提前关闭")
|
|
|
|
|
|
if not db_committed and db.in_transaction():
|
|
|
|
|
|
await db.rollback()
|
|
|
|
|
|
logger.info("世界构建事务已回滚(GeneratorExit)")
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
logger.error(f"世界构建流式生成失败: {str(e)}")
|
|
|
|
|
|
# 异常时回滚事务
|
|
|
|
|
|
if not db_committed and db.in_transaction():
|
|
|
|
|
|
await db.rollback()
|
|
|
|
|
|
logger.info("世界构建事务已回滚(异常)")
|
|
|
|
|
|
yield await SSEResponse.send_error(f"生成失败: {str(e)}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@router.post("/world-building", summary="流式生成世界构建")
|
|
|
|
|
|
async def generate_world_building_stream(
|
2025-11-07 22:14:20 +08:00
|
|
|
|
request: Request,
|
2025-10-30 11:14:43 +08:00
|
|
|
|
data: Dict[str, Any],
|
2025-10-30 16:53:50 +08:00
|
|
|
|
db: AsyncSession = Depends(get_db),
|
|
|
|
|
|
user_ai_service: AIService = Depends(get_user_ai_service)
|
2025-10-30 11:14:43 +08:00
|
|
|
|
):
|
|
|
|
|
|
"""
|
|
|
|
|
|
使用SSE流式生成世界构建,避免超时
|
|
|
|
|
|
前端使用EventSource接收实时进度和结果
|
|
|
|
|
|
"""
|
2025-11-07 22:14:20 +08:00
|
|
|
|
# 从中间件注入user_id到data中
|
|
|
|
|
|
if hasattr(request.state, 'user_id'):
|
|
|
|
|
|
data['user_id'] = request.state.user_id
|
|
|
|
|
|
|
2025-10-30 16:53:50 +08:00
|
|
|
|
return create_sse_response(world_building_generator(data, db, user_ai_service))
|
2025-10-30 11:14:43 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def characters_generator(
|
|
|
|
|
|
data: Dict[str, Any],
|
2025-10-30 16:53:50 +08:00
|
|
|
|
db: AsyncSession,
|
|
|
|
|
|
user_ai_service: AIService
|
2025-10-30 11:14:43 +08:00
|
|
|
|
) -> AsyncGenerator[str, None]:
|
2025-11-07 22:14:20 +08:00
|
|
|
|
"""角色批量生成流式生成器 - 优化版:分批+重试+MCP工具增强"""
|
2025-10-30 11:14:43 +08:00
|
|
|
|
db_committed = False
|
|
|
|
|
|
try:
|
|
|
|
|
|
yield await SSEResponse.send_progress("开始生成角色...", 5)
|
|
|
|
|
|
|
|
|
|
|
|
project_id = data.get("project_id")
|
|
|
|
|
|
count = data.get("count", 5)
|
|
|
|
|
|
world_context = data.get("world_context")
|
|
|
|
|
|
theme = data.get("theme", "")
|
|
|
|
|
|
genre = data.get("genre", "")
|
|
|
|
|
|
requirements = data.get("requirements", "")
|
|
|
|
|
|
provider = data.get("provider")
|
|
|
|
|
|
model = data.get("model")
|
2025-11-07 22:14:20 +08:00
|
|
|
|
enable_mcp = data.get("enable_mcp", True) # 默认启用MCP
|
|
|
|
|
|
user_id = data.get("user_id") # 从中间件注入
|
2025-10-30 11:14:43 +08:00
|
|
|
|
|
|
|
|
|
|
# 验证项目
|
|
|
|
|
|
yield await SSEResponse.send_progress("验证项目...", 10)
|
|
|
|
|
|
result = await db.execute(
|
|
|
|
|
|
select(Project).where(Project.id == project_id)
|
|
|
|
|
|
)
|
|
|
|
|
|
project = result.scalar_one_or_none()
|
|
|
|
|
|
if not project:
|
|
|
|
|
|
yield await SSEResponse.send_error("项目不存在", 404)
|
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
|
|
project.wizard_step = 2
|
|
|
|
|
|
|
|
|
|
|
|
world_context = world_context or {
|
|
|
|
|
|
"time_period": project.world_time_period or "未设定",
|
|
|
|
|
|
"location": project.world_location or "未设定",
|
|
|
|
|
|
"atmosphere": project.world_atmosphere or "未设定",
|
|
|
|
|
|
"rules": project.world_rules or "未设定"
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2025-11-07 22:14:20 +08:00
|
|
|
|
# MCP工具增强:收集角色参考资料
|
|
|
|
|
|
character_reference_materials = ""
|
|
|
|
|
|
if enable_mcp and user_id:
|
|
|
|
|
|
try:
|
|
|
|
|
|
yield await SSEResponse.send_progress("🔍 尝试使用MCP工具收集角色参考资料...", 8)
|
|
|
|
|
|
|
|
|
|
|
|
# 构建角色资料收集提示词
|
|
|
|
|
|
planning_prompt = f"""你正在为小说《{project.title}》设计角色。
|
|
|
|
|
|
|
|
|
|
|
|
【小说信息】
|
|
|
|
|
|
- 题材:{genre or project.genre}
|
|
|
|
|
|
- 主题:{theme or project.theme}
|
|
|
|
|
|
- 时代背景:{world_context.get('time_period', '未设定')}
|
|
|
|
|
|
- 地理位置:{world_context.get('location', '未设定')}
|
|
|
|
|
|
|
|
|
|
|
|
【任务】
|
|
|
|
|
|
请使用可用工具搜索相关参考资料,帮助设计更真实、更有深度的角色。
|
|
|
|
|
|
你可以查询:
|
|
|
|
|
|
1. 该时代/地域的真实历史人物特征
|
|
|
|
|
|
2. 文化背景和社会习俗
|
|
|
|
|
|
3. 职业特点和生活方式
|
|
|
|
|
|
4. 相关领域的人物原型
|
|
|
|
|
|
|
|
|
|
|
|
请根据题材特点,有针对性地查询1-2个关键问题。"""
|
|
|
|
|
|
|
|
|
|
|
|
# 调用MCP增强的AI(非流式,最多2轮工具调用)
|
|
|
|
|
|
planning_result = await user_ai_service.generate_text_with_mcp(
|
|
|
|
|
|
prompt=planning_prompt,
|
|
|
|
|
|
user_id=user_id,
|
|
|
|
|
|
db_session=db,
|
|
|
|
|
|
enable_mcp=True,
|
|
|
|
|
|
max_tool_rounds=2,
|
|
|
|
|
|
tool_choice="auto",
|
|
|
|
|
|
provider=None,
|
|
|
|
|
|
model=None
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
# 提取参考资料
|
|
|
|
|
|
if planning_result.get("tool_calls_made", 0) > 0:
|
|
|
|
|
|
yield await SSEResponse.send_progress(
|
|
|
|
|
|
f"✅ MCP工具调用成功({planning_result['tool_calls_made']}次)",
|
|
|
|
|
|
12
|
|
|
|
|
|
)
|
|
|
|
|
|
character_reference_materials = planning_result.get("content", "")
|
|
|
|
|
|
else:
|
|
|
|
|
|
yield await SSEResponse.send_progress("ℹ️ 未使用MCP工具(无可用工具或不需要)", 12)
|
|
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
logger.warning(f"MCP工具调用失败(降级处理): {e}")
|
|
|
|
|
|
yield await SSEResponse.send_progress("⚠️ MCP工具暂时不可用,使用基础模式", 12)
|
|
|
|
|
|
|
2025-10-30 11:14:43 +08:00
|
|
|
|
# 优化的分批策略:每批生成3个,平衡效率和成功率
|
|
|
|
|
|
BATCH_SIZE = 3 # 每批生成3个角色
|
|
|
|
|
|
MAX_RETRIES = 3 # 每批最多重试3次
|
|
|
|
|
|
all_characters = []
|
|
|
|
|
|
total_batches = (count + BATCH_SIZE - 1) // BATCH_SIZE
|
|
|
|
|
|
|
|
|
|
|
|
for batch_idx in range(total_batches):
|
|
|
|
|
|
# 精确计算当前批次应该生成的数量
|
|
|
|
|
|
remaining = count - len(all_characters)
|
|
|
|
|
|
current_batch_size = min(BATCH_SIZE, remaining)
|
|
|
|
|
|
|
|
|
|
|
|
# 如果已经达到目标数量,直接退出
|
|
|
|
|
|
if current_batch_size <= 0:
|
|
|
|
|
|
logger.info(f"已生成{len(all_characters)}个角色,达到目标数量{count}")
|
|
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
|
|
batch_progress = 15 + (batch_idx * 60 // total_batches)
|
|
|
|
|
|
|
|
|
|
|
|
# 重试逻辑
|
|
|
|
|
|
retry_count = 0
|
|
|
|
|
|
batch_success = False
|
2025-11-03 15:28:51 +08:00
|
|
|
|
batch_error_message = ""
|
2025-10-30 11:14:43 +08:00
|
|
|
|
|
|
|
|
|
|
while retry_count < MAX_RETRIES and not batch_success:
|
|
|
|
|
|
try:
|
|
|
|
|
|
retry_suffix = f" (重试{retry_count}/{MAX_RETRIES})" if retry_count > 0 else ""
|
|
|
|
|
|
yield await SSEResponse.send_progress(
|
|
|
|
|
|
f"生成第{batch_idx+1}/{total_batches}批角色 ({current_batch_size}个){retry_suffix}...",
|
|
|
|
|
|
batch_progress
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
# 构建批次要求 - 包含已生成角色信息保持连贯
|
|
|
|
|
|
existing_chars_context = ""
|
|
|
|
|
|
if all_characters:
|
|
|
|
|
|
existing_chars_context = "\n\n【已生成的角色】:\n"
|
|
|
|
|
|
for char in all_characters:
|
|
|
|
|
|
existing_chars_context += f"- {char.get('name')}: {char.get('role_type', '未知')}, {char.get('personality', '暂无')[:50]}...\n"
|
|
|
|
|
|
existing_chars_context += "\n请确保新角色与已有角色形成合理的关系网络和互动。\n"
|
|
|
|
|
|
|
|
|
|
|
|
# 构建精确的批次要求,明确告诉AI要生成的数量
|
|
|
|
|
|
if batch_idx == 0:
|
|
|
|
|
|
if current_batch_size == 1:
|
|
|
|
|
|
batch_requirements = f"{requirements}\n请生成1个主角(protagonist)"
|
|
|
|
|
|
else:
|
|
|
|
|
|
batch_requirements = f"{requirements}\n请精确生成{current_batch_size}个角色:1个主角(protagonist)和{current_batch_size-1}个核心配角(supporting)"
|
|
|
|
|
|
else:
|
|
|
|
|
|
batch_requirements = f"{requirements}\n请精确生成{current_batch_size}个角色{existing_chars_context}"
|
|
|
|
|
|
if batch_idx == total_batches - 1:
|
|
|
|
|
|
batch_requirements += "\n可以包含组织或反派(antagonist)"
|
|
|
|
|
|
else:
|
|
|
|
|
|
batch_requirements += "\n主要是配角(supporting)和反派(antagonist)"
|
|
|
|
|
|
|
2025-11-07 22:14:20 +08:00
|
|
|
|
# 构建基础提示词
|
|
|
|
|
|
base_prompt = prompt_service.get_characters_batch_prompt(
|
2025-10-30 11:14:43 +08:00
|
|
|
|
count=current_batch_size, # 传递精确数量
|
|
|
|
|
|
time_period=world_context.get("time_period", ""),
|
|
|
|
|
|
location=world_context.get("location", ""),
|
|
|
|
|
|
atmosphere=world_context.get("atmosphere", ""),
|
|
|
|
|
|
rules=world_context.get("rules", ""),
|
|
|
|
|
|
theme=theme or project.theme or "",
|
|
|
|
|
|
genre=genre or project.genre or "",
|
|
|
|
|
|
requirements=batch_requirements
|
|
|
|
|
|
)
|
|
|
|
|
|
|
2025-11-07 22:14:20 +08:00
|
|
|
|
# 如果有MCP参考资料,增强提示词
|
|
|
|
|
|
if character_reference_materials:
|
|
|
|
|
|
prompt = f"""{base_prompt}
|
|
|
|
|
|
|
|
|
|
|
|
【参考资料】
|
|
|
|
|
|
以下是通过MCP工具收集的真实背景资料,请参考这些信息设计更真实的角色:
|
|
|
|
|
|
|
|
|
|
|
|
{character_reference_materials}
|
|
|
|
|
|
|
|
|
|
|
|
请结合上述资料,设计符合历史/文化背景的角色。"""
|
|
|
|
|
|
else:
|
|
|
|
|
|
prompt = base_prompt
|
|
|
|
|
|
|
2025-10-30 11:14:43 +08:00
|
|
|
|
# 流式生成
|
|
|
|
|
|
accumulated_text = ""
|
2025-10-30 16:53:50 +08:00
|
|
|
|
async for chunk in user_ai_service.generate_text_stream(
|
2025-10-30 11:14:43 +08:00
|
|
|
|
prompt=prompt,
|
|
|
|
|
|
provider=provider,
|
|
|
|
|
|
model=model
|
|
|
|
|
|
):
|
|
|
|
|
|
accumulated_text += chunk
|
|
|
|
|
|
yield await SSEResponse.send_chunk(chunk)
|
|
|
|
|
|
|
|
|
|
|
|
# 解析批次结果
|
|
|
|
|
|
cleaned_text = accumulated_text.strip()
|
2025-10-30 16:53:50 +08:00
|
|
|
|
# 移除markdown代码块标记
|
2025-10-30 11:14:43 +08:00
|
|
|
|
if cleaned_text.startswith('```json'):
|
2025-10-30 16:53:50 +08:00
|
|
|
|
cleaned_text = cleaned_text[7:].lstrip('\n\r')
|
|
|
|
|
|
elif cleaned_text.startswith('```'):
|
|
|
|
|
|
cleaned_text = cleaned_text[3:].lstrip('\n\r')
|
2025-10-30 11:14:43 +08:00
|
|
|
|
if cleaned_text.endswith('```'):
|
2025-10-30 16:53:50 +08:00
|
|
|
|
cleaned_text = cleaned_text[:-3].rstrip('\n\r')
|
2025-10-30 11:14:43 +08:00
|
|
|
|
cleaned_text = cleaned_text.strip()
|
|
|
|
|
|
|
|
|
|
|
|
characters_data = json.loads(cleaned_text)
|
|
|
|
|
|
if not isinstance(characters_data, list):
|
|
|
|
|
|
characters_data = [characters_data]
|
|
|
|
|
|
|
2025-11-03 15:28:51 +08:00
|
|
|
|
# 严格验证生成数量是否精确匹配
|
2025-10-30 11:14:43 +08:00
|
|
|
|
if len(characters_data) != current_batch_size:
|
2025-11-03 15:28:51 +08:00
|
|
|
|
error_msg = f"批次{batch_idx+1}生成数量不正确: 期望{current_batch_size}个, 实际{len(characters_data)}个"
|
|
|
|
|
|
logger.error(error_msg)
|
2025-10-30 11:14:43 +08:00
|
|
|
|
|
2025-11-03 15:28:51 +08:00
|
|
|
|
# 如果还有重试机会,继续重试
|
|
|
|
|
|
if retry_count < MAX_RETRIES - 1:
|
|
|
|
|
|
retry_count += 1
|
2025-10-30 11:14:43 +08:00
|
|
|
|
yield await SSEResponse.send_progress(
|
2025-11-03 15:28:51 +08:00
|
|
|
|
f"⚠️ {error_msg},准备重试...",
|
2025-10-30 11:14:43 +08:00
|
|
|
|
batch_progress,
|
|
|
|
|
|
"warning"
|
|
|
|
|
|
)
|
2025-11-03 15:28:51 +08:00
|
|
|
|
continue
|
|
|
|
|
|
else:
|
|
|
|
|
|
# 最后一次重试仍失败,直接返回错误
|
|
|
|
|
|
yield await SSEResponse.send_error(error_msg)
|
|
|
|
|
|
return
|
2025-10-30 11:14:43 +08:00
|
|
|
|
|
|
|
|
|
|
all_characters.extend(characters_data)
|
|
|
|
|
|
batch_success = True
|
|
|
|
|
|
logger.info(f"批次{batch_idx+1}成功添加{len(characters_data)}个角色,当前总数{len(all_characters)}/{count}")
|
|
|
|
|
|
|
|
|
|
|
|
except json.JSONDecodeError as e:
|
|
|
|
|
|
logger.error(f"批次{batch_idx+1}解析失败(尝试{retry_count+1}/{MAX_RETRIES}): {e}")
|
2025-11-03 15:28:51 +08:00
|
|
|
|
batch_error_message = f"JSON解析失败: {str(e)}"
|
2025-10-30 11:14:43 +08:00
|
|
|
|
retry_count += 1
|
|
|
|
|
|
if retry_count < MAX_RETRIES:
|
|
|
|
|
|
yield await SSEResponse.send_progress(
|
|
|
|
|
|
f"解析失败,准备重试...",
|
|
|
|
|
|
batch_progress,
|
|
|
|
|
|
"warning"
|
|
|
|
|
|
)
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
logger.error(f"批次{batch_idx+1}生成异常(尝试{retry_count+1}/{MAX_RETRIES}): {e}")
|
2025-11-03 15:28:51 +08:00
|
|
|
|
batch_error_message = f"生成异常: {str(e)}"
|
2025-10-30 11:14:43 +08:00
|
|
|
|
retry_count += 1
|
|
|
|
|
|
if retry_count < MAX_RETRIES:
|
|
|
|
|
|
yield await SSEResponse.send_progress(
|
|
|
|
|
|
f"生成异常,准备重试...",
|
|
|
|
|
|
batch_progress,
|
|
|
|
|
|
"warning"
|
|
|
|
|
|
)
|
2025-11-03 15:28:51 +08:00
|
|
|
|
|
|
|
|
|
|
# 检查批次是否成功
|
|
|
|
|
|
if not batch_success:
|
|
|
|
|
|
error_msg = f"批次{batch_idx+1}在{MAX_RETRIES}次重试后仍然失败"
|
|
|
|
|
|
if batch_error_message:
|
|
|
|
|
|
error_msg += f": {batch_error_message}"
|
|
|
|
|
|
logger.error(error_msg)
|
|
|
|
|
|
yield await SSEResponse.send_error(error_msg)
|
|
|
|
|
|
return
|
2025-10-30 11:14:43 +08:00
|
|
|
|
|
|
|
|
|
|
# 保存到数据库 - 分阶段处理以保证一致性
|
|
|
|
|
|
yield await SSEResponse.send_progress("验证角色数据...", 82)
|
|
|
|
|
|
|
|
|
|
|
|
# 预处理:构建本批次所有实体的名称集合
|
|
|
|
|
|
valid_entity_names = set()
|
|
|
|
|
|
valid_organization_names = set()
|
|
|
|
|
|
|
|
|
|
|
|
for char_data in all_characters:
|
|
|
|
|
|
entity_name = char_data.get("name", "")
|
|
|
|
|
|
if entity_name:
|
|
|
|
|
|
valid_entity_names.add(entity_name)
|
|
|
|
|
|
if char_data.get("is_organization", False):
|
|
|
|
|
|
valid_organization_names.add(entity_name)
|
|
|
|
|
|
|
|
|
|
|
|
# 清理幻觉引用
|
|
|
|
|
|
cleaned_count = 0
|
|
|
|
|
|
for char_data in all_characters:
|
|
|
|
|
|
# 清理关系数组中的无效引用
|
|
|
|
|
|
if "relationships_array" in char_data and isinstance(char_data["relationships_array"], list):
|
|
|
|
|
|
original_rels = char_data["relationships_array"]
|
|
|
|
|
|
valid_rels = []
|
|
|
|
|
|
for rel in original_rels:
|
|
|
|
|
|
target_name = rel.get("target_character_name", "")
|
|
|
|
|
|
if target_name in valid_entity_names:
|
|
|
|
|
|
valid_rels.append(rel)
|
|
|
|
|
|
else:
|
|
|
|
|
|
cleaned_count += 1
|
|
|
|
|
|
logger.debug(f" 🧹 清理无效关系引用:{char_data.get('name')} -> {target_name}")
|
|
|
|
|
|
char_data["relationships_array"] = valid_rels
|
|
|
|
|
|
|
|
|
|
|
|
# 清理组织成员关系中的无效引用
|
|
|
|
|
|
if "organization_memberships" in char_data and isinstance(char_data["organization_memberships"], list):
|
|
|
|
|
|
original_orgs = char_data["organization_memberships"]
|
|
|
|
|
|
valid_orgs = []
|
|
|
|
|
|
for org_mem in original_orgs:
|
|
|
|
|
|
org_name = org_mem.get("organization_name", "")
|
|
|
|
|
|
if org_name in valid_organization_names:
|
|
|
|
|
|
valid_orgs.append(org_mem)
|
|
|
|
|
|
else:
|
|
|
|
|
|
cleaned_count += 1
|
|
|
|
|
|
logger.debug(f" 🧹 清理无效组织引用:{char_data.get('name')} -> {org_name}")
|
|
|
|
|
|
char_data["organization_memberships"] = valid_orgs
|
|
|
|
|
|
|
|
|
|
|
|
if cleaned_count > 0:
|
|
|
|
|
|
logger.info(f"✨ 清理了{cleaned_count}个AI幻觉引用")
|
|
|
|
|
|
yield await SSEResponse.send_progress(f"已清理{cleaned_count}个无效引用", 84)
|
|
|
|
|
|
|
|
|
|
|
|
yield await SSEResponse.send_progress("保存角色到数据库...", 85)
|
|
|
|
|
|
|
|
|
|
|
|
# 第一阶段:创建所有Character记录
|
|
|
|
|
|
created_characters = []
|
|
|
|
|
|
character_name_to_obj = {} # 名称到对象的映射,用于后续关系创建
|
|
|
|
|
|
|
|
|
|
|
|
for char_data in all_characters:
|
|
|
|
|
|
# 从relationships_array提取文本描述以保持向后兼容
|
|
|
|
|
|
relationships_text = ""
|
|
|
|
|
|
relationships_array = char_data.get("relationships_array", [])
|
|
|
|
|
|
if relationships_array and isinstance(relationships_array, list):
|
|
|
|
|
|
# 将关系数组转换为可读文本
|
|
|
|
|
|
rel_descriptions = []
|
|
|
|
|
|
for rel in relationships_array:
|
|
|
|
|
|
target = rel.get("target_character_name", "未知")
|
|
|
|
|
|
rel_type = rel.get("relationship_type", "关系")
|
|
|
|
|
|
desc = rel.get("description", "")
|
|
|
|
|
|
rel_descriptions.append(f"{target}({rel_type}): {desc}")
|
|
|
|
|
|
relationships_text = "; ".join(rel_descriptions)
|
|
|
|
|
|
# 兼容旧格式
|
|
|
|
|
|
elif isinstance(char_data.get("relationships"), dict):
|
|
|
|
|
|
relationships_text = json.dumps(char_data.get("relationships"), ensure_ascii=False)
|
|
|
|
|
|
elif isinstance(char_data.get("relationships"), str):
|
|
|
|
|
|
relationships_text = char_data.get("relationships")
|
|
|
|
|
|
|
2025-11-05 16:22:14 +08:00
|
|
|
|
# 判断是否为组织
|
|
|
|
|
|
is_organization = char_data.get("is_organization", False)
|
|
|
|
|
|
|
2025-10-30 11:14:43 +08:00
|
|
|
|
character = Character(
|
|
|
|
|
|
project_id=project_id,
|
|
|
|
|
|
name=char_data.get("name", "未命名角色"),
|
2025-11-05 16:22:14 +08:00
|
|
|
|
age=str(char_data.get("age", "")) if not is_organization else None,
|
|
|
|
|
|
gender=char_data.get("gender") if not is_organization else None,
|
|
|
|
|
|
is_organization=is_organization,
|
2025-10-30 11:14:43 +08:00
|
|
|
|
role_type=char_data.get("role_type", "supporting"),
|
|
|
|
|
|
personality=char_data.get("personality", ""),
|
|
|
|
|
|
background=char_data.get("background", ""),
|
|
|
|
|
|
appearance=char_data.get("appearance", ""),
|
|
|
|
|
|
relationships=relationships_text,
|
2025-11-05 16:22:14 +08:00
|
|
|
|
organization_type=char_data.get("organization_type") if is_organization else None,
|
|
|
|
|
|
organization_purpose=char_data.get("organization_purpose") if is_organization else None,
|
|
|
|
|
|
organization_members=json.dumps(char_data.get("organization_members", []), ensure_ascii=False) if is_organization else None,
|
|
|
|
|
|
traits=json.dumps(char_data.get("traits", []), ensure_ascii=False) if char_data.get("traits") else None
|
2025-10-30 11:14:43 +08:00
|
|
|
|
)
|
|
|
|
|
|
db.add(character)
|
|
|
|
|
|
created_characters.append((character, char_data))
|
|
|
|
|
|
|
|
|
|
|
|
await db.flush() # 获取所有角色的ID
|
|
|
|
|
|
|
|
|
|
|
|
# 刷新并建立名称映射
|
|
|
|
|
|
for character, _ in created_characters:
|
|
|
|
|
|
await db.refresh(character)
|
|
|
|
|
|
character_name_to_obj[character.name] = character
|
|
|
|
|
|
logger.info(f"向导创建角色:{character.name} (ID: {character.id}, 是否组织: {character.is_organization})")
|
|
|
|
|
|
|
|
|
|
|
|
# 为is_organization=True的角色创建Organization记录
|
|
|
|
|
|
yield await SSEResponse.send_progress("创建组织记录...", 87)
|
|
|
|
|
|
organization_name_to_obj = {} # 组织名称到Organization对象的映射
|
|
|
|
|
|
|
|
|
|
|
|
for character, char_data in created_characters:
|
|
|
|
|
|
if character.is_organization:
|
|
|
|
|
|
# 检查是否已存在Organization记录
|
|
|
|
|
|
org_check = await db.execute(
|
|
|
|
|
|
select(Organization).where(Organization.character_id == character.id)
|
|
|
|
|
|
)
|
|
|
|
|
|
existing_org = org_check.scalar_one_or_none()
|
|
|
|
|
|
|
|
|
|
|
|
if not existing_org:
|
|
|
|
|
|
# 创建Organization记录
|
|
|
|
|
|
org = Organization(
|
|
|
|
|
|
character_id=character.id,
|
|
|
|
|
|
project_id=project_id,
|
|
|
|
|
|
member_count=0, # 初始为0,后续添加成员时会更新
|
2025-11-05 16:22:14 +08:00
|
|
|
|
power_level=char_data.get("power_level", 50),
|
2025-10-30 11:14:43 +08:00
|
|
|
|
location=char_data.get("location"),
|
2025-11-05 16:22:14 +08:00
|
|
|
|
motto=char_data.get("motto"),
|
|
|
|
|
|
color=char_data.get("color")
|
2025-10-30 11:14:43 +08:00
|
|
|
|
)
|
|
|
|
|
|
db.add(org)
|
|
|
|
|
|
logger.info(f"向导创建组织记录:{character.name}")
|
|
|
|
|
|
else:
|
|
|
|
|
|
org = existing_org
|
|
|
|
|
|
|
|
|
|
|
|
# 建立组织名称映射(无论是新建还是已存在)
|
|
|
|
|
|
organization_name_to_obj[character.name] = org
|
|
|
|
|
|
|
|
|
|
|
|
await db.flush() # 确保Organization记录有ID
|
|
|
|
|
|
|
|
|
|
|
|
# 刷新角色以获取ID
|
|
|
|
|
|
for character, _ in created_characters:
|
|
|
|
|
|
await db.refresh(character)
|
|
|
|
|
|
|
|
|
|
|
|
# 第三阶段:创建角色间的关系
|
|
|
|
|
|
yield await SSEResponse.send_progress("创建角色关系...", 90)
|
|
|
|
|
|
relationships_created = 0
|
|
|
|
|
|
|
|
|
|
|
|
for character, char_data in created_characters:
|
|
|
|
|
|
# 跳过组织实体的角色关系处理(组织通过成员关系关联)
|
|
|
|
|
|
if character.is_organization:
|
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
|
|
# 处理relationships数组
|
|
|
|
|
|
relationships_data = char_data.get("relationships_array", [])
|
|
|
|
|
|
if not relationships_data and isinstance(char_data.get("relationships"), list):
|
|
|
|
|
|
relationships_data = char_data.get("relationships")
|
|
|
|
|
|
|
|
|
|
|
|
if relationships_data and isinstance(relationships_data, list):
|
|
|
|
|
|
for rel in relationships_data:
|
|
|
|
|
|
try:
|
|
|
|
|
|
target_name = rel.get("target_character_name")
|
|
|
|
|
|
if not target_name:
|
|
|
|
|
|
logger.debug(f" ⚠️ {character.name}的关系缺少target_character_name,跳过")
|
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
|
|
# 使用名称映射快速查找
|
|
|
|
|
|
target_char = character_name_to_obj.get(target_name)
|
|
|
|
|
|
|
|
|
|
|
|
if target_char:
|
|
|
|
|
|
# 避免创建重复关系
|
|
|
|
|
|
existing_rel = await db.execute(
|
|
|
|
|
|
select(CharacterRelationship).where(
|
|
|
|
|
|
CharacterRelationship.project_id == project_id,
|
|
|
|
|
|
CharacterRelationship.character_from_id == character.id,
|
|
|
|
|
|
CharacterRelationship.character_to_id == target_char.id
|
|
|
|
|
|
)
|
|
|
|
|
|
)
|
|
|
|
|
|
if existing_rel.scalar_one_or_none():
|
|
|
|
|
|
logger.debug(f" ℹ️ 关系已存在:{character.name} -> {target_name}")
|
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
|
|
relationship = CharacterRelationship(
|
|
|
|
|
|
project_id=project_id,
|
|
|
|
|
|
character_from_id=character.id,
|
|
|
|
|
|
character_to_id=target_char.id,
|
|
|
|
|
|
relationship_name=rel.get("relationship_type", "未知关系"),
|
|
|
|
|
|
intimacy_level=rel.get("intimacy_level", 50),
|
|
|
|
|
|
description=rel.get("description", ""),
|
|
|
|
|
|
started_at=rel.get("started_at"),
|
|
|
|
|
|
source="ai"
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
# 匹配预定义关系类型
|
|
|
|
|
|
rel_type_result = await db.execute(
|
|
|
|
|
|
select(RelationshipType).where(
|
|
|
|
|
|
RelationshipType.name == rel.get("relationship_type")
|
|
|
|
|
|
)
|
|
|
|
|
|
)
|
|
|
|
|
|
rel_type = rel_type_result.scalar_one_or_none()
|
|
|
|
|
|
if rel_type:
|
|
|
|
|
|
relationship.relationship_type_id = rel_type.id
|
|
|
|
|
|
|
|
|
|
|
|
db.add(relationship)
|
|
|
|
|
|
relationships_created += 1
|
|
|
|
|
|
logger.info(f" ✅ 向导创建关系:{character.name} -> {target_name} ({rel.get('relationship_type')})")
|
|
|
|
|
|
else:
|
|
|
|
|
|
logger.warning(f" ⚠️ 目标角色不存在:{character.name} -> {target_name}(可能是AI幻觉)")
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
logger.warning(f" ❌ 向导创建关系失败:{character.name} - {str(e)}")
|
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
|
|
# 第四阶段:创建组织成员关系
|
|
|
|
|
|
yield await SSEResponse.send_progress("创建组织成员关系...", 93)
|
|
|
|
|
|
members_created = 0
|
|
|
|
|
|
|
|
|
|
|
|
for character, char_data in created_characters:
|
|
|
|
|
|
# 跳过组织实体本身
|
|
|
|
|
|
if character.is_organization:
|
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
|
|
# 处理组织成员关系
|
|
|
|
|
|
org_memberships = char_data.get("organization_memberships", [])
|
|
|
|
|
|
if org_memberships and isinstance(org_memberships, list):
|
|
|
|
|
|
for membership in org_memberships:
|
|
|
|
|
|
try:
|
|
|
|
|
|
org_name = membership.get("organization_name")
|
|
|
|
|
|
if not org_name:
|
|
|
|
|
|
logger.debug(f" ⚠️ {character.name}的组织成员关系缺少organization_name,跳过")
|
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
|
|
# 使用映射快速查找组织
|
|
|
|
|
|
org = organization_name_to_obj.get(org_name)
|
|
|
|
|
|
|
|
|
|
|
|
if org:
|
|
|
|
|
|
# 检查是否已存在成员关系
|
|
|
|
|
|
existing_member = await db.execute(
|
|
|
|
|
|
select(OrganizationMember).where(
|
|
|
|
|
|
OrganizationMember.organization_id == org.id,
|
|
|
|
|
|
OrganizationMember.character_id == character.id
|
|
|
|
|
|
)
|
|
|
|
|
|
)
|
|
|
|
|
|
if existing_member.scalar_one_or_none():
|
|
|
|
|
|
logger.debug(f" ℹ️ 成员关系已存在:{character.name} -> {org_name}")
|
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
|
|
# 创建成员关系
|
|
|
|
|
|
member = OrganizationMember(
|
|
|
|
|
|
organization_id=org.id,
|
|
|
|
|
|
character_id=character.id,
|
|
|
|
|
|
position=membership.get("position", "成员"),
|
|
|
|
|
|
rank=membership.get("rank", 0),
|
|
|
|
|
|
loyalty=membership.get("loyalty", 50),
|
|
|
|
|
|
joined_at=membership.get("joined_at"),
|
|
|
|
|
|
status=membership.get("status", "active"),
|
|
|
|
|
|
source="ai"
|
|
|
|
|
|
)
|
|
|
|
|
|
db.add(member)
|
|
|
|
|
|
|
|
|
|
|
|
# 更新组织成员计数
|
|
|
|
|
|
org.member_count += 1
|
|
|
|
|
|
|
|
|
|
|
|
members_created += 1
|
|
|
|
|
|
logger.info(f" ✅ 向导添加成员:{character.name} -> {org_name} ({membership.get('position')})")
|
|
|
|
|
|
else:
|
|
|
|
|
|
# 这种情况理论上已经被预处理清理了,但保留日志以防万一
|
|
|
|
|
|
logger.debug(f" ℹ️ 组织引用已被清理:{character.name} -> {org_name}")
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
logger.warning(f" ❌ 向导添加组织成员失败:{character.name} - {str(e)}")
|
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
|
|
logger.info(f"📊 向导数据统计:")
|
|
|
|
|
|
logger.info(f" - 创建角色/组织:{len(created_characters)} 个")
|
|
|
|
|
|
logger.info(f" - 创建组织详情:{len(organization_name_to_obj)} 个")
|
|
|
|
|
|
logger.info(f" - 创建角色关系:{relationships_created} 条")
|
|
|
|
|
|
logger.info(f" - 创建组织成员:{members_created} 条")
|
|
|
|
|
|
|
2025-11-03 15:28:51 +08:00
|
|
|
|
# 更新项目的角色数量
|
|
|
|
|
|
project.character_count = len(created_characters)
|
|
|
|
|
|
logger.info(f"✅ 更新项目角色数量: {project.character_count}")
|
|
|
|
|
|
|
2025-10-30 11:14:43 +08:00
|
|
|
|
await db.commit()
|
|
|
|
|
|
db_committed = True
|
|
|
|
|
|
|
|
|
|
|
|
# 重新提取character对象
|
|
|
|
|
|
created_characters = [char for char, _ in created_characters]
|
|
|
|
|
|
|
|
|
|
|
|
# 发送结果
|
|
|
|
|
|
yield await SSEResponse.send_result({
|
|
|
|
|
|
"message": f"成功生成{len(created_characters)}个角色/组织(分{total_batches}批完成)",
|
|
|
|
|
|
"count": len(created_characters),
|
|
|
|
|
|
"batches": total_batches,
|
|
|
|
|
|
"characters": [
|
|
|
|
|
|
{
|
|
|
|
|
|
"id": char.id,
|
|
|
|
|
|
"project_id": char.project_id,
|
|
|
|
|
|
"name": char.name,
|
|
|
|
|
|
"age": char.age,
|
|
|
|
|
|
"gender": char.gender,
|
|
|
|
|
|
"is_organization": char.is_organization,
|
|
|
|
|
|
"role_type": char.role_type,
|
|
|
|
|
|
"personality": char.personality,
|
|
|
|
|
|
"background": char.background,
|
|
|
|
|
|
"appearance": char.appearance,
|
|
|
|
|
|
"relationships": char.relationships,
|
|
|
|
|
|
"organization_type": char.organization_type,
|
|
|
|
|
|
"organization_purpose": char.organization_purpose,
|
|
|
|
|
|
"organization_members": char.organization_members,
|
|
|
|
|
|
"traits": char.traits,
|
|
|
|
|
|
"created_at": char.created_at.isoformat() if char.created_at else None,
|
|
|
|
|
|
"updated_at": char.updated_at.isoformat() if char.updated_at else None
|
|
|
|
|
|
} for char in created_characters
|
|
|
|
|
|
]
|
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
|
|
yield await SSEResponse.send_progress("完成!", 100, "success")
|
|
|
|
|
|
yield await SSEResponse.send_done()
|
|
|
|
|
|
|
|
|
|
|
|
except GeneratorExit:
|
|
|
|
|
|
logger.warning("角色生成器被提前关闭")
|
|
|
|
|
|
if not db_committed and db.in_transaction():
|
|
|
|
|
|
await db.rollback()
|
|
|
|
|
|
logger.info("角色生成事务已回滚(GeneratorExit)")
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
logger.error(f"角色生成失败: {str(e)}")
|
|
|
|
|
|
if not db_committed and db.in_transaction():
|
|
|
|
|
|
await db.rollback()
|
|
|
|
|
|
logger.info("角色生成事务已回滚(异常)")
|
|
|
|
|
|
yield await SSEResponse.send_error(f"生成失败: {str(e)}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@router.post("/characters", summary="流式批量生成角色")
|
|
|
|
|
|
async def generate_characters_stream(
|
2025-11-07 22:14:20 +08:00
|
|
|
|
request: Request,
|
2025-10-30 11:14:43 +08:00
|
|
|
|
data: Dict[str, Any],
|
2025-10-30 16:53:50 +08:00
|
|
|
|
db: AsyncSession = Depends(get_db),
|
|
|
|
|
|
user_ai_service: AIService = Depends(get_user_ai_service)
|
2025-10-30 11:14:43 +08:00
|
|
|
|
):
|
|
|
|
|
|
"""
|
|
|
|
|
|
使用SSE流式批量生成角色,避免超时
|
2025-11-07 22:14:20 +08:00
|
|
|
|
支持MCP工具增强
|
2025-10-30 11:14:43 +08:00
|
|
|
|
"""
|
2025-11-07 22:14:20 +08:00
|
|
|
|
# 从中间件注入user_id到data中
|
|
|
|
|
|
if hasattr(request.state, 'user_id'):
|
|
|
|
|
|
data['user_id'] = request.state.user_id
|
|
|
|
|
|
|
2025-10-30 16:53:50 +08:00
|
|
|
|
return create_sse_response(characters_generator(data, db, user_ai_service))
|
2025-10-30 11:14:43 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def outline_generator(
|
|
|
|
|
|
data: Dict[str, Any],
|
2025-10-30 16:53:50 +08:00
|
|
|
|
db: AsyncSession,
|
|
|
|
|
|
user_ai_service: AIService
|
2025-10-30 11:14:43 +08:00
|
|
|
|
) -> AsyncGenerator[str, None]:
|
2025-11-21 15:49:39 +08:00
|
|
|
|
"""大纲生成流式生成器 - 向导仅生成大纲节点,不展开章节(避免等待过久)"""
|
2025-10-30 11:14:43 +08:00
|
|
|
|
db_committed = False
|
|
|
|
|
|
try:
|
|
|
|
|
|
yield await SSEResponse.send_progress("开始生成大纲...", 5)
|
|
|
|
|
|
|
|
|
|
|
|
project_id = data.get("project_id")
|
2025-11-21 15:49:39 +08:00
|
|
|
|
# 向导固定生成3个大纲节点(不展开)
|
|
|
|
|
|
outline_count = data.get("chapter_count", 3)
|
2025-10-30 11:14:43 +08:00
|
|
|
|
narrative_perspective = data.get("narrative_perspective")
|
|
|
|
|
|
target_words = data.get("target_words", 100000)
|
|
|
|
|
|
requirements = data.get("requirements", "")
|
|
|
|
|
|
provider = data.get("provider")
|
|
|
|
|
|
model = data.get("model")
|
|
|
|
|
|
|
|
|
|
|
|
# 获取项目信息
|
|
|
|
|
|
yield await SSEResponse.send_progress("加载项目信息...", 10)
|
|
|
|
|
|
result = await db.execute(
|
|
|
|
|
|
select(Project).where(Project.id == project_id)
|
|
|
|
|
|
)
|
|
|
|
|
|
project = result.scalar_one_or_none()
|
|
|
|
|
|
if not project:
|
|
|
|
|
|
yield await SSEResponse.send_error("项目不存在", 404)
|
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
|
|
# 获取角色信息
|
|
|
|
|
|
yield await SSEResponse.send_progress("加载角色信息...", 15)
|
|
|
|
|
|
result = await db.execute(
|
|
|
|
|
|
select(Character).where(Character.project_id == project_id)
|
|
|
|
|
|
)
|
|
|
|
|
|
characters = result.scalars().all()
|
|
|
|
|
|
|
|
|
|
|
|
characters_info = "\n".join([
|
|
|
|
|
|
f"- {char.name} ({'组织' if char.is_organization else '角色'}, {char.role_type}): {char.personality[:100] if char.personality else '暂无描述'}"
|
|
|
|
|
|
for char in characters
|
|
|
|
|
|
])
|
|
|
|
|
|
|
2025-11-19 13:30:55 +08:00
|
|
|
|
# 第一阶段:生成3个粗粒度大纲节点
|
|
|
|
|
|
yield await SSEResponse.send_progress(f"生成{outline_count}个大纲节点...", 20)
|
|
|
|
|
|
|
|
|
|
|
|
outline_requirements = f"{requirements}\n\n【重要说明】这是小说的开局部分,请生成{outline_count}个大纲节点,重点关注:\n"
|
|
|
|
|
|
outline_requirements += "1. 引入主要角色和世界观设定\n"
|
|
|
|
|
|
outline_requirements += "2. 建立主线冲突和故事钩子\n"
|
|
|
|
|
|
outline_requirements += "3. 展开初期情节,为后续发展埋下伏笔\n"
|
|
|
|
|
|
outline_requirements += "4. 不要试图完结故事,这只是开始部分\n"
|
|
|
|
|
|
outline_requirements += "5. 不要在JSON字符串值中使用中文引号(""''),请使用【】或《》标记\n"
|
|
|
|
|
|
|
|
|
|
|
|
outline_prompt = prompt_service.get_complete_outline_prompt(
|
|
|
|
|
|
title=project.title,
|
|
|
|
|
|
theme=project.theme or "未设定",
|
|
|
|
|
|
genre=project.genre or "通用",
|
|
|
|
|
|
chapter_count=outline_count,
|
|
|
|
|
|
narrative_perspective=narrative_perspective,
|
|
|
|
|
|
target_words=target_words // 10, # 开局约占总字数的1/10
|
|
|
|
|
|
time_period=project.world_time_period or "未设定",
|
|
|
|
|
|
location=project.world_location or "未设定",
|
|
|
|
|
|
atmosphere=project.world_atmosphere or "未设定",
|
|
|
|
|
|
rules=project.world_rules or "未设定",
|
|
|
|
|
|
characters_info=characters_info or "暂无角色信息",
|
|
|
|
|
|
requirements=outline_requirements
|
|
|
|
|
|
)
|
2025-10-30 11:14:43 +08:00
|
|
|
|
|
2025-11-19 13:30:55 +08:00
|
|
|
|
# 流式生成大纲
|
|
|
|
|
|
accumulated_text = ""
|
|
|
|
|
|
async for chunk in user_ai_service.generate_text_stream(
|
|
|
|
|
|
prompt=outline_prompt,
|
|
|
|
|
|
provider=provider,
|
|
|
|
|
|
model=model
|
|
|
|
|
|
):
|
|
|
|
|
|
accumulated_text += chunk
|
|
|
|
|
|
yield await SSEResponse.send_chunk(chunk)
|
2025-10-30 11:14:43 +08:00
|
|
|
|
|
2025-11-19 13:30:55 +08:00
|
|
|
|
# 解析大纲结果
|
|
|
|
|
|
yield await SSEResponse.send_progress("解析大纲...", 40)
|
|
|
|
|
|
cleaned_text = accumulated_text.strip()
|
|
|
|
|
|
if cleaned_text.startswith('```json'):
|
|
|
|
|
|
cleaned_text = cleaned_text[7:].lstrip('\n\r')
|
|
|
|
|
|
elif cleaned_text.startswith('```'):
|
|
|
|
|
|
cleaned_text = cleaned_text[3:].lstrip('\n\r')
|
|
|
|
|
|
if cleaned_text.endswith('```'):
|
|
|
|
|
|
cleaned_text = cleaned_text[:-3].rstrip('\n\r')
|
|
|
|
|
|
cleaned_text = cleaned_text.strip()
|
2025-10-30 11:14:43 +08:00
|
|
|
|
|
2025-11-19 13:30:55 +08:00
|
|
|
|
try:
|
|
|
|
|
|
outline_data = json.loads(cleaned_text)
|
|
|
|
|
|
if not isinstance(outline_data, list):
|
|
|
|
|
|
outline_data = [outline_data]
|
|
|
|
|
|
except json.JSONDecodeError as e:
|
|
|
|
|
|
logger.error(f"大纲JSON解析失败: {e}")
|
|
|
|
|
|
yield await SSEResponse.send_error("大纲生成失败,请重试")
|
2025-10-30 11:14:43 +08:00
|
|
|
|
return
|
|
|
|
|
|
|
2025-11-19 13:30:55 +08:00
|
|
|
|
# 保存大纲到数据库
|
|
|
|
|
|
yield await SSEResponse.send_progress("保存大纲到数据库...", 45)
|
2025-10-30 11:14:43 +08:00
|
|
|
|
created_outlines = []
|
2025-11-19 13:30:55 +08:00
|
|
|
|
for index, outline_item in enumerate(outline_data[:outline_count], 1):
|
2025-10-30 11:14:43 +08:00
|
|
|
|
outline = Outline(
|
|
|
|
|
|
project_id=project_id,
|
2025-11-19 13:30:55 +08:00
|
|
|
|
title=outline_item.get("title", f"第{index}节"),
|
|
|
|
|
|
content=outline_item.get("summary", outline_item.get("content", "")),
|
|
|
|
|
|
structure=json.dumps(outline_item, ensure_ascii=False),
|
|
|
|
|
|
order_index=index
|
2025-10-30 11:14:43 +08:00
|
|
|
|
)
|
|
|
|
|
|
db.add(outline)
|
|
|
|
|
|
created_outlines.append(outline)
|
2025-11-19 13:30:55 +08:00
|
|
|
|
|
|
|
|
|
|
await db.flush() # 获取大纲ID
|
|
|
|
|
|
for outline in created_outlines:
|
|
|
|
|
|
await db.refresh(outline)
|
|
|
|
|
|
|
|
|
|
|
|
logger.info(f"✅ 成功创建{len(created_outlines)}个大纲节点")
|
|
|
|
|
|
|
2025-11-21 15:49:39 +08:00
|
|
|
|
# 向导流程中不展开大纲,避免等待时间过长
|
|
|
|
|
|
# 用户可以在大纲页面手动展开需要的大纲节点
|
|
|
|
|
|
yield await SSEResponse.send_progress("跳过大纲展开,加快创建速度...", 85)
|
2025-10-30 11:14:43 +08:00
|
|
|
|
|
2025-11-19 13:30:55 +08:00
|
|
|
|
# 更新项目信息
|
2025-11-21 15:49:39 +08:00
|
|
|
|
project.chapter_count = 0 # 向导阶段不创建章节
|
2025-10-30 11:14:43 +08:00
|
|
|
|
project.narrative_perspective = narrative_perspective
|
|
|
|
|
|
project.target_words = target_words
|
|
|
|
|
|
project.status = "writing"
|
|
|
|
|
|
project.wizard_status = "completed"
|
|
|
|
|
|
project.wizard_step = 4
|
|
|
|
|
|
|
|
|
|
|
|
await db.commit()
|
|
|
|
|
|
db_committed = True
|
|
|
|
|
|
|
2025-11-19 13:30:55 +08:00
|
|
|
|
logger.info(f"📊 向导大纲生成完成:")
|
|
|
|
|
|
logger.info(f" - 创建大纲节点:{len(created_outlines)} 个")
|
2025-11-21 15:49:39 +08:00
|
|
|
|
logger.info(f" - 提示:可在大纲页面手动展开为章节")
|
2025-11-19 13:30:55 +08:00
|
|
|
|
|
2025-10-30 11:14:43 +08:00
|
|
|
|
# 发送结果
|
|
|
|
|
|
yield await SSEResponse.send_result({
|
2025-11-21 15:49:39 +08:00
|
|
|
|
"message": f"成功生成{len(created_outlines)}个大纲节点(未展开章节,可在大纲页面手动展开)",
|
2025-11-19 13:30:55 +08:00
|
|
|
|
"outline_count": len(created_outlines),
|
2025-11-21 15:49:39 +08:00
|
|
|
|
"chapter_count": 0,
|
2025-10-30 11:14:43 +08:00
|
|
|
|
"outlines": [
|
|
|
|
|
|
{
|
2025-11-19 13:30:55 +08:00
|
|
|
|
"id": outline.id,
|
2025-10-30 11:14:43 +08:00
|
|
|
|
"order_index": outline.order_index,
|
|
|
|
|
|
"title": outline.title,
|
2025-11-21 15:49:39 +08:00
|
|
|
|
"content": outline.content[:100] + "..." if len(outline.content) > 100 else outline.content,
|
|
|
|
|
|
"note": "可在大纲页面展开为章节"
|
2025-10-30 11:14:43 +08:00
|
|
|
|
} for outline in created_outlines
|
|
|
|
|
|
]
|
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
|
|
yield await SSEResponse.send_progress("完成!", 100, "success")
|
|
|
|
|
|
yield await SSEResponse.send_done()
|
|
|
|
|
|
|
|
|
|
|
|
except GeneratorExit:
|
|
|
|
|
|
logger.warning("大纲生成器被提前关闭")
|
|
|
|
|
|
if not db_committed and db.in_transaction():
|
|
|
|
|
|
await db.rollback()
|
|
|
|
|
|
logger.info("大纲生成事务已回滚(GeneratorExit)")
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
logger.error(f"大纲生成失败: {str(e)}")
|
|
|
|
|
|
if not db_committed and db.in_transaction():
|
|
|
|
|
|
await db.rollback()
|
|
|
|
|
|
logger.info("大纲生成事务已回滚(异常)")
|
|
|
|
|
|
yield await SSEResponse.send_error(f"生成失败: {str(e)}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@router.post("/outline", summary="流式生成完整大纲")
|
|
|
|
|
|
async def generate_outline_stream(
|
|
|
|
|
|
data: Dict[str, Any],
|
2025-10-30 16:53:50 +08:00
|
|
|
|
db: AsyncSession = Depends(get_db),
|
|
|
|
|
|
user_ai_service: AIService = Depends(get_user_ai_service)
|
2025-10-30 11:14:43 +08:00
|
|
|
|
):
|
|
|
|
|
|
"""
|
|
|
|
|
|
使用SSE流式生成完整大纲,避免超时
|
|
|
|
|
|
"""
|
2025-10-30 16:53:50 +08:00
|
|
|
|
return create_sse_response(outline_generator(data, db, user_ai_service))
|
2025-10-30 11:14:43 +08:00
|
|
|
|
|