update:1.更新根据分析建议重新生成章节内容
This commit is contained in:
@@ -1135,386 +1135,3 @@ async def generate_outline_stream(
|
||||
"""
|
||||
return create_sse_response(outline_generator(data, db, user_ai_service))
|
||||
|
||||
|
||||
async def update_world_building_generator(
|
||||
project_id: str,
|
||||
data: Dict[str, Any],
|
||||
db: AsyncSession
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""更新世界观流式生成器"""
|
||||
db_committed = False
|
||||
try:
|
||||
yield await SSEResponse.send_progress("开始更新世界观...", 10)
|
||||
|
||||
# 获取项目
|
||||
result = await db.execute(
|
||||
select(Project).where(Project.id == project_id)
|
||||
)
|
||||
project = result.scalar_one_or_none()
|
||||
if not project:
|
||||
yield await SSEResponse.send_error("项目不存在", 404)
|
||||
return
|
||||
|
||||
yield await SSEResponse.send_progress("验证数据...", 30)
|
||||
|
||||
# 更新世界观字段
|
||||
if "time_period" in data:
|
||||
project.world_time_period = data["time_period"]
|
||||
if "location" in data:
|
||||
project.world_location = data["location"]
|
||||
if "atmosphere" in data:
|
||||
project.world_atmosphere = data["atmosphere"]
|
||||
if "rules" in data:
|
||||
project.world_rules = data["rules"]
|
||||
|
||||
yield await SSEResponse.send_progress("保存到数据库...", 70)
|
||||
|
||||
await db.commit()
|
||||
db_committed = True
|
||||
await db.refresh(project)
|
||||
|
||||
# 发送结果
|
||||
yield await SSEResponse.send_result({
|
||||
"project_id": project.id,
|
||||
"time_period": project.world_time_period,
|
||||
"location": project.world_location,
|
||||
"atmosphere": project.world_atmosphere,
|
||||
"rules": project.world_rules
|
||||
})
|
||||
|
||||
yield await SSEResponse.send_progress("完成!", 100, "success")
|
||||
yield await SSEResponse.send_done()
|
||||
|
||||
except GeneratorExit:
|
||||
logger.warning("更新世界观生成器被提前关闭")
|
||||
if not db_committed and db.in_transaction():
|
||||
await db.rollback()
|
||||
logger.info("更新世界观事务已回滚(GeneratorExit)")
|
||||
except Exception as e:
|
||||
logger.error(f"更新世界观失败: {str(e)}")
|
||||
if not db_committed and db.in_transaction():
|
||||
await db.rollback()
|
||||
logger.info("更新世界观事务已回滚(异常)")
|
||||
yield await SSEResponse.send_error(f"更新失败: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/world-building/{project_id}", summary="流式更新世界观")
|
||||
async def update_world_building_stream(
|
||||
project_id: str,
|
||||
data: Dict[str, Any],
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
使用SSE流式更新项目的世界观信息
|
||||
请求体格式:
|
||||
{
|
||||
"time_period": "时间背景",
|
||||
"location": "地理位置",
|
||||
"atmosphere": "氛围基调",
|
||||
"rules": "世界规则"
|
||||
}
|
||||
"""
|
||||
return create_sse_response(update_world_building_generator(project_id, data, db))
|
||||
|
||||
|
||||
async def regenerate_world_building_generator(
|
||||
project_id: str,
|
||||
data: Dict[str, Any],
|
||||
db: AsyncSession,
|
||||
user_ai_service: AIService
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""重新生成世界观流式生成器 - 支持MCP工具增强"""
|
||||
db_committed = False
|
||||
try:
|
||||
yield await SSEResponse.send_progress("开始重新生成世界观...", 10)
|
||||
|
||||
# 获取项目
|
||||
result = await db.execute(
|
||||
select(Project).where(Project.id == project_id)
|
||||
)
|
||||
project = result.scalar_one_or_none()
|
||||
if not project:
|
||||
yield await SSEResponse.send_error("项目不存在", 404)
|
||||
return
|
||||
|
||||
provider = data.get("provider")
|
||||
model = data.get("model")
|
||||
enable_mcp = data.get("enable_mcp", True) # 默认启用MCP
|
||||
user_id = data.get("user_id") # 从中间件注入
|
||||
|
||||
# 获取基础提示词
|
||||
yield await SSEResponse.send_progress("准备AI提示词...", 15)
|
||||
base_prompt = prompt_service.get_world_building_prompt(
|
||||
title=project.title,
|
||||
theme=project.theme or "",
|
||||
genre=project.genre or ""
|
||||
)
|
||||
|
||||
# MCP工具增强:收集参考资料
|
||||
reference_materials = ""
|
||||
if enable_mcp and user_id:
|
||||
try:
|
||||
yield await SSEResponse.send_progress("🔍 尝试使用MCP工具收集参考资料...", 18)
|
||||
|
||||
# 直接调用MCP增强的AI,内部会自动检查和加载工具
|
||||
# 构建资料收集提示词
|
||||
planning_prompt = f"""你正在为小说《{project.title}》重新设计世界观。
|
||||
|
||||
【小说信息】
|
||||
- 题材:{project.genre or '未设定'}
|
||||
- 主题:{project.theme or '未设定'}
|
||||
|
||||
【任务】
|
||||
请使用可用工具搜索相关背景资料,帮助构建更真实、更有深度的世界观设定。
|
||||
你可以查询:
|
||||
1. 历史背景(如果是历史题材)
|
||||
2. 地理环境和文化特征
|
||||
3. 相关领域的专业知识
|
||||
4. 类似作品的设定参考
|
||||
|
||||
请根据题材特点,有针对性地查询2-3个关键问题。"""
|
||||
|
||||
# 调用MCP增强的AI(非流式,最多2轮工具调用)
|
||||
planning_result = await user_ai_service.generate_text_with_mcp(
|
||||
prompt=planning_prompt,
|
||||
user_id=user_id,
|
||||
db_session=db,
|
||||
enable_mcp=True,
|
||||
max_tool_rounds=2,
|
||||
tool_choice="auto",
|
||||
provider=None,
|
||||
model=None
|
||||
)
|
||||
|
||||
# 提取参考资料
|
||||
if planning_result.get("tool_calls_made", 0) > 0:
|
||||
yield await SSEResponse.send_progress(
|
||||
f"✅ MCP工具调用成功({planning_result['tool_calls_made']}次)",
|
||||
25
|
||||
)
|
||||
reference_materials = planning_result.get("content", "")
|
||||
else:
|
||||
yield await SSEResponse.send_progress("ℹ️ 未使用MCP工具(无可用工具或不需要)", 25)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"MCP工具调用失败(降级处理): {e}")
|
||||
yield await SSEResponse.send_progress("⚠️ MCP工具暂时不可用,使用基础模式", 25)
|
||||
|
||||
# 构建增强提示词
|
||||
if reference_materials:
|
||||
enhanced_prompt = f"""{base_prompt}
|
||||
|
||||
【参考资料】
|
||||
以下是通过MCP工具收集的真实背景资料,请参考这些信息构建更真实的世界观:
|
||||
|
||||
{reference_materials}
|
||||
|
||||
请结合上述资料,生成符合历史/现实的世界观设定。"""
|
||||
final_prompt = enhanced_prompt
|
||||
yield await SSEResponse.send_progress("💡 已整合参考资料,开始重新生成世界观...", 30)
|
||||
else:
|
||||
final_prompt = base_prompt
|
||||
yield await SSEResponse.send_progress("正在调用AI生成...", 30)
|
||||
|
||||
# 流式生成世界观
|
||||
accumulated_text = ""
|
||||
chunk_count = 0
|
||||
|
||||
async for chunk in user_ai_service.generate_text_stream(
|
||||
prompt=final_prompt,
|
||||
provider=provider,
|
||||
model=model
|
||||
):
|
||||
chunk_count += 1
|
||||
accumulated_text += chunk
|
||||
|
||||
# 发送内容块
|
||||
yield await SSEResponse.send_chunk(chunk)
|
||||
|
||||
# 定期更新进度
|
||||
if chunk_count % 5 == 0:
|
||||
progress = min(30 + (chunk_count // 5), 70)
|
||||
yield await SSEResponse.send_progress(f"生成中... ({len(accumulated_text)}字符)", progress)
|
||||
|
||||
# 每20个块发送心跳
|
||||
if chunk_count % 20 == 0:
|
||||
yield await SSEResponse.send_heartbeat()
|
||||
|
||||
# 解析结果
|
||||
yield await SSEResponse.send_progress("解析AI返回结果...", 80)
|
||||
|
||||
world_data = {}
|
||||
try:
|
||||
cleaned_text = accumulated_text.strip()
|
||||
# 移除markdown代码块标记
|
||||
if cleaned_text.startswith('```json'):
|
||||
cleaned_text = cleaned_text[7:].lstrip('\n\r')
|
||||
elif cleaned_text.startswith('```'):
|
||||
cleaned_text = cleaned_text[3:].lstrip('\n\r')
|
||||
if cleaned_text.endswith('```'):
|
||||
cleaned_text = cleaned_text[:-3].rstrip('\n\r')
|
||||
cleaned_text = cleaned_text.strip()
|
||||
|
||||
world_data = json.loads(cleaned_text)
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"AI返回非JSON格式: {e}")
|
||||
logger.info(world_data)
|
||||
world_data = {
|
||||
"time_period": "AI返回格式错误,请重试",
|
||||
"location": "AI返回格式错误,请重试",
|
||||
"atmosphere": "AI返回格式错误,请重试",
|
||||
"rules": "AI返回格式错误,请重试"
|
||||
}
|
||||
|
||||
# 更新项目世界观
|
||||
yield await SSEResponse.send_progress("保存到数据库...", 90)
|
||||
|
||||
project.world_time_period = world_data.get("time_period")
|
||||
project.world_location = world_data.get("location")
|
||||
project.world_atmosphere = world_data.get("atmosphere")
|
||||
project.world_rules = world_data.get("rules")
|
||||
|
||||
await db.commit()
|
||||
db_committed = True
|
||||
await db.refresh(project)
|
||||
|
||||
# 发送结果
|
||||
yield await SSEResponse.send_result({
|
||||
"project_id": project.id,
|
||||
"time_period": project.world_time_period,
|
||||
"location": project.world_location,
|
||||
"atmosphere": project.world_atmosphere,
|
||||
"rules": project.world_rules
|
||||
})
|
||||
|
||||
yield await SSEResponse.send_progress("完成!", 100, "success")
|
||||
yield await SSEResponse.send_done()
|
||||
|
||||
except GeneratorExit:
|
||||
logger.warning("重新生成世界观生成器被提前关闭")
|
||||
if not db_committed and db.in_transaction():
|
||||
await db.rollback()
|
||||
logger.info("重新生成世界观事务已回滚(GeneratorExit)")
|
||||
except Exception as e:
|
||||
logger.error(f"重新生成世界观失败: {str(e)}")
|
||||
if not db_committed and db.in_transaction():
|
||||
await db.rollback()
|
||||
logger.info("重新生成世界观事务已回滚(异常)")
|
||||
yield await SSEResponse.send_error(f"重新生成失败: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/world-building/{project_id}/regenerate", summary="流式重新生成世界观")
|
||||
async def regenerate_world_building_stream(
|
||||
request: Request,
|
||||
project_id: str,
|
||||
data: Dict[str, Any],
|
||||
db: AsyncSession = Depends(get_db),
|
||||
user_ai_service: AIService = Depends(get_user_ai_service)
|
||||
):
|
||||
"""
|
||||
使用SSE流式重新生成项目的世界观
|
||||
请求体格式:
|
||||
{
|
||||
"provider": "AI提供商(可选)",
|
||||
"model": "模型名称(可选)"
|
||||
}
|
||||
"""
|
||||
# 从中间件注入user_id到data中
|
||||
if hasattr(request.state, 'user_id'):
|
||||
data['user_id'] = request.state.user_id
|
||||
|
||||
return create_sse_response(regenerate_world_building_generator(project_id, data, db, user_ai_service))
|
||||
|
||||
|
||||
async def cleanup_wizard_data_generator(
|
||||
project_id: str,
|
||||
db: AsyncSession
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""清理向导数据流式生成器"""
|
||||
db_committed = False
|
||||
try:
|
||||
yield await SSEResponse.send_progress("开始清理向导数据...", 10)
|
||||
|
||||
# 获取项目
|
||||
result = await db.execute(
|
||||
select(Project).where(Project.id == project_id)
|
||||
)
|
||||
project = result.scalar_one_or_none()
|
||||
if not project:
|
||||
yield await SSEResponse.send_error("项目不存在", 404)
|
||||
return
|
||||
|
||||
# 删除相关的角色
|
||||
yield await SSEResponse.send_progress("删除角色数据...", 30)
|
||||
characters = await db.execute(
|
||||
select(Character).where(Character.project_id == project_id)
|
||||
)
|
||||
char_count = 0
|
||||
for character in characters.scalars():
|
||||
await db.delete(character)
|
||||
char_count += 1
|
||||
|
||||
# 删除相关的大纲
|
||||
yield await SSEResponse.send_progress("删除大纲数据...", 50)
|
||||
outlines = await db.execute(
|
||||
select(Outline).where(Outline.project_id == project_id)
|
||||
)
|
||||
outline_count = 0
|
||||
for outline in outlines.scalars():
|
||||
await db.delete(outline)
|
||||
outline_count += 1
|
||||
|
||||
# 删除相关的章节
|
||||
yield await SSEResponse.send_progress("删除章节数据...", 70)
|
||||
chapters = await db.execute(
|
||||
select(Chapter).where(Chapter.project_id == project_id)
|
||||
)
|
||||
chapter_count = 0
|
||||
for chapter in chapters.scalars():
|
||||
await db.delete(chapter)
|
||||
chapter_count += 1
|
||||
|
||||
# 删除项目
|
||||
yield await SSEResponse.send_progress("删除项目...", 85)
|
||||
await db.delete(project)
|
||||
|
||||
yield await SSEResponse.send_progress("提交数据库更改...", 95)
|
||||
await db.commit()
|
||||
db_committed = True
|
||||
|
||||
# 发送结果
|
||||
yield await SSEResponse.send_result({
|
||||
"message": "项目及相关数据已清理",
|
||||
"deleted": {
|
||||
"characters": char_count,
|
||||
"outlines": outline_count,
|
||||
"chapters": chapter_count
|
||||
}
|
||||
})
|
||||
|
||||
yield await SSEResponse.send_progress("完成!", 100, "success")
|
||||
yield await SSEResponse.send_done()
|
||||
|
||||
except GeneratorExit:
|
||||
logger.warning("清理向导数据生成器被提前关闭")
|
||||
if not db_committed and db.in_transaction():
|
||||
await db.rollback()
|
||||
logger.info("清理向导数据事务已回滚(GeneratorExit)")
|
||||
except Exception as e:
|
||||
logger.error(f"清理数据失败: {str(e)}")
|
||||
if not db_committed and db.in_transaction():
|
||||
await db.rollback()
|
||||
logger.info("清理向导数据事务已回滚(异常)")
|
||||
yield await SSEResponse.send_error(f"清理失败: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/cleanup/{project_id}", summary="流式清理向导数据")
|
||||
async def cleanup_wizard_data_stream(
|
||||
project_id: str,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
使用SSE流式清理向导过程中创建的项目及相关数据
|
||||
用于返回上一步时清理已生成的内容
|
||||
"""
|
||||
return create_sse_response(cleanup_wizard_data_generator(project_id, db))
|
||||
Reference in New Issue
Block a user