Files
MuMuAINovel/backend/app/api/wizard_stream.py
T

1336 lines
60 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""项目创建向导流式API - 使用SSE避免超时"""
from fastapi import APIRouter, Depends, HTTPException, Request
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
from typing import Dict, Any, AsyncGenerator
import json
import re
from app.database import get_db
from app.models.project import Project
from app.models.character import Character
from app.models.outline import Outline
from app.models.chapter import Chapter
from app.models.relationship import CharacterRelationship, Organization, OrganizationMember, RelationshipType
from app.models.writing_style import WritingStyle
from app.models.project_default_style import ProjectDefaultStyle
from app.services.ai_service import AIService
from app.services.mcp_tool_service import MCPToolService
from app.services.prompt_service import prompt_service, PromptService
from app.services.plot_expansion_service import PlotExpansionService
from app.logger import get_logger
from app.utils.sse_response import SSEResponse, create_sse_response
from app.api.settings import get_user_ai_service
router = APIRouter(prefix="/wizard-stream", tags=["项目创建向导(流式)"])
logger = get_logger(__name__)
async def world_building_generator(
data: Dict[str, Any],
db: AsyncSession,
user_ai_service: AIService
) -> AsyncGenerator[str, None]:
"""世界构建流式生成器 - 支持MCP工具增强"""
# 标记数据库会话是否已提交
db_committed = False
try:
# 发送开始消息
yield await SSEResponse.send_progress("开始生成世界观...", 10)
# 提取参数
title = data.get("title")
description = data.get("description")
theme = data.get("theme")
genre = data.get("genre")
narrative_perspective = data.get("narrative_perspective")
target_words = data.get("target_words")
chapter_count = data.get("chapter_count")
character_count = data.get("character_count")
outline_mode = data.get("outline_mode", "one-to-many") # 大纲模式,默认一对多
provider = data.get("provider")
model = data.get("model")
enable_mcp = data.get("enable_mcp", True) # 默认启用MCP
user_id = data.get("user_id") # 从中间件注入
if not title or not description or not theme or not genre:
yield await SSEResponse.send_error("title、description、theme 和 genre 是必需的参数", 400)
return
# 获取基础提示词(支持自定义)
yield await SSEResponse.send_progress("准备AI提示词...", 15)
template = await PromptService.get_template("WORLD_BUILDING", user_id, db)
base_prompt = PromptService.format_prompt(
template,
title=title,
theme=theme,
genre=genre or "通用类型"
)
# MCP工具增强:收集参考资料
reference_materials = ""
if enable_mcp and user_id:
try:
# 先静默检查是否有可用工具
from app.services.mcp_tool_service import mcp_tool_service
available_tools = await mcp_tool_service.get_user_enabled_tools(
user_id=user_id,
db_session=db
)
# 只有在真正有可用工具时才显示消息和调用
if available_tools:
yield await SSEResponse.send_progress("🔍 尝试使用MCP工具收集参考资料...", 18)
# 构建资料收集提示词
planning_prompt = f"""你正在为小说《{title}》设计世界观。
【小说信息】
- 题材:{genre}
- 主题:{theme}
- 简介:{description}
【任务】
请使用可用工具搜索相关背景资料,帮助构建更真实、更有深度的世界观设定。
你可以查询:
1. 历史背景(如果是历史题材)
2. 地理环境和文化特征
3. 相关领域的专业知识
4. 类似作品的设定参考
请查询最关键的1个问题(不要超过1个)。"""
# 调用MCP增强的AI(非流式,最多1轮工具调用,避免超时)
planning_result = await user_ai_service.generate_text_with_mcp(
prompt=planning_prompt,
user_id=user_id,
db_session=db,
enable_mcp=True,
max_tool_rounds=1,
tool_choice="auto",
provider=None,
model=None
)
# 提取参考资料
if planning_result.get("tool_calls_made", 0) > 0:
yield await SSEResponse.send_progress(
f"✅ MCP工具调用成功({planning_result['tool_calls_made']}次)",
25
)
reference_materials = planning_result.get("content", "")
else:
# 有工具但未使用
logger.debug("MCP工具可用但AI未选择使用")
else:
# 没有可用工具,静默跳过
logger.debug(f"用户 {user_id} 未启用MCP工具,跳过MCP增强")
except Exception as e:
logger.warning(f"MCP工具调用失败(降级处理): {e}")
yield await SSEResponse.send_progress("⚠️ MCP工具暂时不可用,使用基础模式", 25)
# 构建增强提示词
if reference_materials:
enhanced_prompt = f"""{base_prompt}
【参考资料】
以下是通过MCP工具收集的真实背景资料,请参考这些信息构建更真实的世界观:
{reference_materials}
请结合上述资料,生成符合历史/现实的世界观设定。"""
final_prompt = enhanced_prompt
yield await SSEResponse.send_progress("💡 已整合参考资料,开始生成世界观...", 30)
else:
final_prompt = base_prompt
yield await SSEResponse.send_progress("正在调用AI生成...", 30)
# 流式生成世界观
accumulated_text = ""
chunk_count = 0
async for chunk in user_ai_service.generate_text_stream(
prompt=final_prompt,
provider=provider,
model=model
):
chunk_count += 1
accumulated_text += chunk
# 发送内容块
yield await SSEResponse.send_chunk(chunk)
# 定期更新进度
if chunk_count % 5 == 0:
progress = min(30 + (chunk_count // 5), 70)
yield await SSEResponse.send_progress(f"生成中... ({len(accumulated_text)}字符)", progress)
# 每20个块发送心跳
if chunk_count % 20 == 0:
yield await SSEResponse.send_heartbeat()
# 解析结果
yield await SSEResponse.send_progress("解析AI返回结果...", 80)
world_data = {}
try:
cleaned_text = accumulated_text.strip()
# 移除markdown代码块标记
if cleaned_text.startswith('```json'):
cleaned_text = cleaned_text[7:].lstrip('\n\r')
elif cleaned_text.startswith('```'):
cleaned_text = cleaned_text[3:].lstrip('\n\r')
if cleaned_text.endswith('```'):
cleaned_text = cleaned_text[:-3].rstrip('\n\r')
cleaned_text = cleaned_text.strip()
world_data = json.loads(cleaned_text)
except json.JSONDecodeError as e:
logger.error(f"世界构建JSON解析失败: {e}")
world_data = {
"time_period": "AI返回格式错误,请重试",
"location": "AI返回格式错误,请重试",
"atmosphere": "AI返回格式错误,请重试",
"rules": "AI返回格式错误,请重试"
}
# 保存到数据库
yield await SSEResponse.send_progress("保存到数据库...", 90)
# 确保user_id存在
if not user_id:
yield await SSEResponse.send_error("用户ID缺失,无法创建项目", 401)
return
project = Project(
user_id=user_id, # 添加user_id字段
title=title,
description=description,
theme=theme,
genre=genre,
world_time_period=world_data.get("time_period"),
world_location=world_data.get("location"),
world_atmosphere=world_data.get("atmosphere"),
world_rules=world_data.get("rules"),
narrative_perspective=narrative_perspective,
target_words=target_words,
chapter_count=chapter_count,
character_count=character_count,
outline_mode=outline_mode, # 设置大纲模式
wizard_status="incomplete",
wizard_step=1,
status="planning"
)
db.add(project)
await db.commit()
await db.refresh(project)
# 自动设置默认写作风格为第一个全局预设风格
try:
result = await db.execute(
select(WritingStyle).where(
WritingStyle.user_id.is_(None),
WritingStyle.order_index == 1
).limit(1)
)
first_style = result.scalar_one_or_none()
if first_style:
default_style = ProjectDefaultStyle(
project_id=project.id,
style_id=first_style.id
)
db.add(default_style)
await db.commit()
logger.info(f"为项目 {project.id} 自动设置默认风格: {first_style.name}")
else:
logger.warning(f"未找到order_index=1的全局预设风格,项目 {project.id} 未设置默认风格")
except Exception as e:
logger.warning(f"设置默认写作风格失败: {e},不影响项目创建")
# 更新向导步骤状态为1(世界观已完成)
project.wizard_step = 1
await db.commit()
db_committed = True
# 发送最终结果
yield await SSEResponse.send_result({
"project_id": project.id,
"time_period": world_data.get("time_period"),
"location": world_data.get("location"),
"atmosphere": world_data.get("atmosphere"),
"rules": world_data.get("rules")
})
yield await SSEResponse.send_progress("完成!", 100, "success")
yield await SSEResponse.send_done()
except GeneratorExit:
# SSE连接断开,回滚未提交的事务
logger.warning("世界构建生成器被提前关闭")
if not db_committed and db.in_transaction():
await db.rollback()
logger.info("世界构建事务已回滚(GeneratorExit")
except Exception as e:
logger.error(f"世界构建流式生成失败: {str(e)}")
# 异常时回滚事务
if not db_committed and db.in_transaction():
await db.rollback()
logger.info("世界构建事务已回滚(异常)")
yield await SSEResponse.send_error(f"生成失败: {str(e)}")
@router.post("/world-building", summary="流式生成世界构建")
async def generate_world_building_stream(
request: Request,
data: Dict[str, Any],
db: AsyncSession = Depends(get_db),
user_ai_service: AIService = Depends(get_user_ai_service)
):
"""
使用SSE流式生成世界构建,避免超时
前端使用EventSource接收实时进度和结果
"""
# 从中间件注入user_id到data中
if hasattr(request.state, 'user_id'):
data['user_id'] = request.state.user_id
return create_sse_response(world_building_generator(data, db, user_ai_service))
async def characters_generator(
data: Dict[str, Any],
db: AsyncSession,
user_ai_service: AIService
) -> AsyncGenerator[str, None]:
"""角色批量生成流式生成器 - 优化版:分批+重试+MCP工具增强"""
db_committed = False
try:
yield await SSEResponse.send_progress("开始生成角色...", 5)
project_id = data.get("project_id")
count = data.get("count", 5)
world_context = data.get("world_context")
theme = data.get("theme", "")
genre = data.get("genre", "")
requirements = data.get("requirements", "")
provider = data.get("provider")
model = data.get("model")
enable_mcp = data.get("enable_mcp", True) # 默认启用MCP
user_id = data.get("user_id") # 从中间件注入
# 验证项目
yield await SSEResponse.send_progress("验证项目...", 10)
result = await db.execute(
select(Project).where(Project.id == project_id)
)
project = result.scalar_one_or_none()
if not project:
yield await SSEResponse.send_error("项目不存在", 404)
return
project.wizard_step = 2
world_context = world_context or {
"time_period": project.world_time_period or "未设定",
"location": project.world_location or "未设定",
"atmosphere": project.world_atmosphere or "未设定",
"rules": project.world_rules or "未设定"
}
# MCP工具增强:收集角色参考资料
character_reference_materials = ""
if enable_mcp and user_id:
try:
# 先静默检查是否有可用工具
from app.services.mcp_tool_service import mcp_tool_service
available_tools = await mcp_tool_service.get_user_enabled_tools(
user_id=user_id,
db_session=db
)
# 只有在真正有可用工具时才显示消息和调用
if available_tools:
yield await SSEResponse.send_progress("🔍 尝试使用MCP工具收集角色参考资料...", 8)
# 构建角色资料收集提示词
planning_prompt = f"""你正在为小说《{project.title}》设计角色。
【小说信息】
- 题材:{genre or project.genre}
- 主题:{theme or project.theme}
- 时代背景:{world_context.get('time_period', '未设定')}
- 地理位置:{world_context.get('location', '未设定')}
【任务】
请使用可用工具搜索相关参考资料,帮助设计更真实、更有深度的角色。
你可以查询:
1. 该时代/地域的真实历史人物特征
2. 文化背景和社会习俗
3. 职业特点和生活方式
4. 相关领域的人物原型
请查询最关键的1个问题(不要超过1个)。"""
# 调用MCP增强的AI(非流式,最多1轮工具调用,避免超时)
planning_result = await user_ai_service.generate_text_with_mcp(
prompt=planning_prompt,
user_id=user_id,
db_session=db,
enable_mcp=True,
max_tool_rounds=1, # ✅ 优化: 从2轮减少到1轮
tool_choice="auto",
provider=None,
model=None
)
# 提取参考资料
if planning_result.get("tool_calls_made", 0) > 0:
yield await SSEResponse.send_progress(
f"✅ MCP工具调用成功({planning_result['tool_calls_made']}次)",
12
)
character_reference_materials = planning_result.get("content", "")
else:
# 有工具但未使用
logger.debug("MCP工具可用但AI未选择使用")
else:
# 没有可用工具,静默跳过
logger.debug(f"用户 {user_id} 未启用MCP工具,跳过MCP增强")
except Exception as e:
logger.warning(f"MCP工具调用失败(降级处理): {e}")
yield await SSEResponse.send_progress("⚠️ MCP工具暂时不可用,使用基础模式", 12)
# 优化的分批策略:每批生成3个,平衡效率和成功率
BATCH_SIZE = 3 # 每批生成3个角色
MAX_RETRIES = 3 # 每批最多重试3次
all_characters = []
total_batches = (count + BATCH_SIZE - 1) // BATCH_SIZE
for batch_idx in range(total_batches):
# 精确计算当前批次应该生成的数量
remaining = count - len(all_characters)
current_batch_size = min(BATCH_SIZE, remaining)
# 如果已经达到目标数量,直接退出
if current_batch_size <= 0:
logger.info(f"已生成{len(all_characters)}个角色,达到目标数量{count}")
break
batch_progress = 15 + (batch_idx * 60 // total_batches)
# 重试逻辑
retry_count = 0
batch_success = False
batch_error_message = ""
while retry_count < MAX_RETRIES and not batch_success:
try:
retry_suffix = f" (重试{retry_count}/{MAX_RETRIES})" if retry_count > 0 else ""
yield await SSEResponse.send_progress(
f"生成第{batch_idx+1}/{total_batches}批角色 ({current_batch_size}个){retry_suffix}...",
batch_progress
)
# 构建批次要求 - 包含已生成角色信息保持连贯
existing_chars_context = ""
if all_characters:
existing_chars_context = "\n\n【已生成的角色】:\n"
for char in all_characters:
existing_chars_context += f"- {char.get('name')}: {char.get('role_type', '未知')}, {char.get('personality', '暂无')[:50]}...\n"
existing_chars_context += "\n请确保新角色与已有角色形成合理的关系网络和互动。\n"
# 构建精确的批次要求,明确告诉AI要生成的数量
if batch_idx == 0:
if current_batch_size == 1:
batch_requirements = f"{requirements}\n请生成1个主角(protagonist)"
else:
batch_requirements = f"{requirements}\n请精确生成{current_batch_size}个角色:1个主角(protagonist)和{current_batch_size-1}个核心配角(supporting)"
else:
batch_requirements = f"{requirements}\n请精确生成{current_batch_size}个角色{existing_chars_context}"
if batch_idx == total_batches - 1:
batch_requirements += "\n可以包含组织或反派(antagonist)"
else:
batch_requirements += "\n主要是配角(supporting)和反派(antagonist)"
# 获取自定义提示词模板
template = await PromptService.get_template("CHARACTERS_BATCH", user_id, db)
# 构建基础提示词
base_prompt = PromptService.format_prompt(
template,
count=current_batch_size, # 传递精确数量
time_period=world_context.get("time_period", ""),
location=world_context.get("location", ""),
atmosphere=world_context.get("atmosphere", ""),
rules=world_context.get("rules", ""),
theme=theme or project.theme or "",
genre=genre or project.genre or "",
requirements=batch_requirements
)
# 如果有MCP参考资料,增强提示词
if character_reference_materials:
prompt = f"""{base_prompt}
【参考资料】
以下是通过MCP工具收集的真实背景资料,请参考这些信息设计更真实的角色:
{character_reference_materials}
请结合上述资料,设计符合历史/文化背景的角色。"""
else:
prompt = base_prompt
# 流式生成
accumulated_text = ""
async for chunk in user_ai_service.generate_text_stream(
prompt=prompt,
provider=provider,
model=model
):
accumulated_text += chunk
yield await SSEResponse.send_chunk(chunk)
# 解析批次结果
cleaned_text = accumulated_text.strip()
# 移除markdown代码块标记
if cleaned_text.startswith('```json'):
cleaned_text = cleaned_text[7:].lstrip('\n\r')
elif cleaned_text.startswith('```'):
cleaned_text = cleaned_text[3:].lstrip('\n\r')
if cleaned_text.endswith('```'):
cleaned_text = cleaned_text[:-3].rstrip('\n\r')
cleaned_text = cleaned_text.strip()
characters_data = json.loads(cleaned_text)
if not isinstance(characters_data, list):
characters_data = [characters_data]
# 严格验证生成数量是否精确匹配
if len(characters_data) != current_batch_size:
error_msg = f"批次{batch_idx+1}生成数量不正确: 期望{current_batch_size}个, 实际{len(characters_data)}"
logger.error(error_msg)
# 如果还有重试机会,继续重试
if retry_count < MAX_RETRIES - 1:
retry_count += 1
yield await SSEResponse.send_progress(
f"⚠️ {error_msg},准备重试...",
batch_progress,
"warning"
)
continue
else:
# 最后一次重试仍失败,直接返回错误
yield await SSEResponse.send_error(error_msg)
return
all_characters.extend(characters_data)
batch_success = True
logger.info(f"批次{batch_idx+1}成功添加{len(characters_data)}个角色,当前总数{len(all_characters)}/{count}")
except json.JSONDecodeError as e:
logger.error(f"批次{batch_idx+1}解析失败(尝试{retry_count+1}/{MAX_RETRIES}): {e}")
batch_error_message = f"JSON解析失败: {str(e)}"
retry_count += 1
if retry_count < MAX_RETRIES:
yield await SSEResponse.send_progress(
f"解析失败,准备重试...",
batch_progress,
"warning"
)
except Exception as e:
logger.error(f"批次{batch_idx+1}生成异常(尝试{retry_count+1}/{MAX_RETRIES}): {e}")
batch_error_message = f"生成异常: {str(e)}"
retry_count += 1
if retry_count < MAX_RETRIES:
yield await SSEResponse.send_progress(
f"生成异常,准备重试...",
batch_progress,
"warning"
)
# 检查批次是否成功
if not batch_success:
error_msg = f"批次{batch_idx+1}{MAX_RETRIES}次重试后仍然失败"
if batch_error_message:
error_msg += f": {batch_error_message}"
logger.error(error_msg)
yield await SSEResponse.send_error(error_msg)
return
# 保存到数据库 - 分阶段处理以保证一致性
yield await SSEResponse.send_progress("验证角色数据...", 82)
# 预处理:构建本批次所有实体的名称集合
valid_entity_names = set()
valid_organization_names = set()
for char_data in all_characters:
entity_name = char_data.get("name", "")
if entity_name:
valid_entity_names.add(entity_name)
if char_data.get("is_organization", False):
valid_organization_names.add(entity_name)
# 清理幻觉引用
cleaned_count = 0
for char_data in all_characters:
# 清理关系数组中的无效引用
if "relationships_array" in char_data and isinstance(char_data["relationships_array"], list):
original_rels = char_data["relationships_array"]
valid_rels = []
for rel in original_rels:
target_name = rel.get("target_character_name", "")
if target_name in valid_entity_names:
valid_rels.append(rel)
else:
cleaned_count += 1
logger.debug(f" 🧹 清理无效关系引用:{char_data.get('name')} -> {target_name}")
char_data["relationships_array"] = valid_rels
# 清理组织成员关系中的无效引用
if "organization_memberships" in char_data and isinstance(char_data["organization_memberships"], list):
original_orgs = char_data["organization_memberships"]
valid_orgs = []
for org_mem in original_orgs:
org_name = org_mem.get("organization_name", "")
if org_name in valid_organization_names:
valid_orgs.append(org_mem)
else:
cleaned_count += 1
logger.debug(f" 🧹 清理无效组织引用:{char_data.get('name')} -> {org_name}")
char_data["organization_memberships"] = valid_orgs
if cleaned_count > 0:
logger.info(f"✨ 清理了{cleaned_count}个AI幻觉引用")
yield await SSEResponse.send_progress(f"已清理{cleaned_count}个无效引用", 84)
yield await SSEResponse.send_progress("保存角色到数据库...", 85)
# 第一阶段:创建所有Character记录
created_characters = []
character_name_to_obj = {} # 名称到对象的映射,用于后续关系创建
for char_data in all_characters:
# 从relationships_array提取文本描述以保持向后兼容
relationships_text = ""
relationships_array = char_data.get("relationships_array", [])
if relationships_array and isinstance(relationships_array, list):
# 将关系数组转换为可读文本
rel_descriptions = []
for rel in relationships_array:
target = rel.get("target_character_name", "未知")
rel_type = rel.get("relationship_type", "关系")
desc = rel.get("description", "")
rel_descriptions.append(f"{target}({rel_type}): {desc}")
relationships_text = "; ".join(rel_descriptions)
# 兼容旧格式
elif isinstance(char_data.get("relationships"), dict):
relationships_text = json.dumps(char_data.get("relationships"), ensure_ascii=False)
elif isinstance(char_data.get("relationships"), str):
relationships_text = char_data.get("relationships")
# 判断是否为组织
is_organization = char_data.get("is_organization", False)
character = Character(
project_id=project_id,
name=char_data.get("name", "未命名角色"),
age=str(char_data.get("age", "")) if not is_organization else None,
gender=char_data.get("gender") if not is_organization else None,
is_organization=is_organization,
role_type=char_data.get("role_type", "supporting"),
personality=char_data.get("personality", ""),
background=char_data.get("background", ""),
appearance=char_data.get("appearance", ""),
relationships=relationships_text,
organization_type=char_data.get("organization_type") if is_organization else None,
organization_purpose=char_data.get("organization_purpose") if is_organization else None,
organization_members=json.dumps(char_data.get("organization_members", []), ensure_ascii=False) if is_organization else None,
traits=json.dumps(char_data.get("traits", []), ensure_ascii=False) if char_data.get("traits") else None
)
db.add(character)
created_characters.append((character, char_data))
await db.flush() # 获取所有角色的ID
# 刷新并建立名称映射
for character, _ in created_characters:
await db.refresh(character)
character_name_to_obj[character.name] = character
logger.info(f"向导创建角色:{character.name} (ID: {character.id}, 是否组织: {character.is_organization})")
# 为is_organization=True的角色创建Organization记录
yield await SSEResponse.send_progress("创建组织记录...", 87)
organization_name_to_obj = {} # 组织名称到Organization对象的映射
for character, char_data in created_characters:
if character.is_organization:
# 检查是否已存在Organization记录
org_check = await db.execute(
select(Organization).where(Organization.character_id == character.id)
)
existing_org = org_check.scalar_one_or_none()
if not existing_org:
# 创建Organization记录
org = Organization(
character_id=character.id,
project_id=project_id,
member_count=0, # 初始为0,后续添加成员时会更新
power_level=char_data.get("power_level", 50),
location=char_data.get("location"),
motto=char_data.get("motto"),
color=char_data.get("color")
)
db.add(org)
logger.info(f"向导创建组织记录:{character.name}")
else:
org = existing_org
# 建立组织名称映射(无论是新建还是已存在)
organization_name_to_obj[character.name] = org
await db.flush() # 确保Organization记录有ID
# 刷新角色以获取ID
for character, _ in created_characters:
await db.refresh(character)
# 第三阶段:创建角色间的关系
yield await SSEResponse.send_progress("创建角色关系...", 90)
relationships_created = 0
for character, char_data in created_characters:
# 跳过组织实体的角色关系处理(组织通过成员关系关联)
if character.is_organization:
continue
# 处理relationships数组
relationships_data = char_data.get("relationships_array", [])
if not relationships_data and isinstance(char_data.get("relationships"), list):
relationships_data = char_data.get("relationships")
if relationships_data and isinstance(relationships_data, list):
for rel in relationships_data:
try:
target_name = rel.get("target_character_name")
if not target_name:
logger.debug(f" ⚠️ {character.name}的关系缺少target_character_name,跳过")
continue
# 使用名称映射快速查找
target_char = character_name_to_obj.get(target_name)
if target_char:
# 避免创建重复关系
existing_rel = await db.execute(
select(CharacterRelationship).where(
CharacterRelationship.project_id == project_id,
CharacterRelationship.character_from_id == character.id,
CharacterRelationship.character_to_id == target_char.id
)
)
if existing_rel.scalar_one_or_none():
logger.debug(f" ️ 关系已存在:{character.name} -> {target_name}")
continue
relationship = CharacterRelationship(
project_id=project_id,
character_from_id=character.id,
character_to_id=target_char.id,
relationship_name=rel.get("relationship_type", "未知关系"),
intimacy_level=rel.get("intimacy_level", 50),
description=rel.get("description", ""),
started_at=rel.get("started_at"),
source="ai"
)
# 匹配预定义关系类型
rel_type_result = await db.execute(
select(RelationshipType).where(
RelationshipType.name == rel.get("relationship_type")
)
)
rel_type = rel_type_result.scalar_one_or_none()
if rel_type:
relationship.relationship_type_id = rel_type.id
db.add(relationship)
relationships_created += 1
logger.info(f" ✅ 向导创建关系:{character.name} -> {target_name} ({rel.get('relationship_type')})")
else:
logger.warning(f" ⚠️ 目标角色不存在:{character.name} -> {target_name}(可能是AI幻觉)")
except Exception as e:
logger.warning(f" ❌ 向导创建关系失败:{character.name} - {str(e)}")
continue
# 第四阶段:创建组织成员关系
yield await SSEResponse.send_progress("创建组织成员关系...", 93)
members_created = 0
for character, char_data in created_characters:
# 跳过组织实体本身
if character.is_organization:
continue
# 处理组织成员关系
org_memberships = char_data.get("organization_memberships", [])
if org_memberships and isinstance(org_memberships, list):
for membership in org_memberships:
try:
org_name = membership.get("organization_name")
if not org_name:
logger.debug(f" ⚠️ {character.name}的组织成员关系缺少organization_name,跳过")
continue
# 使用映射快速查找组织
org = organization_name_to_obj.get(org_name)
if org:
# 检查是否已存在成员关系
existing_member = await db.execute(
select(OrganizationMember).where(
OrganizationMember.organization_id == org.id,
OrganizationMember.character_id == character.id
)
)
if existing_member.scalar_one_or_none():
logger.debug(f" ️ 成员关系已存在:{character.name} -> {org_name}")
continue
# 创建成员关系
member = OrganizationMember(
organization_id=org.id,
character_id=character.id,
position=membership.get("position", "成员"),
rank=membership.get("rank", 0),
loyalty=membership.get("loyalty", 50),
joined_at=membership.get("joined_at"),
status=membership.get("status", "active"),
source="ai"
)
db.add(member)
# 更新组织成员计数
org.member_count += 1
members_created += 1
logger.info(f" ✅ 向导添加成员:{character.name} -> {org_name} ({membership.get('position')})")
else:
# 这种情况理论上已经被预处理清理了,但保留日志以防万一
logger.debug(f" ️ 组织引用已被清理:{character.name} -> {org_name}")
except Exception as e:
logger.warning(f" ❌ 向导添加组织成员失败:{character.name} - {str(e)}")
continue
logger.info(f"📊 向导数据统计:")
logger.info(f" - 创建角色/组织:{len(created_characters)}")
logger.info(f" - 创建组织详情:{len(organization_name_to_obj)}")
logger.info(f" - 创建角色关系:{relationships_created}")
logger.info(f" - 创建组织成员:{members_created}")
# 更新项目的角色数量和向导步骤状态为2(角色已完成)
project.character_count = len(created_characters)
project.wizard_step = 2
logger.info(f"✅ 更新项目角色数量: {project.character_count}")
await db.commit()
db_committed = True
# 重新提取character对象
created_characters = [char for char, _ in created_characters]
# 发送结果
yield await SSEResponse.send_result({
"message": f"成功生成{len(created_characters)}个角色/组织(分{total_batches}批完成)",
"count": len(created_characters),
"batches": total_batches,
"characters": [
{
"id": char.id,
"project_id": char.project_id,
"name": char.name,
"age": char.age,
"gender": char.gender,
"is_organization": char.is_organization,
"role_type": char.role_type,
"personality": char.personality,
"background": char.background,
"appearance": char.appearance,
"relationships": char.relationships,
"organization_type": char.organization_type,
"organization_purpose": char.organization_purpose,
"organization_members": char.organization_members,
"traits": char.traits,
"created_at": char.created_at.isoformat() if char.created_at else None,
"updated_at": char.updated_at.isoformat() if char.updated_at else None
} for char in created_characters
]
})
yield await SSEResponse.send_progress("完成!", 100, "success")
yield await SSEResponse.send_done()
except GeneratorExit:
logger.warning("角色生成器被提前关闭")
if not db_committed and db.in_transaction():
await db.rollback()
logger.info("角色生成事务已回滚(GeneratorExit")
except Exception as e:
logger.error(f"角色生成失败: {str(e)}")
if not db_committed and db.in_transaction():
await db.rollback()
logger.info("角色生成事务已回滚(异常)")
yield await SSEResponse.send_error(f"生成失败: {str(e)}")
@router.post("/characters", summary="流式批量生成角色")
async def generate_characters_stream(
request: Request,
data: Dict[str, Any],
db: AsyncSession = Depends(get_db),
user_ai_service: AIService = Depends(get_user_ai_service)
):
"""
使用SSE流式批量生成角色,避免超时
支持MCP工具增强
"""
# 从中间件注入user_id到data中
if hasattr(request.state, 'user_id'):
data['user_id'] = request.state.user_id
return create_sse_response(characters_generator(data, db, user_ai_service))
async def outline_generator(
data: Dict[str, Any],
db: AsyncSession,
user_ai_service: AIService
) -> AsyncGenerator[str, None]:
"""大纲生成流式生成器 - 向导仅生成大纲节点,不展开章节(避免等待过久)"""
db_committed = False
try:
yield await SSEResponse.send_progress("开始生成大纲...", 5)
project_id = data.get("project_id")
# 向导固定生成3个大纲节点(不展开)
outline_count = data.get("chapter_count", 3)
narrative_perspective = data.get("narrative_perspective")
target_words = data.get("target_words", 100000)
requirements = data.get("requirements", "")
provider = data.get("provider")
model = data.get("model")
# 获取项目信息
yield await SSEResponse.send_progress("加载项目信息...", 10)
result = await db.execute(
select(Project).where(Project.id == project_id)
)
project = result.scalar_one_or_none()
if not project:
yield await SSEResponse.send_error("项目不存在", 404)
return
# 获取角色信息
yield await SSEResponse.send_progress("加载角色信息...", 15)
result = await db.execute(
select(Character).where(Character.project_id == project_id)
)
characters = result.scalars().all()
characters_info = "\n".join([
f"- {char.name} ({'组织' if char.is_organization else '角色'}, {char.role_type}): {char.personality[:100] if char.personality else '暂无描述'}"
for char in characters
])
# 第一阶段:生成3个粗粒度大纲节点
yield await SSEResponse.send_progress(f"生成{outline_count}个大纲节点...", 20)
outline_requirements = f"{requirements}\n\n【重要说明】这是小说的开局部分,请生成{outline_count}个大纲节点,重点关注:\n"
outline_requirements += "1. 引入主要角色和世界观设定\n"
outline_requirements += "2. 建立主线冲突和故事钩子\n"
outline_requirements += "3. 展开初期情节,为后续发展埋下伏笔\n"
outline_requirements += "4. 不要试图完结故事,这只是开始部分\n"
outline_requirements += "5. 不要在JSON字符串值中使用中文引号(""''),请使用【】或《》标记\n"
# 获取自定义提示词模板
template = await PromptService.get_template("COMPLETE_OUTLINE_GENERATION", user_id, db)
outline_prompt = PromptService.format_prompt(
template,
title=project.title,
theme=project.theme or "未设定",
genre=project.genre or "通用",
chapter_count=outline_count,
narrative_perspective=narrative_perspective,
target_words=target_words // 10, # 开局约占总字数的1/10
time_period=project.world_time_period or "未设定",
location=project.world_location or "未设定",
atmosphere=project.world_atmosphere or "未设定",
rules=project.world_rules or "未设定",
characters_info=characters_info or "暂无角色信息",
mcp_references="",
requirements=outline_requirements
)
# 流式生成大纲
accumulated_text = ""
async for chunk in user_ai_service.generate_text_stream(
prompt=outline_prompt,
provider=provider,
model=model
):
accumulated_text += chunk
yield await SSEResponse.send_chunk(chunk)
# 解析大纲结果
yield await SSEResponse.send_progress("解析大纲...", 40)
cleaned_text = accumulated_text.strip()
if cleaned_text.startswith('```json'):
cleaned_text = cleaned_text[7:].lstrip('\n\r')
elif cleaned_text.startswith('```'):
cleaned_text = cleaned_text[3:].lstrip('\n\r')
if cleaned_text.endswith('```'):
cleaned_text = cleaned_text[:-3].rstrip('\n\r')
cleaned_text = cleaned_text.strip()
try:
outline_data = json.loads(cleaned_text)
if not isinstance(outline_data, list):
outline_data = [outline_data]
except json.JSONDecodeError as e:
logger.error(f"大纲JSON解析失败: {e}")
yield await SSEResponse.send_error("大纲生成失败,请重试")
return
# 保存大纲到数据库
yield await SSEResponse.send_progress("保存大纲到数据库...", 45)
created_outlines = []
for index, outline_item in enumerate(outline_data[:outline_count], 1):
outline = Outline(
project_id=project_id,
title=outline_item.get("title", f"{index}"),
content=outline_item.get("summary", outline_item.get("content", "")),
structure=json.dumps(outline_item, ensure_ascii=False),
order_index=index
)
db.add(outline)
created_outlines.append(outline)
await db.flush() # 获取大纲ID
for outline in created_outlines:
await db.refresh(outline)
logger.info(f"✅ 成功创建{len(created_outlines)}个大纲节点")
# 根据项目的大纲模式决定是否自动创建章节
created_chapters = []
if project.outline_mode == 'one-to-one':
# 一对一模式:自动为每个大纲创建对应的章节
yield await SSEResponse.send_progress("一对一模式:自动创建章节...", 50)
for outline in created_outlines:
chapter = Chapter(
project_id=project_id,
title=outline.title,
content="", # 空内容,等待用户生成
outline_id=None, # 一对一模式下不关联outline_id
chapter_number=outline.order_index, # 使用chapter_number而不是order_index
status="pending"
)
db.add(chapter)
created_chapters.append(chapter)
await db.flush()
for chapter in created_chapters:
await db.refresh(chapter)
logger.info(f"✅ 一对一模式:自动创建了{len(created_chapters)}个章节")
yield await SSEResponse.send_progress(f"已自动创建{len(created_chapters)}个章节", 85)
else:
# 一对多模式:跳过自动创建,用户可手动展开
yield await SSEResponse.send_progress("细化模式:跳过自动创建章节", 85)
logger.info(f"📝 细化模式:跳过章节创建,用户可在大纲页面手动展开")
# 更新项目信息
project.chapter_count = len(created_chapters) # 记录实际创建的章节数
project.narrative_perspective = narrative_perspective
project.target_words = target_words
project.status = "writing"
project.wizard_status = "completed"
project.wizard_step = 3
await db.commit()
db_committed = True
logger.info(f"📊 向导大纲生成完成:")
logger.info(f" - 创建大纲节点:{len(created_outlines)}")
logger.info(f" - 创建章节:{len(created_chapters)}")
logger.info(f" - 大纲模式:{project.outline_mode}")
# 构建结果消息
if project.outline_mode == 'one-to-one':
result_message = f"成功生成{len(created_outlines)}个大纲节点并自动创建{len(created_chapters)}个章节(传统模式)"
result_note = "已自动创建章节,可直接生成内容"
else:
result_message = f"成功生成{len(created_outlines)}个大纲节点(细化模式,可在大纲页面手动展开)"
result_note = "可在大纲页面展开为多个章节"
# 发送结果
yield await SSEResponse.send_result({
"message": result_message,
"outline_count": len(created_outlines),
"chapter_count": len(created_chapters),
"outline_mode": project.outline_mode,
"outlines": [
{
"id": outline.id,
"order_index": outline.order_index,
"title": outline.title,
"content": outline.content[:100] + "..." if len(outline.content) > 100 else outline.content,
"note": result_note
} for outline in created_outlines
],
"chapters": [
{
"id": chapter.id,
"chapter_number": chapter.chapter_number,
"title": chapter.title,
"status": chapter.status
} for chapter in created_chapters
] if created_chapters else []
})
yield await SSEResponse.send_progress("完成!", 100, "success")
yield await SSEResponse.send_done()
except GeneratorExit:
logger.warning("大纲生成器被提前关闭")
if not db_committed and db.in_transaction():
await db.rollback()
logger.info("大纲生成事务已回滚(GeneratorExit")
except Exception as e:
logger.error(f"大纲生成失败: {str(e)}")
if not db_committed and db.in_transaction():
await db.rollback()
logger.info("大纲生成事务已回滚(异常)")
yield await SSEResponse.send_error(f"生成失败: {str(e)}")
@router.post("/outline", summary="流式生成完整大纲")
async def generate_outline_stream(
data: Dict[str, Any],
db: AsyncSession = Depends(get_db),
user_ai_service: AIService = Depends(get_user_ai_service)
):
"""
使用SSE流式生成完整大纲,避免超时
"""
return create_sse_response(outline_generator(data, db, user_ai_service))
async def world_building_regenerate_generator(
project_id: str,
data: Dict[str, Any],
db: AsyncSession,
user_ai_service: AIService
) -> AsyncGenerator[str, None]:
"""世界观重新生成流式生成器"""
db_committed = False
try:
yield await SSEResponse.send_progress("开始重新生成世界观...", 10)
# 获取项目信息
result = await db.execute(
select(Project).where(Project.id == project_id)
)
project = result.scalar_one_or_none()
if not project:
yield await SSEResponse.send_error("项目不存在", 404)
return
# 提取参数
provider = data.get("provider")
model = data.get("model")
enable_mcp = data.get("enable_mcp", True)
user_id = data.get("user_id")
# 获取基础提示词(支持自定义)
yield await SSEResponse.send_progress("准备AI提示词...", 15)
template = await PromptService.get_template("WORLD_BUILDING", user_id, db)
base_prompt = PromptService.format_prompt(
template,
title=project.title,
theme=project.theme or "未设定",
genre=project.genre or "通用"
)
# MCP工具增强:收集参考资料
reference_materials = ""
if enable_mcp and user_id:
try:
from app.services.mcp_tool_service import mcp_tool_service
available_tools = await mcp_tool_service.get_user_enabled_tools(
user_id=user_id,
db_session=db
)
if available_tools:
yield await SSEResponse.send_progress("🔍 尝试使用MCP工具收集参考资料...", 18)
planning_prompt = f"""你正在为小说《{project.title}》重新设计世界观。
【小说信息】
- 题材:{project.genre}
- 主题:{project.theme}
- 简介:{project.description or '未设定'}
【任务】
请使用可用工具搜索相关背景资料,帮助构建更真实、更有深度的世界观设定。
你可以查询:
1. 历史背景(如果是历史题材)
2. 地理环境和文化特征
3. 相关领域的专业知识
4. 类似作品的设定参考
请查询最关键的1个问题(不要超过1个)。"""
planning_result = await user_ai_service.generate_text_with_mcp(
prompt=planning_prompt,
user_id=user_id,
db_session=db,
enable_mcp=True,
max_tool_rounds=1,
tool_choice="auto",
provider=None,
model=None
)
if planning_result.get("tool_calls_made", 0) > 0:
yield await SSEResponse.send_progress(
f"✅ MCP工具调用成功({planning_result['tool_calls_made']}次)",
25
)
reference_materials = planning_result.get("content", "")
else:
logger.debug("MCP工具可用但AI未选择使用")
else:
logger.debug(f"用户 {user_id} 未启用MCP工具,跳过MCP增强")
except Exception as e:
logger.warning(f"MCP工具调用失败(降级处理): {e}")
yield await SSEResponse.send_progress("⚠️ MCP工具暂时不可用,使用基础模式", 25)
# 构建增强提示词
if reference_materials:
enhanced_prompt = f"""{base_prompt}
【参考资料】
以下是通过MCP工具收集的真实背景资料,请参考这些信息构建更真实的世界观:
{reference_materials}
请结合上述资料,生成符合历史/现实的世界观设定。"""
final_prompt = enhanced_prompt
yield await SSEResponse.send_progress("💡 已整合参考资料,开始生成世界观...", 30)
else:
final_prompt = base_prompt
yield await SSEResponse.send_progress("正在调用AI生成...", 30)
# 流式生成世界观
accumulated_text = ""
chunk_count = 0
async for chunk in user_ai_service.generate_text_stream(
prompt=final_prompt,
provider=provider,
model=model
):
chunk_count += 1
accumulated_text += chunk
yield await SSEResponse.send_chunk(chunk)
if chunk_count % 5 == 0:
progress = min(30 + (chunk_count // 5), 70)
yield await SSEResponse.send_progress(f"生成中... ({len(accumulated_text)}字符)", progress)
if chunk_count % 20 == 0:
yield await SSEResponse.send_heartbeat()
# 解析结果
yield await SSEResponse.send_progress("解析AI返回结果...", 80)
world_data = {}
try:
cleaned_text = accumulated_text.strip()
if cleaned_text.startswith('```json'):
cleaned_text = cleaned_text[7:].lstrip('\n\r')
elif cleaned_text.startswith('```'):
cleaned_text = cleaned_text[3:].lstrip('\n\r')
if cleaned_text.endswith('```'):
cleaned_text = cleaned_text[:-3].rstrip('\n\r')
cleaned_text = cleaned_text.strip()
world_data = json.loads(cleaned_text)
except json.JSONDecodeError as e:
logger.error(f"世界构建JSON解析失败: {e}")
world_data = {
"time_period": "AI返回格式错误,请重试",
"location": "AI返回格式错误,请重试",
"atmosphere": "AI返回格式错误,请重试",
"rules": "AI返回格式错误,请重试"
}
# 不保存到数据库,仅返回生成结果供用户预览
yield await SSEResponse.send_progress("生成完成,等待用户确认...", 90)
# 发送最终结果(不包含project_id,表示未保存)
yield await SSEResponse.send_result({
"time_period": world_data.get("time_period"),
"location": world_data.get("location"),
"atmosphere": world_data.get("atmosphere"),
"rules": world_data.get("rules")
})
yield await SSEResponse.send_progress("完成!", 100, "success")
yield await SSEResponse.send_done()
except GeneratorExit:
logger.warning("世界观重新生成器被提前关闭")
if not db_committed and db.in_transaction():
await db.rollback()
logger.info("世界观重新生成事务已回滚(GeneratorExit")
except Exception as e:
logger.error(f"世界观重新生成失败: {str(e)}")
if not db_committed and db.in_transaction():
await db.rollback()
logger.info("世界观重新生成事务已回滚(异常)")
yield await SSEResponse.send_error(f"生成失败: {str(e)}")
@router.post("/world-building/{project_id}/regenerate", summary="流式重新生成世界观")
async def regenerate_world_building_stream(
project_id: str,
request: Request,
data: Dict[str, Any],
db: AsyncSession = Depends(get_db),
user_ai_service: AIService = Depends(get_user_ai_service)
):
"""
使用SSE流式重新生成世界观,避免超时
前端使用EventSource接收实时进度和结果
"""
# 从中间件注入user_id到data中
if hasattr(request.state, 'user_id'):
data['user_id'] = request.state.user_id
return create_sse_response(world_building_regenerate_generator(project_id, data, db, user_ai_service))