update:新增用户API配置提示 优化大纲全新/续写的分批生成

This commit is contained in:
xiamuceer
2025-10-30 22:01:10 +08:00
parent c4f8cd78f0
commit 12ded06b36
7 changed files with 830 additions and 91 deletions
+534 -67
View File
@@ -2,7 +2,7 @@
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, func, delete
from typing import List
from typing import List, AsyncGenerator, Dict, Any
import json
from app.database import get_db
@@ -23,6 +23,7 @@ from app.services.ai_service import AIService
from app.services.prompt_service import prompt_service
from app.logger import get_logger
from app.api.settings import get_user_ai_service
from app.utils.sse_response import SSEResponse, create_sse_response
router = APIRouter(prefix="/outlines", tags=["大纲管理"])
logger = get_logger(__name__)
@@ -479,27 +480,21 @@ async def _continue_outline(
db: AsyncSession,
user_ai_service: AIService
) -> OutlineListResponse:
"""续写大纲"""
"""续写大纲 - 分批生成,每批5章"""
logger.info(f"续写大纲 - 项目: {project.id}, 已有: {len(existing_outlines)}")
# 分析已有大纲
current_chapter_count = len(existing_outlines)
last_chapter_number = existing_outlines[-1].order_index
# 获取最近2章的剧情
recent_outlines = existing_outlines[-2:] if len(existing_outlines) >= 2 else existing_outlines
recent_plot = "\n".join([
f"{o.order_index}章《{o.title}》: {o.content}"
for o in recent_outlines
])
# logger.debug(f"最近三章内容:{recent_plot}")
# 全部章节概览
all_chapters_brief = "\n".join([
f"{o.order_index}章: {o.title}"
for o in existing_outlines
])
# 计算需要生成的总章数和批次
total_chapters_to_generate = request.chapter_count
batch_size = 5 # 每批生成5章
total_batches = (total_chapters_to_generate + batch_size - 1) // batch_size
# 获取角色信息
logger.info(f"分批生成计划: 总共{total_chapters_to_generate}章,分{total_batches}批,每批{batch_size}")
# 获取角色信息(所有批次共用)
characters_result = await db.execute(
select(Character).where(Character.project_id == project.id)
)
@@ -518,65 +513,104 @@ async def _continue_outline(
}
stage_instruction = stage_instructions.get(request.plot_stage, "")
# 使用标准续写提示词模板
prompt = prompt_service.get_outline_continue_prompt(
title=project.title,
theme=request.theme or project.theme or "未设定",
genre=request.genre or project.genre or "通用",
narrative_perspective=request.narrative_perspective,
chapter_count=request.chapter_count,
time_period=project.world_time_period or "未设定",
location=project.world_location or "未设定",
atmosphere=project.world_atmosphere or "未设定",
rules=project.world_rules or "未设定",
characters_info=characters_info or "暂无角色信息",
current_chapter_count=current_chapter_count,
all_chapters_brief=all_chapters_brief,
recent_plot=recent_plot,
plot_stage_instruction=stage_instruction,
start_chapter=last_chapter_number + 1,
story_direction=request.story_direction or "自然延续",
requirements=request.requirements or ""
)
# 批量生成
all_new_outlines = []
current_start_chapter = last_chapter_number + 1
# 调用AI
ai_response = await user_ai_service.generate_text(
prompt=prompt,
provider=request.provider,
model=request.model
)
# 解析响应
outline_data = _parse_ai_response(ai_response)
# 保存续写的大纲
new_outlines = await _save_outlines(
project.id, outline_data, db, start_index=last_chapter_number + 1
)
# 记录历史
history = GenerationHistory(
project_id=project.id,
prompt=prompt,
generated_content=ai_response,
model=request.model or "default"
)
db.add(history)
await db.commit()
for outline in new_outlines:
await db.refresh(outline)
for batch_num in range(total_batches):
# 计算当前批次的章节数
remaining_chapters = total_chapters_to_generate - len(all_new_outlines)
current_batch_size = min(batch_size, remaining_chapters)
logger.info(f"开始生成第{batch_num + 1}/{total_batches}批,章节范围: {current_start_chapter}-{current_start_chapter + current_batch_size - 1}")
# 获取最新的大纲列表(包括之前批次生成的)
latest_result = await db.execute(
select(Outline)
.where(Outline.project_id == project.id)
.order_by(Outline.order_index)
)
latest_outlines = latest_result.scalars().all()
# 获取最近2章的剧情
recent_outlines = latest_outlines[-2:] if len(latest_outlines) >= 2 else latest_outlines
recent_plot = "\n".join([
f"{o.order_index}章《{o.title}》: {o.content}"
for o in recent_outlines
])
# 全部章节概览
all_chapters_brief = "\n".join([
f"{o.order_index}章: {o.title}"
for o in latest_outlines
])
# 使用标准续写提示词模板
prompt = prompt_service.get_outline_continue_prompt(
title=project.title,
theme=request.theme or project.theme or "未设定",
genre=request.genre or project.genre or "通用",
narrative_perspective=request.narrative_perspective,
chapter_count=current_batch_size, # 当前批次的章节数
time_period=project.world_time_period or "未设定",
location=project.world_location or "未设定",
atmosphere=project.world_atmosphere or "未设定",
rules=project.world_rules or "未设定",
characters_info=characters_info or "暂无角色信息",
current_chapter_count=len(latest_outlines),
all_chapters_brief=all_chapters_brief,
recent_plot=recent_plot,
plot_stage_instruction=stage_instruction,
start_chapter=current_start_chapter,
story_direction=request.story_direction or "自然延续",
requirements=request.requirements or ""
)
# 调用AI生成当前批次
logger.info(f"正在调用AI生成第{batch_num + 1}批...")
ai_response = await user_ai_service.generate_text(
prompt=prompt,
provider=request.provider,
model=request.model
)
# 解析响应
outline_data = _parse_ai_response(ai_response)
# 保存当前批次的大纲
batch_outlines = await _save_outlines(
project.id, outline_data, db, start_index=current_start_chapter
)
# 记录历史
history = GenerationHistory(
project_id=project.id,
prompt=f"[批次{batch_num + 1}/{total_batches}] {str(prompt)[:500]}",
generated_content=ai_response,
model=request.model or "default"
)
db.add(history)
# 提交当前批次
await db.commit()
for outline in batch_outlines:
await db.refresh(outline)
all_new_outlines.extend(batch_outlines)
current_start_chapter += current_batch_size
logger.info(f"{batch_num + 1}批生成完成,本批生成{len(batch_outlines)}")
# 返回所有大纲(包括旧的和新的)
all_result = await db.execute(
final_result = await db.execute(
select(Outline)
.where(Outline.project_id == project.id)
.order_by(Outline.order_index)
)
all_outlines = all_result.scalars().all()
all_outlines = final_result.scalars().all()
logger.info(f"续写完成 - 新增 {len(new_outlines)} 章,总计 {len(all_outlines)}")
logger.info(f"续写完成 - {total_batches}批,新增 {len(all_new_outlines)} 章,总计 {len(all_outlines)}")
return OutlineListResponse(total=len(all_outlines), items=all_outlines)
@@ -658,4 +692,437 @@ async def _save_outlines(
)
db.add(chapter)
return outlines
return outlines
async def new_outline_generator(
data: Dict[str, Any],
db: AsyncSession,
user_ai_service: AIService
) -> AsyncGenerator[str, None]:
"""全新生成大纲SSE生成器"""
db_committed = False
try:
yield await SSEResponse.send_progress("开始生成大纲...", 5)
project_id = data.get("project_id")
# 确保chapter_count是整数(前端可能传字符串)
chapter_count = int(data.get("chapter_count", 10))
# 验证项目
yield await SSEResponse.send_progress("加载项目信息...", 10)
result = await db.execute(
select(Project).where(Project.id == project_id)
)
project = result.scalar_one_or_none()
if not project:
yield await SSEResponse.send_error("项目不存在", 404)
return
yield await SSEResponse.send_progress(f"准备生成{chapter_count}章大纲...", 15)
# 获取角色信息
characters_result = await db.execute(
select(Character).where(Character.project_id == project_id)
)
characters = characters_result.scalars().all()
characters_info = "\n".join([
f"- {char.name} ({'组织' if char.is_organization else '角色'}, {char.role_type}): "
f"{char.personality[:100] if char.personality else '暂无描述'}"
for char in characters
])
# 使用完整提示词
yield await SSEResponse.send_progress("准备AI提示词...", 20)
prompt = prompt_service.get_complete_outline_prompt(
title=project.title,
theme=data.get("theme") or project.theme or "未设定",
genre=data.get("genre") or project.genre or "通用",
chapter_count=chapter_count,
narrative_perspective=data.get("narrative_perspective") or "第三人称",
target_words=data.get("target_words") or project.target_words or 100000,
time_period=project.world_time_period or "未设定",
location=project.world_location or "未设定",
atmosphere=project.world_atmosphere or "未设定",
rules=project.world_rules or "未设定",
characters_info=characters_info or "暂无角色信息",
requirements=data.get("requirements") or ""
)
# 调用AI
yield await SSEResponse.send_progress("🤖 正在调用AI生成...", 30)
ai_response = await user_ai_service.generate_text(
prompt=prompt,
provider=data.get("provider"),
model=data.get("model")
)
yield await SSEResponse.send_progress("✅ AI生成完成,正在解析...", 70)
# 解析响应
outline_data = _parse_ai_response(ai_response)
# 删除旧大纲和章节
yield await SSEResponse.send_progress("清理旧数据...", 75)
logger.info(f"删除项目 {project_id} 的旧大纲和章节")
await db.execute(
delete(Outline).where(Outline.project_id == project_id)
)
await db.execute(
delete(Chapter).where(Chapter.project_id == project_id)
)
# 保存新大纲
yield await SSEResponse.send_progress("💾 保存大纲到数据库...", 80)
outlines = await _save_outlines(
project_id, outline_data, db, start_index=1
)
# 记录历史
history = GenerationHistory(
project_id=project_id,
prompt=prompt,
generated_content=ai_response,
model=data.get("model") or "default"
)
db.add(history)
await db.commit()
db_committed = True
for outline in outlines:
await db.refresh(outline)
yield await SSEResponse.send_progress("整理结果数据...", 95)
logger.info(f"全新生成完成 - {len(outlines)}")
# 发送最终结果
yield await SSEResponse.send_result({
"message": f"成功生成{len(outlines)}章大纲",
"total_chapters": len(outlines),
"outlines": [
{
"id": outline.id,
"project_id": outline.project_id,
"title": outline.title,
"content": outline.content,
"order_index": outline.order_index,
"structure": outline.structure,
"created_at": outline.created_at.isoformat() if outline.created_at else None,
"updated_at": outline.updated_at.isoformat() if outline.updated_at else None
} for outline in outlines
]
})
yield await SSEResponse.send_progress("🎉 生成完成!", 100, "success")
yield await SSEResponse.send_done()
except GeneratorExit:
logger.warning("大纲生成器被提前关闭")
if not db_committed and db.in_transaction():
await db.rollback()
logger.info("大纲生成事务已回滚(GeneratorExit")
except Exception as e:
logger.error(f"大纲生成失败: {str(e)}")
if not db_committed and db.in_transaction():
await db.rollback()
logger.info("大纲生成事务已回滚(异常)")
yield await SSEResponse.send_error(f"生成失败: {str(e)}")
async def continue_outline_generator(
data: Dict[str, Any],
db: AsyncSession,
user_ai_service: AIService
) -> AsyncGenerator[str, None]:
"""大纲续写SSE生成器 - 分批生成,推送进度"""
db_committed = False
try:
yield await SSEResponse.send_progress("开始续写大纲...", 5)
project_id = data.get("project_id")
# 确保chapter_count是整数(前端可能传字符串)
total_chapters_to_generate = int(data.get("chapter_count", 5))
# 验证项目
yield await SSEResponse.send_progress("加载项目信息...", 10)
result = await db.execute(
select(Project).where(Project.id == project_id)
)
project = result.scalar_one_or_none()
if not project:
yield await SSEResponse.send_error("项目不存在", 404)
return
# 获取现有大纲
yield await SSEResponse.send_progress("分析已有大纲...", 15)
existing_result = await db.execute(
select(Outline)
.where(Outline.project_id == project_id)
.order_by(Outline.order_index)
)
existing_outlines = existing_result.scalars().all()
if not existing_outlines:
yield await SSEResponse.send_error("续写模式需要已有大纲,当前项目没有大纲", 400)
return
current_chapter_count = len(existing_outlines)
last_chapter_number = existing_outlines[-1].order_index
yield await SSEResponse.send_progress(
f"当前已有{str(current_chapter_count)}章,将续写{str(total_chapters_to_generate)}",
20
)
# 获取角色信息
characters_result = await db.execute(
select(Character).where(Character.project_id == project_id)
)
characters = characters_result.scalars().all()
characters_info = "\n".join([
f"- {char.name} ({'组织' if char.is_organization else '角色'}, {char.role_type}): "
f"{char.personality[:100] if char.personality else '暂无描述'}"
for char in characters
])
# 分批配置
batch_size = 5
total_batches = (total_chapters_to_generate + batch_size - 1) // batch_size
yield await SSEResponse.send_progress(
f"分批生成计划: 总共{str(total_chapters_to_generate)}章,分{str(total_batches)}批,每批{str(batch_size)}",
25
)
# 情节阶段指导
stage_instructions = {
"development": "继续展开情节,深化角色关系,推进主线冲突",
"climax": "进入故事高潮,矛盾激化,关键冲突爆发",
"ending": "解决主要冲突,收束伏笔,给出结局"
}
stage_instruction = stage_instructions.get(data.get("plot_stage", "development"), "")
# 批量生成
all_new_outlines = []
current_start_chapter = last_chapter_number + 1
for batch_num in range(total_batches):
# 计算当前批次的章节数
remaining_chapters = int(total_chapters_to_generate) - len(all_new_outlines)
current_batch_size = min(batch_size, remaining_chapters)
batch_progress = 25 + (batch_num * 60 // total_batches)
yield await SSEResponse.send_progress(
f"📝 第{str(batch_num + 1)}/{str(total_batches)}批: 生成第{str(current_start_chapter)}-{str(current_start_chapter + current_batch_size - 1)}",
batch_progress
)
# 获取最新的大纲列表(包括之前批次生成的)
latest_result = await db.execute(
select(Outline)
.where(Outline.project_id == project_id)
.order_by(Outline.order_index)
)
latest_outlines = latest_result.scalars().all()
# 获取最近2章的剧情
recent_outlines = latest_outlines[-2:] if len(latest_outlines) >= 2 else latest_outlines
recent_plot = "\n".join([
f"{o.order_index}章《{o.title}》: {o.content}"
for o in recent_outlines
])
# 全部章节概览
all_chapters_brief = "\n".join([
f"{o.order_index}章: {o.title}"
for o in latest_outlines
])
yield await SSEResponse.send_progress(
f"🤖 调用AI生成第{str(batch_num + 1)}批...",
batch_progress + 5
)
# 使用标准续写提示词模板
prompt = prompt_service.get_outline_continue_prompt(
title=project.title,
theme=data.get("theme") or project.theme or "未设定",
genre=data.get("genre") or project.genre or "通用",
narrative_perspective=data.get("narrative_perspective") or project.narrative_perspective or "第三人称",
chapter_count=current_batch_size,
time_period=project.world_time_period or "未设定",
location=project.world_location or "未设定",
atmosphere=project.world_atmosphere or "未设定",
rules=project.world_rules or "未设定",
characters_info=characters_info or "暂无角色信息",
current_chapter_count=len(latest_outlines),
all_chapters_brief=all_chapters_brief,
recent_plot=recent_plot,
plot_stage_instruction=stage_instruction,
start_chapter=current_start_chapter,
story_direction=data.get("story_direction", "自然延续"),
requirements=data.get("requirements", "")
)
# 调用AI生成当前批次
ai_response = await user_ai_service.generate_text(
prompt=prompt,
provider=data.get("provider"),
model=data.get("model")
)
yield await SSEResponse.send_progress(
f"✅ 第{str(batch_num + 1)}批AI生成完成,正在解析...",
batch_progress + 10
)
# 解析响应
outline_data = _parse_ai_response(ai_response)
# 保存当前批次的大纲
batch_outlines = await _save_outlines(
project_id, outline_data, db, start_index=current_start_chapter
)
# 记录历史
history = GenerationHistory(
project_id=project_id,
prompt=f"[续写批次{batch_num + 1}/{total_batches}] {str(prompt)[:500]}",
generated_content=ai_response,
model=data.get("model") or "default"
)
db.add(history)
# 提交当前批次
await db.commit()
for outline in batch_outlines:
await db.refresh(outline)
all_new_outlines.extend(batch_outlines)
current_start_chapter += current_batch_size
yield await SSEResponse.send_progress(
f"💾 第{str(batch_num + 1)}批保存成功!本批生成{str(len(batch_outlines))}章,累计新增{str(len(all_new_outlines))}",
batch_progress + 15
)
logger.info(f"{str(batch_num + 1)}批生成完成,本批生成{str(len(batch_outlines))}")
db_committed = True
# 返回所有大纲(包括旧的和新的)
final_result = await db.execute(
select(Outline)
.where(Outline.project_id == project_id)
.order_by(Outline.order_index)
)
all_outlines = final_result.scalars().all()
yield await SSEResponse.send_progress("整理结果数据...", 95)
# 发送最终结果
yield await SSEResponse.send_result({
"message": f"续写完成!共{str(total_batches)}批,新增{str(len(all_new_outlines))}章,总计{str(len(all_outlines))}",
"total_batches": total_batches,
"new_chapters": len(all_new_outlines),
"total_chapters": len(all_outlines),
"outlines": [
{
"id": outline.id,
"project_id": outline.project_id,
"title": outline.title,
"content": outline.content,
"order_index": outline.order_index,
"structure": outline.structure,
"created_at": outline.created_at.isoformat() if outline.created_at else None,
"updated_at": outline.updated_at.isoformat() if outline.updated_at else None
} for outline in all_outlines
]
})
yield await SSEResponse.send_progress("🎉 续写完成!", 100, "success")
yield await SSEResponse.send_done()
except GeneratorExit:
logger.warning("大纲续写生成器被提前关闭")
if not db_committed and db.in_transaction():
await db.rollback()
logger.info("大纲续写事务已回滚(GeneratorExit")
except Exception as e:
logger.error(f"大纲续写失败: {str(e)}")
if not db_committed and db.in_transaction():
await db.rollback()
logger.info("大纲续写事务已回滚(异常)")
yield await SSEResponse.send_error(f"续写失败: {str(e)}")
@router.post("/generate-stream", summary="AI生成/续写大纲(SSE流式)")
async def generate_outline_stream(
data: Dict[str, Any],
db: AsyncSession = Depends(get_db),
user_ai_service: AIService = Depends(get_user_ai_service)
):
"""
使用SSE流式生成或续写小说大纲,实时推送批次进度
支持模式:
- auto: 自动判断(无大纲→新建,有大纲→续写)
- new: 全新生成
- continue: 续写模式
请求体示例:
{
"project_id": "项目ID",
"chapter_count": 5, // 章节数
"mode": "auto", // auto/new/continue
"theme": "故事主题", // new模式必需
"story_direction": "故事发展方向", // continue模式可选
"plot_stage": "development", // continue模式:development/climax/ending
"narrative_perspective": "第三人称",
"requirements": "其他要求",
"provider": "openai", // 可选
"model": "gpt-4" // 可选
}
"""
# 验证项目是否存在
result = await db.execute(
select(Project).where(Project.id == data.get("project_id"))
)
project = result.scalar_one_or_none()
if not project:
raise HTTPException(status_code=404, detail="项目不存在")
# 判断模式
mode = data.get("mode", "auto")
# 获取现有大纲
existing_result = await db.execute(
select(Outline)
.where(Outline.project_id == data.get("project_id"))
.order_by(Outline.order_index)
)
existing_outlines = existing_result.scalars().all()
# 自动判断模式
if mode == "auto":
mode = "continue" if existing_outlines else "new"
logger.info(f"自动判断模式:{'续写' if existing_outlines else '新建'}")
# 根据模式选择生成器
if mode == "new":
return create_sse_response(new_outline_generator(data, db, user_ai_service))
elif mode == "continue":
if not existing_outlines:
raise HTTPException(
status_code=400,
detail="续写模式需要已有大纲,当前项目没有大纲"
)
return create_sse_response(continue_outline_generator(data, db, user_ai_service))
else:
raise HTTPException(
status_code=400,
detail=f"不支持的模式: {mode}"
)
+1 -1
View File
@@ -148,7 +148,7 @@ async def get_db(request: Request):
logger.debug(f"📊 会话关闭 [User:{user_id}][ID:{session_id}] - 活跃:{_session_stats['active']}, 总创建:{_session_stats['created']}, 总关闭:{_session_stats['closed']}, 错误:{_session_stats['errors']}")
if _session_stats["active"] > 10:
if _session_stats["active"] > 100:
logger.warning(f"🚨 活跃会话数过多: {_session_stats['active']},可能存在连接泄漏!")
elif _session_stats["active"] < 0:
logger.error(f"🚨 活跃会话数异常: {_session_stats['active']},统计可能不准确!")
+2 -2
View File
@@ -150,7 +150,7 @@ class PromptService:
3. 不要引用任何本批次中不存在的角色或组织名称
4. 文本描述中不要使用中文引号(""),改用【】或《》"""
# 完整大纲生成提示词
# 向导大纲生成提示词
COMPLETE_OUTLINE_GENERATION = """你是一位经验丰富的小说作家和编剧。请根据以下信息生成完整的{chapter_count}章小说大纲:
基本信息:
@@ -639,7 +639,7 @@ class PromptService:
target_words: int, time_period: str, location: str,
atmosphere: str, rules: str, characters_info: str,
requirements: str = "") -> str:
"""获取完整大纲生成提示词"""
"""获取向导大纲生成提示词"""
return cls.format_prompt(
cls.COMPLETE_OUTLINE_GENERATION,
title=title,