Files
MuMuAINovel/backend/app/api/chapters.py
T
2025-11-05 00:11:27 +08:00

1258 lines
52 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""章节管理API"""
from fastapi import APIRouter, Depends, HTTPException, Request, Query, BackgroundTasks
from fastapi.responses import StreamingResponse
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, func
import json
import asyncio
from typing import Optional
from datetime import datetime
from asyncio import Queue, Lock
from app.database import get_db
from app.models.chapter import Chapter
from app.models.project import Project
from app.models.outline import Outline
from app.models.character import Character
from app.models.generation_history import GenerationHistory
from app.models.writing_style import WritingStyle
from app.models.analysis_task import AnalysisTask
from app.models.memory import PlotAnalysis, StoryMemory
from app.schemas.chapter import (
ChapterCreate,
ChapterUpdate,
ChapterResponse,
ChapterListResponse,
ChapterGenerateRequest
)
from app.services.ai_service import AIService
from app.services.prompt_service import prompt_service
from app.services.plot_analyzer import PlotAnalyzer
from app.services.memory_service import memory_service
from app.logger import get_logger
from app.api.settings import get_user_ai_service
router = APIRouter(prefix="/chapters", tags=["章节管理"])
logger = get_logger(__name__)
# 全局数据库写入锁(每个用户一个锁,用于保护SQLite写入操作)
db_write_locks: dict[str, Lock] = {}
async def get_db_write_lock(user_id: str) -> Lock:
"""获取或创建用户的数据库写入锁"""
if user_id not in db_write_locks:
db_write_locks[user_id] = Lock()
logger.debug(f"🔒 为用户 {user_id} 创建数据库写入锁")
return db_write_locks[user_id]
@router.post("", response_model=ChapterResponse, summary="创建章节")
async def create_chapter(
chapter: ChapterCreate,
db: AsyncSession = Depends(get_db)
):
"""创建新的章节"""
# 验证项目是否存在
result = await db.execute(
select(Project).where(Project.id == chapter.project_id)
)
project = result.scalar_one_or_none()
if not project:
raise HTTPException(status_code=404, detail="项目不存在")
# 计算字数
word_count = len(chapter.content)
db_chapter = Chapter(
**chapter.model_dump(),
word_count=word_count
)
db.add(db_chapter)
# 更新项目的当前字数
project.current_words = project.current_words + word_count
await db.commit()
await db.refresh(db_chapter)
return db_chapter
@router.get("/project/{project_id}", response_model=ChapterListResponse, summary="获取项目的所有章节")
async def get_project_chapters(
project_id: str,
db: AsyncSession = Depends(get_db)
):
"""获取指定项目的所有章节(路径参数版本)"""
# 获取总数
count_result = await db.execute(
select(func.count(Chapter.id)).where(Chapter.project_id == project_id)
)
total = count_result.scalar_one()
# 获取章节列表
result = await db.execute(
select(Chapter)
.where(Chapter.project_id == project_id)
.order_by(Chapter.chapter_number)
)
chapters = result.scalars().all()
return ChapterListResponse(total=total, items=chapters)
@router.get("/{chapter_id}", response_model=ChapterResponse, summary="获取章节详情")
async def get_chapter(
chapter_id: str,
db: AsyncSession = Depends(get_db)
):
"""根据ID获取章节详情"""
result = await db.execute(
select(Chapter).where(Chapter.id == chapter_id)
)
chapter = result.scalar_one_or_none()
if not chapter:
raise HTTPException(status_code=404, detail="章节不存在")
return chapter
@router.get("/{chapter_id}/navigation", summary="获取章节导航信息")
async def get_chapter_navigation(
chapter_id: str,
db: AsyncSession = Depends(get_db)
):
"""
获取章节的导航信息(上一章/下一章)
用于章节阅读器的翻页功能
"""
# 获取当前章节
result = await db.execute(
select(Chapter).where(Chapter.id == chapter_id)
)
current_chapter = result.scalar_one_or_none()
if not current_chapter:
raise HTTPException(status_code=404, detail="章节不存在")
# 获取上一章
prev_result = await db.execute(
select(Chapter)
.where(Chapter.project_id == current_chapter.project_id)
.where(Chapter.chapter_number < current_chapter.chapter_number)
.order_by(Chapter.chapter_number.desc())
.limit(1)
)
prev_chapter = prev_result.scalar_one_or_none()
# 获取下一章
next_result = await db.execute(
select(Chapter)
.where(Chapter.project_id == current_chapter.project_id)
.where(Chapter.chapter_number > current_chapter.chapter_number)
.order_by(Chapter.chapter_number.asc())
.limit(1)
)
next_chapter = next_result.scalar_one_or_none()
return {
"current": {
"id": current_chapter.id,
"chapter_number": current_chapter.chapter_number,
"title": current_chapter.title
},
"previous": {
"id": prev_chapter.id,
"chapter_number": prev_chapter.chapter_number,
"title": prev_chapter.title
} if prev_chapter else None,
"next": {
"id": next_chapter.id,
"chapter_number": next_chapter.chapter_number,
"title": next_chapter.title
} if next_chapter else None
}
@router.put("/{chapter_id}", response_model=ChapterResponse, summary="更新章节")
async def update_chapter(
chapter_id: str,
chapter_update: ChapterUpdate,
db: AsyncSession = Depends(get_db)
):
"""更新章节信息"""
result = await db.execute(
select(Chapter).where(Chapter.id == chapter_id)
)
chapter = result.scalar_one_or_none()
if not chapter:
raise HTTPException(status_code=404, detail="章节不存在")
# 记录旧字数
old_word_count = chapter.word_count or 0
# 更新字段
update_data = chapter_update.model_dump(exclude_unset=True)
for field, value in update_data.items():
setattr(chapter, field, value)
# 如果内容更新了,重新计算字数
if "content" in update_data and chapter.content:
new_word_count = len(chapter.content)
chapter.word_count = new_word_count
# 更新项目字数
result = await db.execute(
select(Project).where(Project.id == chapter.project_id)
)
project = result.scalar_one_or_none()
if project:
project.current_words = project.current_words - old_word_count + new_word_count
await db.commit()
await db.refresh(chapter)
return chapter
@router.delete("/{chapter_id}", summary="删除章节")
async def delete_chapter(
chapter_id: str,
db: AsyncSession = Depends(get_db)
):
"""删除章节"""
result = await db.execute(
select(Chapter).where(Chapter.id == chapter_id)
)
chapter = result.scalar_one_or_none()
if not chapter:
raise HTTPException(status_code=404, detail="章节不存在")
# 更新项目字数
result = await db.execute(
select(Project).where(Project.id == chapter.project_id)
)
project = result.scalar_one_or_none()
if project:
project.current_words = max(0, project.current_words - chapter.word_count)
await db.delete(chapter)
await db.commit()
return {"message": "章节删除成功"}
async def check_prerequisites(db: AsyncSession, chapter: Chapter) -> tuple[bool, str, list[Chapter]]:
"""
检查章节前置条件
Args:
db: 数据库会话
chapter: 当前章节
Returns:
(可否生成, 错误信息, 前置章节列表)
"""
# 如果是第一章,无需检查前置
if chapter.chapter_number == 1:
return True, "", []
# 查询所有前置章节(序号小于当前章节的)
result = await db.execute(
select(Chapter)
.where(Chapter.project_id == chapter.project_id)
.where(Chapter.chapter_number < chapter.chapter_number)
.order_by(Chapter.chapter_number)
)
previous_chapters = result.scalars().all()
# 检查是否所有前置章节都有内容
incomplete_chapters = [
ch for ch in previous_chapters
if not ch.content or ch.content.strip() == ""
]
if incomplete_chapters:
missing_numbers = [str(ch.chapter_number) for ch in incomplete_chapters]
error_msg = f"需要先完成前置章节:第 {', '.join(missing_numbers)}"
return False, error_msg, previous_chapters
return True, "", previous_chapters
@router.get("/{chapter_id}/can-generate", summary="检查章节是否可以生成")
async def check_can_generate(
chapter_id: str,
db: AsyncSession = Depends(get_db)
):
"""
检查章节是否满足生成条件
返回可生成状态和前置章节信息
"""
# 获取章节
result = await db.execute(
select(Chapter).where(Chapter.id == chapter_id)
)
chapter = result.scalar_one_or_none()
if not chapter:
raise HTTPException(status_code=404, detail="章节不存在")
# 检查前置条件
can_generate, error_msg, previous_chapters = await check_prerequisites(db, chapter)
# 构建前置章节信息
previous_info = [
{
"id": ch.id,
"chapter_number": ch.chapter_number,
"title": ch.title,
"has_content": bool(ch.content and ch.content.strip()),
"word_count": ch.word_count or 0
}
for ch in previous_chapters
]
return {
"can_generate": can_generate,
"reason": error_msg if not can_generate else "",
"previous_chapters": previous_info,
"chapter_number": chapter.chapter_number
}
async def analyze_chapter_background(
chapter_id: str,
user_id: str,
project_id: str,
task_id: str,
ai_service: AIService
):
"""
后台异步分析章节(支持并发,使用锁保护数据库写入)
Args:
chapter_id: 章节ID
user_id: 用户ID
project_id: 项目ID
task_id: 任务ID
ai_service: AI服务实例
"""
db_session = None
write_lock = await get_db_write_lock(user_id)
try:
logger.info(f"🔍 开始分析章节: {chapter_id}, 任务ID: {task_id}")
# 创建独立数据库会话
from app.database import get_engine
from sqlalchemy.ext.asyncio import async_sessionmaker, AsyncSession
engine = await get_engine(user_id)
AsyncSessionLocal = async_sessionmaker(
engine,
class_=AsyncSession,
expire_on_commit=False
)
db_session = AsyncSessionLocal()
# 1. 获取任务(读操作)
task_result = await db_session.execute(
select(AnalysisTask).where(AnalysisTask.id == task_id)
)
task = task_result.scalar_one_or_none()
if not task:
logger.error(f"❌ 任务不存在: {task_id}")
return
# 更新任务状态(写操作,需要锁)
async with write_lock:
task.status = 'running'
task.started_at = datetime.now()
task.progress = 10
await db_session.commit()
# 2. 获取章节信息(读操作)
chapter_result = await db_session.execute(
select(Chapter).where(Chapter.id == chapter_id)
)
chapter = chapter_result.scalar_one_or_none()
if not chapter or not chapter.content:
async with write_lock:
task.status = 'failed'
task.error_message = '章节不存在或内容为空'
task.completed_at = datetime.now()
await db_session.commit()
logger.error(f"❌ 章节不存在或内容为空: {chapter_id}")
return
async with write_lock:
task.progress = 20
await db_session.commit()
# 3. 使用PlotAnalyzer分析章节
analyzer = PlotAnalyzer(ai_service)
analysis_result = await analyzer.analyze_chapter(
chapter_number=chapter.chapter_number,
title=chapter.title,
content=chapter.content,
word_count=chapter.word_count or len(chapter.content)
)
if not analysis_result:
async with write_lock:
task.status = 'failed'
task.error_message = 'AI分析失败,请检查日志'
task.completed_at = datetime.now()
await db_session.commit()
logger.error(f"❌ AI分析失败: {chapter_id}")
return
async with write_lock:
task.progress = 60
await db_session.commit()
# 4. 保存分析结果到数据库(写操作,需要锁)
async with write_lock:
existing_analysis_result = await db_session.execute(
select(PlotAnalysis).where(PlotAnalysis.chapter_id == chapter_id)
)
existing_analysis = existing_analysis_result.scalar_one_or_none()
if existing_analysis:
# 更新现有记录
logger.info(f" 更新现有分析记录: {existing_analysis.id}")
existing_analysis.plot_stage = analysis_result.get('plot_stage', '发展')
existing_analysis.conflict_level = analysis_result.get('conflict', {}).get('level', 0)
existing_analysis.conflict_types = analysis_result.get('conflict', {}).get('types', [])
existing_analysis.emotional_tone = analysis_result.get('emotional_arc', {}).get('primary_emotion', '')
existing_analysis.emotional_intensity = analysis_result.get('emotional_arc', {}).get('intensity', 0) / 10.0
existing_analysis.hooks = analysis_result.get('hooks', [])
existing_analysis.hooks_count = len(analysis_result.get('hooks', []))
existing_analysis.foreshadows = analysis_result.get('foreshadows', [])
existing_analysis.foreshadows_planted = sum(1 for f in analysis_result.get('foreshadows', []) if f.get('type') == 'planted')
existing_analysis.foreshadows_resolved = sum(1 for f in analysis_result.get('foreshadows', []) if f.get('type') == 'resolved')
existing_analysis.plot_points = analysis_result.get('plot_points', [])
existing_analysis.plot_points_count = len(analysis_result.get('plot_points', []))
existing_analysis.character_states = analysis_result.get('character_states', [])
existing_analysis.scenes = analysis_result.get('scenes', [])
existing_analysis.pacing = analysis_result.get('pacing', 'moderate')
existing_analysis.overall_quality_score = analysis_result.get('scores', {}).get('overall', 0)
existing_analysis.pacing_score = analysis_result.get('scores', {}).get('pacing', 0)
existing_analysis.engagement_score = analysis_result.get('scores', {}).get('engagement', 0)
existing_analysis.coherence_score = analysis_result.get('scores', {}).get('coherence', 0)
existing_analysis.analysis_report = analyzer.generate_analysis_summary(analysis_result)
existing_analysis.suggestions = analysis_result.get('suggestions', [])
existing_analysis.dialogue_ratio = analysis_result.get('dialogue_ratio', 0)
existing_analysis.description_ratio = analysis_result.get('description_ratio', 0)
else:
# 创建新记录
logger.info(f" 创建新的分析记录")
plot_analysis = PlotAnalysis(
chapter_id=chapter_id,
project_id=project_id,
plot_stage=analysis_result.get('plot_stage', '发展'),
conflict_level=analysis_result.get('conflict', {}).get('level', 0),
conflict_types=analysis_result.get('conflict', {}).get('types', []),
emotional_tone=analysis_result.get('emotional_arc', {}).get('primary_emotion', ''),
emotional_intensity=analysis_result.get('emotional_arc', {}).get('intensity', 0) / 10.0,
hooks=analysis_result.get('hooks', []),
hooks_count=len(analysis_result.get('hooks', [])),
foreshadows=analysis_result.get('foreshadows', []),
foreshadows_planted=sum(1 for f in analysis_result.get('foreshadows', []) if f.get('type') == 'planted'),
foreshadows_resolved=sum(1 for f in analysis_result.get('foreshadows', []) if f.get('type') == 'resolved'),
plot_points=analysis_result.get('plot_points', []),
plot_points_count=len(analysis_result.get('plot_points', [])),
character_states=analysis_result.get('character_states', []),
scenes=analysis_result.get('scenes', []),
pacing=analysis_result.get('pacing', 'moderate'),
overall_quality_score=analysis_result.get('scores', {}).get('overall', 0),
pacing_score=analysis_result.get('scores', {}).get('pacing', 0),
engagement_score=analysis_result.get('scores', {}).get('engagement', 0),
coherence_score=analysis_result.get('scores', {}).get('coherence', 0),
analysis_report=analyzer.generate_analysis_summary(analysis_result),
suggestions=analysis_result.get('suggestions', []),
dialogue_ratio=analysis_result.get('dialogue_ratio', 0),
description_ratio=analysis_result.get('description_ratio', 0)
)
db_session.add(plot_analysis)
await db_session.commit()
task.progress = 80
await db_session.commit()
# 5. 提取记忆并保存到向量数据库(传入章节内容用于计算位置)
memories = analyzer.extract_memories_from_analysis(
analysis=analysis_result,
chapter_id=chapter_id,
chapter_number=chapter.chapter_number,
chapter_content=chapter.content or ""
)
# 先删除该章节的旧记忆(写操作,需要锁)
async with write_lock:
old_memories_result = await db_session.execute(
select(StoryMemory).where(StoryMemory.chapter_id == chapter_id)
)
old_memories = old_memories_result.scalars().all()
for old_mem in old_memories:
await db_session.delete(old_mem)
await db_session.commit()
logger.info(f" 删除旧记忆: {len(old_memories)}")
# 准备批量添加的记忆数据(不需要锁)
memory_records = []
for mem in memories:
memory_id = f"{chapter_id}_{mem['type']}_{len(memory_records)}"
memory_records.append({
'id': memory_id,
'content': mem['content'],
'type': mem['type'],
'metadata': mem['metadata']
})
# 保存到关系数据库(写操作,需要锁)
async with write_lock:
for mem in memories:
memory_id = memory_records[memories.index(mem)]['id']
text_position = mem['metadata'].get('text_position', -1)
text_length = mem['metadata'].get('text_length', 0)
story_memory = StoryMemory(
id=memory_id,
project_id=project_id,
chapter_id=chapter_id,
memory_type=mem['type'],
content=mem['content'],
title=mem['title'],
importance_score=mem['metadata'].get('importance_score', 0.5),
tags=mem['metadata'].get('tags', []),
is_foreshadow=mem['metadata'].get('is_foreshadow', 0),
story_timeline=chapter.chapter_number,
chapter_position=text_position,
text_length=text_length,
related_characters=mem['metadata'].get('related_characters', []),
related_locations=mem['metadata'].get('related_locations', [])
)
db_session.add(story_memory)
if text_position >= 0:
logger.debug(f" 保存记忆 {memory_id}: position={text_position}, length={text_length}")
await db_session.commit()
# 批量添加到向量数据库
if memory_records:
added_count = await memory_service.batch_add_memories(
user_id=user_id,
project_id=project_id,
memories=memory_records
)
logger.info(f"✅ 添加{added_count}条记忆到向量库")
# 最终更新任务状态(写操作,需要锁)
async with write_lock:
task.progress = 100
task.status = 'completed'
task.completed_at = datetime.now()
await db_session.commit()
logger.info(f"✅ 章节分析完成: {chapter_id}, 提取{len(memories)}条记忆")
except Exception as e:
logger.error(f"❌ 后台分析异常: {str(e)}", exc_info=True)
# 确保任务状态被更新为failed(写操作,需要锁)
if db_session:
try:
async with write_lock:
task_result = await db_session.execute(
select(AnalysisTask).where(AnalysisTask.id == task_id)
)
task = task_result.scalar_one_or_none()
if task:
task.status = 'failed'
task.error_message = str(e)[:500]
task.completed_at = datetime.now()
task.progress = 0
await db_session.commit()
logger.info(f"✅ 任务状态已更新为failed: {task_id}")
else:
logger.error(f"❌ 无法找到任务进行状态更新: {task_id}")
except Exception as update_error:
logger.error(f"❌ 更新任务状态失败: {str(update_error)}")
finally:
if db_session:
await db_session.close()
@router.post("/{chapter_id}/generate-stream", summary="AI创作章节内容(流式)")
async def generate_chapter_content_stream(
chapter_id: str,
request: Request,
background_tasks: BackgroundTasks,
generate_request: ChapterGenerateRequest = ChapterGenerateRequest(),
user_ai_service: AIService = Depends(get_user_ai_service)
):
"""
根据大纲、前置章节内容和项目信息AI创作章节完整内容(流式返回)
要求:必须按顺序生成,确保前置章节都已完成
请求体参数:
- style_id: 可选,指定使用的写作风格ID。不提供则不使用任何风格
- target_word_count: 可选,目标字数,默认3000字,范围500-10000字
注意:此函数不使用依赖注入的db,而是在生成器内部创建独立的数据库会话
以避免流式响应期间的连接泄漏问题
"""
style_id = generate_request.style_id
target_word_count = generate_request.target_word_count or 3000
# 预先验证章节存在性(使用临时会话)
async for temp_db in get_db(request):
try:
result = await temp_db.execute(
select(Chapter).where(Chapter.id == chapter_id)
)
chapter = result.scalar_one_or_none()
if not chapter:
raise HTTPException(status_code=404, detail="章节不存在")
# 检查前置条件
can_generate, error_msg, previous_chapters = await check_prerequisites(temp_db, chapter)
if not can_generate:
raise HTTPException(status_code=400, detail=error_msg)
# 保存前置章节数据供生成器使用
previous_chapters_data = [
{
'id': ch.id,
'chapter_number': ch.chapter_number,
'title': ch.title,
'content': ch.content
}
for ch in previous_chapters
]
finally:
await temp_db.close()
break
async def event_generator():
# 在生成器内部创建独立的数据库会话
db_session = None
db_committed = False
# 获取当前用户ID(在生成器外部就需要)
current_user_id = getattr(request.state, "user_id", "system")
try:
# 创建新的数据库会话
async for db_session in get_db(request):
# 重新获取章节信息
chapter_result = await db_session.execute(
select(Chapter).where(Chapter.id == chapter_id)
)
current_chapter = chapter_result.scalar_one_or_none()
if not current_chapter:
yield f"data: {json.dumps({'type': 'error', 'error': '章节不存在'}, ensure_ascii=False)}\n\n"
return
# 获取项目信息
project_result = await db_session.execute(
select(Project).where(Project.id == current_chapter.project_id)
)
project = project_result.scalar_one_or_none()
if not project:
yield f"data: {json.dumps({'type': 'error', 'error': '项目不存在'}, ensure_ascii=False)}\n\n"
return
# 获取对应的大纲
outline_result = await db_session.execute(
select(Outline)
.where(Outline.project_id == current_chapter.project_id)
.where(Outline.order_index == current_chapter.chapter_number)
.execution_options(populate_existing=True)
)
outline = outline_result.scalar_one_or_none()
# 获取所有大纲用于上下文
all_outlines_result = await db_session.execute(
select(Outline)
.where(Outline.project_id == current_chapter.project_id)
.order_by(Outline.order_index)
.execution_options(populate_existing=True)
)
all_outlines = all_outlines_result.scalars().all()
outlines_context = "\n".join([
f"{o.order_index}{o.title}: {o.content[:100]}..."
for o in all_outlines
])
# 获取角色信息
characters_result = await db_session.execute(
select(Character).where(Character.project_id == current_chapter.project_id)
)
characters = characters_result.scalars().all()
characters_info = "\n".join([
f"- {c.name}({'组织' if c.is_organization else '角色'}, {c.role_type}): {c.personality[:100] if c.personality else ''}"
for c in characters
])
# 获取写作风格
style_content = ""
if style_id:
# 使用指定的风格
style_result = await db_session.execute(
select(WritingStyle).where(WritingStyle.id == style_id)
)
style = style_result.scalar_one_or_none()
if style:
# 验证风格是否可用:全局预设风格(project_id为NULL)或者当前项目的自定义风格
if style.project_id is None or style.project_id == current_chapter.project_id:
style_content = style.prompt_content or ""
style_type = "全局预设" if style.project_id is None else "项目自定义"
logger.info(f"使用指定风格: {style.name} ({style_type})")
else:
logger.warning(f"风格 {style_id} 不属于当前项目,无法使用")
else:
logger.warning(f"未找到风格 {style_id}")
else:
logger.info("未指定写作风格,使用原始提示词")
# 构建前置章节内容上下文(使用之前保存的数据)
previous_content = ""
if previous_chapters_data:
recent_chapters = previous_chapters_data[-3:] if len(previous_chapters_data) > 3 else previous_chapters_data
early_chapters = previous_chapters_data[:-3] if len(previous_chapters_data) > 3 else []
if early_chapters:
early_summary = "【前期剧情概要】\n" + "\n".join([
f"{ch['chapter_number']}章《{ch['title']}》:{ch['content'][:200] if ch['content'] else ''}..."
for ch in early_chapters
])
previous_content += early_summary + "\n\n"
if recent_chapters:
recent_content = "【最近章节完整内容】\n" + "\n\n".join([
f"=== 第{ch['chapter_number']}章:{ch['title']} ===\n{ch['content']}"
for ch in recent_chapters
])
previous_content += recent_content
logger.info(f"构建前置上下文:{len(early_chapters)}章摘要 + {len(recent_chapters)}章完整内容")
# 🧠 构建记忆增强上下文
logger.info(f"🧠 开始构建记忆增强上下文...")
memory_context = await memory_service.build_context_for_generation(
user_id=current_user_id,
project_id=project.id,
current_chapter=current_chapter.chapter_number,
chapter_outline=outline.content if outline else current_chapter.summary or "",
character_names=[c.name for c in characters] if characters else None
)
# 计算各部分的字符长度
context_lengths = {
'recent_context': len(memory_context.get('recent_context', '')),
'relevant_memories': len(memory_context.get('relevant_memories', '')),
'foreshadows': len(memory_context.get('foreshadows', '')),
'character_states': len(memory_context.get('character_states', '')),
'plot_points': len(memory_context.get('plot_points', ''))
}
total_memory_length = sum(context_lengths.values())
logger.info(f"✅ 记忆上下文构建完成: {memory_context['stats']}")
logger.info(f"📏 记忆上下文长度统计:")
logger.info(f" - 最近章节记忆: {context_lengths['recent_context']} 字符")
logger.info(f" - 语义相关记忆: {context_lengths['relevant_memories']} 字符")
logger.info(f" - 未完结伏笔: {context_lengths['foreshadows']} 字符")
logger.info(f" - 角色状态记忆: {context_lengths['character_states']} 字符")
logger.info(f" - 重要情节点: {context_lengths['plot_points']} 字符")
logger.info(f" - 记忆总长度: {total_memory_length} 字符")
logger.info(f" - 前置章节上下文长度: {len(previous_content)} 字符")
logger.info(f" - 总上下文长度(估算): {total_memory_length + len(previous_content) + 2000} 字符")
# 发送开始事件
yield f"data: {json.dumps({'type': 'start', 'message': '开始AI创作...'}, ensure_ascii=False)}\n\n"
# 根据是否有前置内容选择不同的提示词,并应用写作风格和记忆增强
if previous_content:
prompt = prompt_service.get_chapter_generation_with_context_prompt(
title=project.title,
theme=project.theme or '',
genre=project.genre or '',
narrative_perspective=project.narrative_perspective or '第三人称',
time_period=project.world_time_period or '未设定',
location=project.world_location or '未设定',
atmosphere=project.world_atmosphere or '未设定',
rules=project.world_rules or '未设定',
characters_info=characters_info or '暂无角色信息',
outlines_context=outlines_context,
previous_content=previous_content,
chapter_number=current_chapter.chapter_number,
chapter_title=current_chapter.title,
chapter_outline=outline.content if outline else current_chapter.summary or '暂无大纲',
style_content=style_content,
target_word_count=target_word_count,
memory_context=memory_context
)
else:
prompt = prompt_service.get_chapter_generation_prompt(
title=project.title,
theme=project.theme or '',
genre=project.genre or '',
narrative_perspective=project.narrative_perspective or '第三人称',
time_period=project.world_time_period or '未设定',
location=project.world_location or '未设定',
atmosphere=project.world_atmosphere or '未设定',
rules=project.world_rules or '未设定',
characters_info=characters_info or '暂无角色信息',
outlines_context=outlines_context,
chapter_number=current_chapter.chapter_number,
chapter_title=current_chapter.title,
chapter_outline=outline.content if outline else current_chapter.summary or '暂无大纲',
style_content=style_content,
target_word_count=target_word_count,
memory_context=memory_context
)
logger.info(f"开始AI流式创作章节 {chapter_id}")
# 流式生成内容
full_content = ""
async for chunk in user_ai_service.generate_text_stream(prompt=prompt):
full_content += chunk
yield f"data: {json.dumps({'type': 'content', 'content': chunk}, ensure_ascii=False)}\n\n"
await asyncio.sleep(0) # 让出控制权
# 更新章节内容到数据库
old_word_count = current_chapter.word_count or 0
current_chapter.content = full_content
new_word_count = len(full_content)
current_chapter.word_count = new_word_count
current_chapter.status = "completed"
# 更新项目字数
project.current_words = project.current_words - old_word_count + new_word_count
# 记录生成历史
history = GenerationHistory(
project_id=current_chapter.project_id,
chapter_id=current_chapter.id,
prompt=f"创作章节: 第{current_chapter.chapter_number}{current_chapter.title}",
generated_content=full_content[:500] if len(full_content) > 500 else full_content,
model="default"
)
db_session.add(history)
await db_session.commit()
db_committed = True
await db_session.refresh(current_chapter)
logger.info(f"成功创作章节 {chapter_id},共 {new_word_count}")
# 创建分析任务
analysis_task = AnalysisTask(
chapter_id=chapter_id,
user_id=current_user_id,
project_id=project.id,
status='pending',
progress=0
)
db_session.add(analysis_task)
await db_session.commit()
await db_session.refresh(analysis_task)
task_id = analysis_task.id
logger.info(f"📋 已创建分析任务: {task_id}")
# 短暂延迟确保SQLite WAL完成写入
await asyncio.sleep(0.05)
# 直接启动后台分析(并发执行)
background_tasks.add_task(
analyze_chapter_background,
chapter_id=chapter_id,
user_id=current_user_id,
project_id=project.id,
task_id=task_id,
ai_service=user_ai_service
)
# 发送完成事件(包含分析任务ID
completion_data = {
'type': 'done',
'message': '创作完成',
'word_count': new_word_count,
'analysis_task_id': task_id
}
yield f"data: {json.dumps(completion_data, ensure_ascii=False)}\n\n"
# 发送分析开始事件
analysis_started_data = {
'type': 'analysis_started',
'task_id': task_id,
'message': '章节分析已开始'
}
yield f"data: {json.dumps(analysis_started_data, ensure_ascii=False)}\n\n"
break # 退出async for db_session循环
except GeneratorExit:
# SSE连接断开
logger.warning("章节生成器被提前关闭(SSE断开)")
if db_session and not db_committed:
try:
if db_session.in_transaction():
await db_session.rollback()
logger.info("章节生成事务已回滚(GeneratorExit")
except Exception as e:
logger.error(f"GeneratorExit回滚失败: {str(e)}")
except Exception as e:
logger.error(f"流式创作章节失败: {str(e)}")
if db_session and not db_committed:
try:
if db_session.in_transaction():
await db_session.rollback()
logger.info("章节生成事务已回滚(异常)")
except Exception as rollback_error:
logger.error(f"回滚失败: {str(rollback_error)}")
yield f"data: {json.dumps({'type': 'error', 'error': str(e)}, ensure_ascii=False)}\n\n"
finally:
# 确保数据库会话被正确关闭
if db_session:
try:
# 最后检查:确保没有未提交的事务
if not db_committed and db_session.in_transaction():
await db_session.rollback()
logger.warning("在finally中发现未提交事务,已回滚")
await db_session.close()
logger.info("数据库会话已关闭")
except Exception as close_error:
logger.error(f"关闭数据库会话失败: {str(close_error)}")
# 强制关闭
try:
await db_session.close()
except:
pass
return StreamingResponse(
event_generator(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no"
}
)
@router.get("/{chapter_id}/analysis/status", summary="查询章节分析任务状态")
async def get_analysis_task_status(
chapter_id: str,
db: AsyncSession = Depends(get_db)
):
"""
查询指定章节的最新分析任务状态
返回:
- task_id: 任务ID
- status: pending/running/completed/failed
- progress: 0-100
- error_message: 错误信息(如果失败)
- created_at: 创建时间
- completed_at: 完成时间
"""
# 获取该章节最新的分析任务
result = await db.execute(
select(AnalysisTask)
.where(AnalysisTask.chapter_id == chapter_id)
.order_by(AnalysisTask.created_at.desc())
.limit(1)
)
task = result.scalar_one_or_none()
if not task:
raise HTTPException(status_code=404, detail="未找到分析任务")
return {
"task_id": task.id,
"chapter_id": task.chapter_id,
"status": task.status,
"progress": task.progress,
"error_message": task.error_message,
"created_at": task.created_at.isoformat() if task.created_at else None,
"started_at": task.started_at.isoformat() if task.started_at else None,
"completed_at": task.completed_at.isoformat() if task.completed_at else None
}
@router.get("/{chapter_id}/analysis", summary="获取章节分析结果")
async def get_chapter_analysis(
chapter_id: str,
db: AsyncSession = Depends(get_db)
):
"""
获取章节的完整分析结果
返回:
- analysis_data: 完整的分析数据(JSON)
- summary: 分析摘要文本
- memories: 提取的记忆列表
- created_at: 分析时间
"""
# 获取分析结果
analysis_result = await db.execute(
select(PlotAnalysis)
.where(PlotAnalysis.chapter_id == chapter_id)
.order_by(PlotAnalysis.created_at.desc())
.limit(1)
)
analysis = analysis_result.scalar_one_or_none()
if not analysis:
raise HTTPException(status_code=404, detail="该章节暂无分析结果")
# 获取相关记忆
memories_result = await db.execute(
select(StoryMemory)
.where(StoryMemory.chapter_id == chapter_id)
.order_by(StoryMemory.importance_score.desc())
)
memories = memories_result.scalars().all()
return {
"chapter_id": chapter_id,
"analysis": analysis.to_dict(), # 使用to_dict()方法
"memories": [
{
"id": mem.id,
"type": mem.memory_type,
"title": mem.title,
"content": mem.content,
"importance": mem.importance_score,
"tags": mem.tags,
"is_foreshadow": mem.is_foreshadow,
"position": mem.chapter_position,
"related_characters": mem.related_characters
}
for mem in memories
],
"created_at": analysis.created_at.isoformat() if analysis.created_at else None
}
@router.get("/{chapter_id}/annotations", summary="获取章节标注数据")
async def get_chapter_annotations(
chapter_id: str,
db: AsyncSession = Depends(get_db)
):
"""
获取章节的标注数据(用于前端展示标注)
返回格式化的标注列表,包含精确位置信息
适用于章节内容的可视化标注展示
"""
# 获取章节
chapter_result = await db.execute(
select(Chapter).where(Chapter.id == chapter_id)
)
chapter = chapter_result.scalar_one_or_none()
if not chapter:
raise HTTPException(status_code=404, detail="章节不存在")
# 获取分析结果
analysis_result = await db.execute(
select(PlotAnalysis)
.where(PlotAnalysis.chapter_id == chapter_id)
.order_by(PlotAnalysis.created_at.desc())
.limit(1)
)
analysis = analysis_result.scalar_one_or_none()
# 获取记忆
memories_result = await db.execute(
select(StoryMemory)
.where(StoryMemory.chapter_id == chapter_id)
.order_by(StoryMemory.importance_score.desc())
)
memories = memories_result.scalars().all()
# 构建标注数据
annotations = []
for mem in memories:
# 优先从数据库读取位置信息
position = mem.chapter_position if mem.chapter_position is not None else -1
length = mem.text_length if hasattr(mem, 'text_length') and mem.text_length is not None else 0
metadata_extra = {}
# 如果数据库中没有位置信息,尝试从分析数据中重新计算
if position == -1 and analysis and chapter.content:
# 根据记忆类型从分析数据中查找对应项
if mem.memory_type == 'hook' and analysis.hooks:
for hook in analysis.hooks:
# 通过标题或内容匹配
if mem.title and hook.get('type') in mem.title:
keyword = hook.get('keyword', '')
if keyword:
pos = chapter.content.find(keyword)
if pos != -1:
position = pos
length = len(keyword)
metadata_extra["strength"] = hook.get('strength', 5)
metadata_extra["position_desc"] = hook.get('position', '')
break
elif mem.memory_type == 'foreshadow' and analysis.foreshadows:
for foreshadow in analysis.foreshadows:
if foreshadow.get('content') in mem.content:
keyword = foreshadow.get('keyword', '')
if keyword:
pos = chapter.content.find(keyword)
if pos != -1:
position = pos
length = len(keyword)
metadata_extra["foreshadow_type"] = foreshadow.get('type', 'planted')
metadata_extra["strength"] = foreshadow.get('strength', 5)
break
elif mem.memory_type == 'plot_point' and analysis.plot_points:
for plot_point in analysis.plot_points:
if plot_point.get('content') in mem.content:
keyword = plot_point.get('keyword', '')
if keyword:
pos = chapter.content.find(keyword)
if pos != -1:
position = pos
length = len(keyword)
break
else:
# 如果数据库有位置,也从分析数据中提取额外的元数据
if analysis:
if mem.memory_type == 'hook' and analysis.hooks:
for hook in analysis.hooks:
if mem.title and hook.get('type') in mem.title:
metadata_extra["strength"] = hook.get('strength', 5)
metadata_extra["position_desc"] = hook.get('position', '')
break
elif mem.memory_type == 'foreshadow' and analysis.foreshadows:
for foreshadow in analysis.foreshadows:
if foreshadow.get('content') in mem.content:
metadata_extra["foreshadow_type"] = foreshadow.get('type', 'planted')
metadata_extra["strength"] = foreshadow.get('strength', 5)
break
annotation = {
"id": mem.id,
"type": mem.memory_type,
"title": mem.title,
"content": mem.content,
"importance": mem.importance_score or 0.5,
"position": position,
"length": length,
"tags": mem.tags or [],
"metadata": {
"is_foreshadow": mem.is_foreshadow,
"related_characters": mem.related_characters or [],
"related_locations": mem.related_locations or [],
**metadata_extra
}
}
annotations.append(annotation)
return {
"chapter_id": chapter_id,
"chapter_number": chapter.chapter_number,
"title": chapter.title,
"word_count": chapter.word_count or 0,
"annotations": annotations,
"has_analysis": analysis is not None,
"summary": {
"total_annotations": len(annotations),
"hooks": len([a for a in annotations if a["type"] == "hook"]),
"foreshadows": len([a for a in annotations if a["type"] == "foreshadow"]),
"plot_points": len([a for a in annotations if a["type"] == "plot_point"]),
"character_events": len([a for a in annotations if a["type"] == "character_event"])
}
}
@router.post("/{chapter_id}/analyze", summary="手动触发章节分析")
async def trigger_chapter_analysis(
chapter_id: str,
request: Request,
background_tasks: BackgroundTasks,
db: AsyncSession = Depends(get_db),
user_ai_service: AIService = Depends(get_user_ai_service)
):
"""
手动触发章节分析(用于重新分析或分析旧章节)
"""
# 从请求中获取用户ID
user_id = getattr(request.state, "user_id", None)
if not user_id:
raise HTTPException(status_code=401, detail="未登录")
# 验证章节存在
chapter_result = await db.execute(
select(Chapter).where(Chapter.id == chapter_id)
)
chapter = chapter_result.scalar_one_or_none()
if not chapter:
raise HTTPException(status_code=404, detail="章节不存在")
if not chapter.content or chapter.content.strip() == "":
raise HTTPException(status_code=400, detail="章节内容为空,无法分析")
# 获取项目信息
project_result = await db.execute(
select(Project).where(Project.id == chapter.project_id)
)
project = project_result.scalar_one_or_none()
if not project:
raise HTTPException(status_code=404, detail="项目不存在")
# 创建分析任务
analysis_task = AnalysisTask(
chapter_id=chapter_id,
user_id=user_id,
project_id=project.id,
status='pending',
progress=0
)
db.add(analysis_task)
await db.commit()
task_id = analysis_task.id
logger.info(f"📋 创建分析任务: {task_id}, 章节: {chapter_id}")
# 刷新数据库会话,确保其他会话可以看到新任务
await db.refresh(analysis_task)
# 短暂延迟确保SQLite WAL完成写入(让其他会话可见)
await asyncio.sleep(3)
# 直接启动后台分析(并发执行)
background_tasks.add_task(
analyze_chapter_background,
chapter_id=chapter_id,
user_id=user_id,
project_id=project.id,
task_id=task_id,
ai_service=user_ai_service
)
return {
"task_id": task_id,
"chapter_id": chapter_id,
"status": "pending",
"message": "分析任务已创建并开始执行"
}