update:1.更新导入导出功能 2.实现RAG记忆功能,引入剧情分析功能

This commit is contained in:
xiamuceer
2025-11-04 14:38:59 +08:00
parent 1cde345ed9
commit e4f90d5da0
26 changed files with 6722 additions and 84 deletions
+712 -6
View File
@@ -1,11 +1,12 @@
"""章节管理API"""
from fastapi import APIRouter, Depends, HTTPException, Request, Query
from fastapi import APIRouter, Depends, HTTPException, Request, Query, BackgroundTasks
from fastapi.responses import StreamingResponse
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, func
import json
import asyncio
from typing import Optional
from datetime import datetime
from app.database import get_db
from app.models.chapter import Chapter
@@ -14,6 +15,8 @@ from app.models.outline import Outline
from app.models.character import Character
from app.models.generation_history import GenerationHistory
from app.models.writing_style import WritingStyle
from app.models.analysis_task import AnalysisTask
from app.models.memory import PlotAnalysis, StoryMemory
from app.schemas.chapter import (
ChapterCreate,
ChapterUpdate,
@@ -23,6 +26,8 @@ from app.schemas.chapter import (
)
from app.services.ai_service import AIService
from app.services.prompt_service import prompt_service
from app.services.plot_analyzer import PlotAnalyzer
from app.services.memory_service import memory_service
from app.logger import get_logger
from app.api.settings import get_user_ai_service
@@ -101,6 +106,63 @@ async def get_chapter(
return chapter
@router.get("/{chapter_id}/navigation", summary="获取章节导航信息")
async def get_chapter_navigation(
chapter_id: str,
db: AsyncSession = Depends(get_db)
):
"""
获取章节的导航信息(上一章/下一章)
用于章节阅读器的翻页功能
"""
# 获取当前章节
result = await db.execute(
select(Chapter).where(Chapter.id == chapter_id)
)
current_chapter = result.scalar_one_or_none()
if not current_chapter:
raise HTTPException(status_code=404, detail="章节不存在")
# 获取上一章
prev_result = await db.execute(
select(Chapter)
.where(Chapter.project_id == current_chapter.project_id)
.where(Chapter.chapter_number < current_chapter.chapter_number)
.order_by(Chapter.chapter_number.desc())
.limit(1)
)
prev_chapter = prev_result.scalar_one_or_none()
# 获取下一章
next_result = await db.execute(
select(Chapter)
.where(Chapter.project_id == current_chapter.project_id)
.where(Chapter.chapter_number > current_chapter.chapter_number)
.order_by(Chapter.chapter_number.asc())
.limit(1)
)
next_chapter = next_result.scalar_one_or_none()
return {
"current": {
"id": current_chapter.id,
"chapter_number": current_chapter.chapter_number,
"title": current_chapter.title
},
"previous": {
"id": prev_chapter.id,
"chapter_number": prev_chapter.chapter_number,
"title": prev_chapter.title
} if prev_chapter else None,
"next": {
"id": next_chapter.id,
"chapter_number": next_chapter.chapter_number,
"title": next_chapter.title
} if next_chapter else None
}
@router.put("/{chapter_id}", response_model=ChapterResponse, summary="更新章节")
async def update_chapter(
chapter_id: str,
@@ -248,10 +310,273 @@ async def check_can_generate(
}
async def analyze_chapter_background(
chapter_id: str,
user_id: str,
project_id: str,
task_id: str,
ai_service: AIService
):
"""
后台异步分析章节
Args:
chapter_id: 章节ID
user_id: 用户ID
project_id: 项目ID
task_id: 任务ID
ai_service: AI服务实例
"""
db_session = None
try:
logger.info(f"🔍 开始后台分析章节: {chapter_id}")
# 等待一小段时间,确保主会话的commit已经持久化到磁盘
await asyncio.sleep(0.1)
# 创建独立数据库会话
from app.database import get_engine
from sqlalchemy.ext.asyncio import async_sessionmaker, AsyncSession
engine = await get_engine(user_id)
AsyncSessionLocal = async_sessionmaker(
engine,
class_=AsyncSession,
expire_on_commit=False
)
db_session = AsyncSessionLocal()
# 1. 获取任务(添加重试逻辑)
task = None
for retry in range(3):
task_result = await db_session.execute(
select(AnalysisTask).where(AnalysisTask.id == task_id)
)
task = task_result.scalar_one_or_none()
if task:
break
if retry < 2:
logger.warning(f"⚠️ 第{retry+1}次未找到任务 {task_id},等待后重试...")
await asyncio.sleep(0.2)
if not task:
logger.error(f"❌ 任务不存在: {task_id}")
return
task.status = 'running'
task.started_at = datetime.now()
task.progress = 10
await db_session.commit()
# 2. 获取章节信息
chapter_result = await db_session.execute(
select(Chapter).where(Chapter.id == chapter_id)
)
chapter = chapter_result.scalar_one_or_none()
if not chapter or not chapter.content:
task.status = 'failed'
task.error_message = '章节不存在或内容为空'
task.completed_at = datetime.now()
await db_session.commit()
logger.error(f"❌ 章节不存在或内容为空: {chapter_id}")
return
task.progress = 20
await db_session.commit()
# 3. 使用PlotAnalyzer分析章节
analyzer = PlotAnalyzer(ai_service)
analysis_result = await analyzer.analyze_chapter(
chapter_number=chapter.chapter_number,
title=chapter.title,
content=chapter.content,
word_count=chapter.word_count or len(chapter.content)
)
if not analysis_result:
task.status = 'failed'
task.error_message = 'AI分析失败,请检查日志'
task.completed_at = datetime.now()
await db_session.commit()
logger.error(f"❌ AI分析失败: {chapter_id}")
return
task.progress = 60
await db_session.commit()
# 4. 保存分析结果到数据库(先检查是否已存在)
existing_analysis_result = await db_session.execute(
select(PlotAnalysis).where(PlotAnalysis.chapter_id == chapter_id)
)
existing_analysis = existing_analysis_result.scalar_one_or_none()
if existing_analysis:
# 更新现有记录
logger.info(f" 更新现有分析记录: {existing_analysis.id}")
existing_analysis.plot_stage = analysis_result.get('plot_stage', '发展')
existing_analysis.conflict_level = analysis_result.get('conflict', {}).get('level', 0)
existing_analysis.conflict_types = analysis_result.get('conflict', {}).get('types', [])
existing_analysis.emotional_tone = analysis_result.get('emotional_arc', {}).get('primary_emotion', '')
existing_analysis.emotional_intensity = analysis_result.get('emotional_arc', {}).get('intensity', 0) / 10.0
existing_analysis.hooks = analysis_result.get('hooks', [])
existing_analysis.hooks_count = len(analysis_result.get('hooks', []))
existing_analysis.foreshadows = analysis_result.get('foreshadows', [])
existing_analysis.foreshadows_planted = sum(1 for f in analysis_result.get('foreshadows', []) if f.get('type') == 'planted')
existing_analysis.foreshadows_resolved = sum(1 for f in analysis_result.get('foreshadows', []) if f.get('type') == 'resolved')
existing_analysis.plot_points = analysis_result.get('plot_points', [])
existing_analysis.plot_points_count = len(analysis_result.get('plot_points', []))
existing_analysis.character_states = analysis_result.get('character_states', [])
existing_analysis.scenes = analysis_result.get('scenes', [])
existing_analysis.pacing = analysis_result.get('pacing', 'moderate')
existing_analysis.overall_quality_score = analysis_result.get('scores', {}).get('overall', 0)
existing_analysis.pacing_score = analysis_result.get('scores', {}).get('pacing', 0)
existing_analysis.engagement_score = analysis_result.get('scores', {}).get('engagement', 0)
existing_analysis.coherence_score = analysis_result.get('scores', {}).get('coherence', 0)
existing_analysis.analysis_report = analyzer.generate_analysis_summary(analysis_result)
existing_analysis.suggestions = analysis_result.get('suggestions', [])
existing_analysis.dialogue_ratio = analysis_result.get('dialogue_ratio', 0)
existing_analysis.description_ratio = analysis_result.get('description_ratio', 0)
else:
# 创建新记录
logger.info(f" 创建新的分析记录")
plot_analysis = PlotAnalysis(
chapter_id=chapter_id,
project_id=project_id,
plot_stage=analysis_result.get('plot_stage', '发展'),
conflict_level=analysis_result.get('conflict', {}).get('level', 0),
conflict_types=analysis_result.get('conflict', {}).get('types', []),
emotional_tone=analysis_result.get('emotional_arc', {}).get('primary_emotion', ''),
emotional_intensity=analysis_result.get('emotional_arc', {}).get('intensity', 0) / 10.0,
hooks=analysis_result.get('hooks', []),
hooks_count=len(analysis_result.get('hooks', [])),
foreshadows=analysis_result.get('foreshadows', []),
foreshadows_planted=sum(1 for f in analysis_result.get('foreshadows', []) if f.get('type') == 'planted'),
foreshadows_resolved=sum(1 for f in analysis_result.get('foreshadows', []) if f.get('type') == 'resolved'),
plot_points=analysis_result.get('plot_points', []),
plot_points_count=len(analysis_result.get('plot_points', [])),
character_states=analysis_result.get('character_states', []),
scenes=analysis_result.get('scenes', []),
pacing=analysis_result.get('pacing', 'moderate'),
overall_quality_score=analysis_result.get('scores', {}).get('overall', 0),
pacing_score=analysis_result.get('scores', {}).get('pacing', 0),
engagement_score=analysis_result.get('scores', {}).get('engagement', 0),
coherence_score=analysis_result.get('scores', {}).get('coherence', 0),
analysis_report=analyzer.generate_analysis_summary(analysis_result),
suggestions=analysis_result.get('suggestions', []),
dialogue_ratio=analysis_result.get('dialogue_ratio', 0),
description_ratio=analysis_result.get('description_ratio', 0)
)
db_session.add(plot_analysis)
await db_session.commit()
task.progress = 80
await db_session.commit()
# 5. 提取记忆并保存到向量数据库(传入章节内容用于计算位置)
memories = analyzer.extract_memories_from_analysis(
analysis=analysis_result,
chapter_id=chapter_id,
chapter_number=chapter.chapter_number,
chapter_content=chapter.content or ""
)
# 先删除该章节的旧记忆(支持重新分析)
old_memories_result = await db_session.execute(
select(StoryMemory).where(StoryMemory.chapter_id == chapter_id)
)
old_memories = old_memories_result.scalars().all()
for old_mem in old_memories:
await db_session.delete(old_mem)
logger.info(f" 删除旧记忆: {len(old_memories)}")
# 准备批量添加的记忆数据
memory_records = []
for mem in memories:
memory_id = f"{chapter_id}_{mem['type']}_{len(memory_records)}"
memory_records.append({
'id': memory_id,
'content': mem['content'],
'type': mem['type'],
'metadata': mem['metadata']
})
# 从metadata中提取位置信息
text_position = mem['metadata'].get('text_position', -1)
text_length = mem['metadata'].get('text_length', 0)
# 同时保存到关系数据库
story_memory = StoryMemory(
id=memory_id,
project_id=project_id,
chapter_id=chapter_id,
memory_type=mem['type'],
content=mem['content'],
title=mem['title'],
importance_score=mem['metadata'].get('importance_score', 0.5),
tags=mem['metadata'].get('tags', []),
is_foreshadow=mem['metadata'].get('is_foreshadow', 0),
story_timeline=chapter.chapter_number, # 使用章节序号作为时间线
chapter_position=text_position, # 保存文本位置
text_length=text_length, # 保存文本长度
related_characters=mem['metadata'].get('related_characters', []),
related_locations=mem['metadata'].get('related_locations', [])
)
db_session.add(story_memory)
# 记录日志便于调试
if text_position >= 0:
logger.debug(f" 保存记忆 {memory_id}: position={text_position}, length={text_length}")
await db_session.commit()
# 批量添加到向量数据库
if memory_records:
added_count = await memory_service.batch_add_memories(
user_id=user_id,
project_id=project_id,
memories=memory_records
)
logger.info(f"✅ 添加{added_count}条记忆到向量库")
task.progress = 100
task.status = 'completed'
task.completed_at = datetime.now()
await db_session.commit()
logger.info(f"✅ 章节分析完成: {chapter_id}, 提取{len(memories)}条记忆")
except Exception as e:
logger.error(f"❌ 后台分析异常: {str(e)}", exc_info=True)
# 确保任务状态被更新为failed,避免前端一直轮询
if db_session:
try:
# 重新获取任务以确保有最新状态
task_result = await db_session.execute(
select(AnalysisTask).where(AnalysisTask.id == task_id)
)
task = task_result.scalar_one_or_none()
if task:
task.status = 'failed'
task.error_message = str(e)[:500]
task.completed_at = datetime.now()
task.progress = 0 # 重置进度
await db_session.commit()
logger.info(f"✅ 任务状态已更新为failed: {task_id}")
else:
logger.error(f"❌ 无法找到任务进行状态更新: {task_id}")
except Exception as update_error:
logger.error(f"❌ 更新任务状态失败: {str(update_error)}")
finally:
if db_session:
await db_session.close()
@router.post("/{chapter_id}/generate-stream", summary="AI创作章节内容(流式)")
async def generate_chapter_content_stream(
chapter_id: str,
request: Request,
background_tasks: BackgroundTasks,
generate_request: ChapterGenerateRequest = ChapterGenerateRequest(),
user_ai_service: AIService = Depends(get_user_ai_service)
):
@@ -301,6 +626,9 @@ async def generate_chapter_content_stream(
# 在生成器内部创建独立的数据库会话
db_session = None
db_committed = False
# 获取当前用户ID(在生成器外部就需要)
current_user_id = getattr(request.state, "user_id", "system")
try:
# 创建新的数据库会话
async for db_session in get_db(request):
@@ -396,11 +724,42 @@ async def generate_chapter_content_stream(
previous_content += recent_content
logger.info(f"构建前置上下文:{len(early_chapters)}章摘要 + {len(recent_chapters)}章完整内容")
# 🧠 构建记忆增强上下文
logger.info(f"🧠 开始构建记忆增强上下文...")
memory_context = await memory_service.build_context_for_generation(
user_id=current_user_id,
project_id=project.id,
current_chapter=current_chapter.chapter_number,
chapter_outline=outline.content if outline else current_chapter.summary or "",
character_names=[c.name for c in characters] if characters else None
)
# 计算各部分的字符长度
context_lengths = {
'recent_context': len(memory_context.get('recent_context', '')),
'relevant_memories': len(memory_context.get('relevant_memories', '')),
'foreshadows': len(memory_context.get('foreshadows', '')),
'character_states': len(memory_context.get('character_states', '')),
'plot_points': len(memory_context.get('plot_points', ''))
}
total_memory_length = sum(context_lengths.values())
logger.info(f"✅ 记忆上下文构建完成: {memory_context['stats']}")
logger.info(f"📏 记忆上下文长度统计:")
logger.info(f" - 最近章节记忆: {context_lengths['recent_context']} 字符")
logger.info(f" - 语义相关记忆: {context_lengths['relevant_memories']} 字符")
logger.info(f" - 未完结伏笔: {context_lengths['foreshadows']} 字符")
logger.info(f" - 角色状态记忆: {context_lengths['character_states']} 字符")
logger.info(f" - 重要情节点: {context_lengths['plot_points']} 字符")
logger.info(f" - 记忆总长度: {total_memory_length} 字符")
logger.info(f" - 前置章节上下文长度: {len(previous_content)} 字符")
logger.info(f" - 总上下文长度(估算): {total_memory_length + len(previous_content) + 2000} 字符")
# 发送开始事件
yield f"data: {json.dumps({'type': 'start', 'message': '开始AI创作...'}, ensure_ascii=False)}\n\n"
# 根据是否有前置内容选择不同的提示词,并应用写作风格
# 根据是否有前置内容选择不同的提示词,并应用写作风格和记忆增强
if previous_content:
prompt = prompt_service.get_chapter_generation_with_context_prompt(
title=project.title,
@@ -418,7 +777,8 @@ async def generate_chapter_content_stream(
chapter_title=current_chapter.title,
chapter_outline=outline.content if outline else current_chapter.summary or '暂无大纲',
style_content=style_content,
target_word_count=target_word_count
target_word_count=target_word_count,
memory_context=memory_context
)
else:
prompt = prompt_service.get_chapter_generation_prompt(
@@ -436,7 +796,8 @@ async def generate_chapter_content_stream(
chapter_title=current_chapter.title,
chapter_outline=outline.content if outline else current_chapter.summary or '暂无大纲',
style_content=style_content,
target_word_count=target_word_count
target_word_count=target_word_count,
memory_context=memory_context
)
logger.info(f"开始AI流式创作章节 {chapter_id}")
@@ -474,8 +835,48 @@ async def generate_chapter_content_stream(
logger.info(f"成功创作章节 {chapter_id},共 {new_word_count}")
# 发送完成事件
yield f"data: {json.dumps({'type': 'done', 'message': '创作完成', 'word_count': new_word_count}, ensure_ascii=False)}\n\n"
# 创建分析任务并启动后台分析
analysis_task = AnalysisTask(
chapter_id=chapter_id,
user_id=current_user_id,
project_id=project.id,
status='pending',
progress=0
)
db_session.add(analysis_task)
await db_session.commit()
# 不需要refresh,只需要获取ID
task_id = analysis_task.id
# 启动后台分析任务
background_tasks.add_task(
analyze_chapter_background,
chapter_id=chapter_id,
user_id=current_user_id,
project_id=project.id,
task_id=task_id,
ai_service=user_ai_service
)
logger.info(f"📋 已创建分析任务: {task_id}")
# 发送完成事件(包含分析任务ID
completion_data = {
'type': 'done',
'message': '创作完成',
'word_count': new_word_count,
'analysis_task_id': task_id
}
yield f"data: {json.dumps(completion_data, ensure_ascii=False)}\n\n"
# 发送分析排队事件
analysis_queued_data = {
'type': 'analysis_queued',
'task_id': task_id,
'message': '章节分析已加入队列'
}
yield f"data: {json.dumps(analysis_queued_data, ensure_ascii=False)}\n\n"
break # 退出async for db_session循环
@@ -527,3 +928,308 @@ async def generate_chapter_content_stream(
"X-Accel-Buffering": "no"
}
)
@router.get("/{chapter_id}/analysis/status", summary="查询章节分析任务状态")
async def get_analysis_task_status(
chapter_id: str,
db: AsyncSession = Depends(get_db)
):
"""
查询指定章节的最新分析任务状态
返回:
- task_id: 任务ID
- status: pending/running/completed/failed
- progress: 0-100
- error_message: 错误信息(如果失败)
- created_at: 创建时间
- completed_at: 完成时间
"""
# 获取该章节最新的分析任务
result = await db.execute(
select(AnalysisTask)
.where(AnalysisTask.chapter_id == chapter_id)
.order_by(AnalysisTask.created_at.desc())
.limit(1)
)
task = result.scalar_one_or_none()
if not task:
raise HTTPException(status_code=404, detail="未找到分析任务")
return {
"task_id": task.id,
"chapter_id": task.chapter_id,
"status": task.status,
"progress": task.progress,
"error_message": task.error_message,
"created_at": task.created_at.isoformat() if task.created_at else None,
"started_at": task.started_at.isoformat() if task.started_at else None,
"completed_at": task.completed_at.isoformat() if task.completed_at else None
}
@router.get("/{chapter_id}/analysis", summary="获取章节分析结果")
async def get_chapter_analysis(
chapter_id: str,
db: AsyncSession = Depends(get_db)
):
"""
获取章节的完整分析结果
返回:
- analysis_data: 完整的分析数据(JSON)
- summary: 分析摘要文本
- memories: 提取的记忆列表
- created_at: 分析时间
"""
# 获取分析结果
analysis_result = await db.execute(
select(PlotAnalysis)
.where(PlotAnalysis.chapter_id == chapter_id)
.order_by(PlotAnalysis.created_at.desc())
.limit(1)
)
analysis = analysis_result.scalar_one_or_none()
if not analysis:
raise HTTPException(status_code=404, detail="该章节暂无分析结果")
# 获取相关记忆
memories_result = await db.execute(
select(StoryMemory)
.where(StoryMemory.chapter_id == chapter_id)
.order_by(StoryMemory.importance_score.desc())
)
memories = memories_result.scalars().all()
return {
"chapter_id": chapter_id,
"analysis": analysis.to_dict(), # 使用to_dict()方法
"memories": [
{
"id": mem.id,
"type": mem.memory_type,
"title": mem.title,
"content": mem.content,
"importance": mem.importance_score,
"tags": mem.tags,
"is_foreshadow": mem.is_foreshadow,
"position": mem.chapter_position,
"related_characters": mem.related_characters
}
for mem in memories
],
"created_at": analysis.created_at.isoformat() if analysis.created_at else None
}
@router.get("/{chapter_id}/annotations", summary="获取章节标注数据")
async def get_chapter_annotations(
chapter_id: str,
db: AsyncSession = Depends(get_db)
):
"""
获取章节的标注数据(用于前端展示标注)
返回格式化的标注列表,包含精确位置信息
适用于章节内容的可视化标注展示
"""
# 获取章节
chapter_result = await db.execute(
select(Chapter).where(Chapter.id == chapter_id)
)
chapter = chapter_result.scalar_one_or_none()
if not chapter:
raise HTTPException(status_code=404, detail="章节不存在")
# 获取分析结果
analysis_result = await db.execute(
select(PlotAnalysis)
.where(PlotAnalysis.chapter_id == chapter_id)
.order_by(PlotAnalysis.created_at.desc())
.limit(1)
)
analysis = analysis_result.scalar_one_or_none()
# 获取记忆
memories_result = await db.execute(
select(StoryMemory)
.where(StoryMemory.chapter_id == chapter_id)
.order_by(StoryMemory.importance_score.desc())
)
memories = memories_result.scalars().all()
# 构建标注数据
annotations = []
for mem in memories:
# 优先从数据库读取位置信息
position = mem.chapter_position if mem.chapter_position is not None else -1
length = mem.text_length if hasattr(mem, 'text_length') and mem.text_length is not None else 0
metadata_extra = {}
# 如果数据库中没有位置信息,尝试从分析数据中重新计算
if position == -1 and analysis and chapter.content:
# 根据记忆类型从分析数据中查找对应项
if mem.memory_type == 'hook' and analysis.hooks:
for hook in analysis.hooks:
# 通过标题或内容匹配
if mem.title and hook.get('type') in mem.title:
keyword = hook.get('keyword', '')
if keyword:
pos = chapter.content.find(keyword)
if pos != -1:
position = pos
length = len(keyword)
metadata_extra["strength"] = hook.get('strength', 5)
metadata_extra["position_desc"] = hook.get('position', '')
break
elif mem.memory_type == 'foreshadow' and analysis.foreshadows:
for foreshadow in analysis.foreshadows:
if foreshadow.get('content') in mem.content:
keyword = foreshadow.get('keyword', '')
if keyword:
pos = chapter.content.find(keyword)
if pos != -1:
position = pos
length = len(keyword)
metadata_extra["foreshadow_type"] = foreshadow.get('type', 'planted')
metadata_extra["strength"] = foreshadow.get('strength', 5)
break
elif mem.memory_type == 'plot_point' and analysis.plot_points:
for plot_point in analysis.plot_points:
if plot_point.get('content') in mem.content:
keyword = plot_point.get('keyword', '')
if keyword:
pos = chapter.content.find(keyword)
if pos != -1:
position = pos
length = len(keyword)
break
else:
# 如果数据库有位置,也从分析数据中提取额外的元数据
if analysis:
if mem.memory_type == 'hook' and analysis.hooks:
for hook in analysis.hooks:
if mem.title and hook.get('type') in mem.title:
metadata_extra["strength"] = hook.get('strength', 5)
metadata_extra["position_desc"] = hook.get('position', '')
break
elif mem.memory_type == 'foreshadow' and analysis.foreshadows:
for foreshadow in analysis.foreshadows:
if foreshadow.get('content') in mem.content:
metadata_extra["foreshadow_type"] = foreshadow.get('type', 'planted')
metadata_extra["strength"] = foreshadow.get('strength', 5)
break
annotation = {
"id": mem.id,
"type": mem.memory_type,
"title": mem.title,
"content": mem.content,
"importance": mem.importance_score or 0.5,
"position": position,
"length": length,
"tags": mem.tags or [],
"metadata": {
"is_foreshadow": mem.is_foreshadow,
"related_characters": mem.related_characters or [],
"related_locations": mem.related_locations or [],
**metadata_extra
}
}
annotations.append(annotation)
return {
"chapter_id": chapter_id,
"chapter_number": chapter.chapter_number,
"title": chapter.title,
"word_count": chapter.word_count or 0,
"annotations": annotations,
"has_analysis": analysis is not None,
"summary": {
"total_annotations": len(annotations),
"hooks": len([a for a in annotations if a["type"] == "hook"]),
"foreshadows": len([a for a in annotations if a["type"] == "foreshadow"]),
"plot_points": len([a for a in annotations if a["type"] == "plot_point"]),
"character_events": len([a for a in annotations if a["type"] == "character_event"])
}
}
@router.post("/{chapter_id}/analyze", summary="手动触发章节分析")
async def trigger_chapter_analysis(
chapter_id: str,
request: Request,
background_tasks: BackgroundTasks,
db: AsyncSession = Depends(get_db),
user_ai_service: AIService = Depends(get_user_ai_service)
):
"""
手动触发章节分析(用于重新分析或分析旧章节)
"""
# 从请求中获取用户ID
user_id = getattr(request.state, "user_id", None)
if not user_id:
raise HTTPException(status_code=401, detail="未登录")
# 验证章节存在
chapter_result = await db.execute(
select(Chapter).where(Chapter.id == chapter_id)
)
chapter = chapter_result.scalar_one_or_none()
if not chapter:
raise HTTPException(status_code=404, detail="章节不存在")
if not chapter.content or chapter.content.strip() == "":
raise HTTPException(status_code=400, detail="章节内容为空,无法分析")
# 获取项目信息
project_result = await db.execute(
select(Project).where(Project.id == chapter.project_id)
)
project = project_result.scalar_one_or_none()
if not project:
raise HTTPException(status_code=404, detail="项目不存在")
# 创建分析任务
analysis_task = AnalysisTask(
chapter_id=chapter_id,
user_id=user_id,
project_id=project.id,
status='pending',
progress=0
)
db.add(analysis_task)
await db.commit()
# 注意:不需要refresh,因为我们只需要id,而id在commit后已经生成
task_id = analysis_task.id
# 启动后台分析任务
background_tasks.add_task(
analyze_chapter_background,
chapter_id=chapter_id,
user_id=user_id,
project_id=project.id,
task_id=task_id,
ai_service=user_ai_service
)
logger.info(f"📋 手动触发分析任务: {task_id}")
return {
"task_id": task_id,
"chapter_id": chapter_id,
"status": "pending",
"message": "分析任务已创建并加入队列"
}
+383
View File
@@ -0,0 +1,383 @@
"""记忆管理API - 提供记忆的查询、分析等接口"""
from fastapi import APIRouter, Depends, HTTPException, Request
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, and_, desc
from typing import List, Optional
from app.database import get_db
from app.models.memory import StoryMemory, PlotAnalysis
from app.models.chapter import Chapter
from app.services.memory_service import memory_service
from app.services.plot_analyzer import get_plot_analyzer
from app.services.ai_service import create_user_ai_service
from app.models.settings import Settings
from app.logger import get_logger
import uuid
logger = get_logger(__name__)
router = APIRouter(prefix="/api/memories", tags=["memories"])
@router.post("/projects/{project_id}/analyze-chapter/{chapter_id}")
async def analyze_chapter(
project_id: str,
chapter_id: str,
request: Request,
db: AsyncSession = Depends(get_db)
):
"""
分析章节并生成记忆
对指定章节进行剧情分析,提取钩子、伏笔、情节点等,并存入记忆系统
"""
try:
user_id = request.state.user_id
# 获取章节内容
result = await db.execute(
select(Chapter).where(
and_(
Chapter.id == chapter_id,
Chapter.project_id == project_id
)
)
)
chapter = result.scalar_one_or_none()
if not chapter:
raise HTTPException(status_code=404, detail="章节不存在")
if not chapter.content:
raise HTTPException(status_code=400, detail="章节内容为空,无法分析")
# 获取用户AI设置
settings_result = await db.execute(select(Settings))
settings = settings_result.scalar_one_or_none()
if not settings:
raise HTTPException(status_code=400, detail="请先配置AI设置")
# 创建AI服务
ai_service = create_user_ai_service(
api_provider=settings.api_provider,
api_key=settings.api_key,
api_base_url=settings.api_base_url,
model_name=settings.model_name,
temperature=settings.temperature,
max_tokens=settings.max_tokens
)
# 执行剧情分析
analyzer = get_plot_analyzer(ai_service)
analysis_result = await analyzer.analyze_chapter(
chapter_number=chapter.chapter_number,
title=chapter.title,
content=chapter.content,
word_count=chapter.word_count or len(chapter.content)
)
if not analysis_result:
raise HTTPException(status_code=500, detail="剧情分析失败")
# 保存分析结果到数据库
plot_analysis = PlotAnalysis(
id=str(uuid.uuid4()),
project_id=project_id,
chapter_id=chapter_id,
plot_stage=analysis_result.get('plot_stage'),
conflict_level=analysis_result.get('conflict', {}).get('level'),
conflict_types=analysis_result.get('conflict', {}).get('types'),
emotional_tone=analysis_result.get('emotional_arc', {}).get('primary_emotion'),
emotional_intensity=analysis_result.get('emotional_arc', {}).get('intensity', 0) / 10,
emotional_curve=analysis_result.get('emotional_arc'),
hooks=analysis_result.get('hooks'),
hooks_count=len(analysis_result.get('hooks', [])),
hooks_avg_strength=sum(h.get('strength', 0) for h in analysis_result.get('hooks', [])) / max(len(analysis_result.get('hooks', [])), 1),
foreshadows=analysis_result.get('foreshadows'),
foreshadows_planted=sum(1 for f in analysis_result.get('foreshadows', []) if f.get('type') == 'planted'),
foreshadows_resolved=sum(1 for f in analysis_result.get('foreshadows', []) if f.get('type') == 'resolved'),
plot_points=analysis_result.get('plot_points'),
plot_points_count=len(analysis_result.get('plot_points', [])),
character_states=analysis_result.get('character_states'),
scenes=analysis_result.get('scenes'),
pacing=analysis_result.get('pacing'),
dialogue_ratio=analysis_result.get('dialogue_ratio'),
description_ratio=analysis_result.get('description_ratio'),
overall_quality_score=analysis_result.get('scores', {}).get('overall'),
pacing_score=analysis_result.get('scores', {}).get('pacing'),
engagement_score=analysis_result.get('scores', {}).get('engagement'),
coherence_score=analysis_result.get('scores', {}).get('coherence'),
analysis_report=analyzer.generate_analysis_summary(analysis_result),
suggestions=analysis_result.get('suggestions'),
word_count=chapter.word_count
)
# 检查是否已存在分析记录
existing = await db.execute(
select(PlotAnalysis).where(PlotAnalysis.chapter_id == chapter_id)
)
if existing.scalar_one_or_none():
# 删除旧记录
await db.execute(
select(PlotAnalysis).where(PlotAnalysis.chapter_id == chapter_id)
)
await db.delete(existing.scalar_one())
db.add(plot_analysis)
await db.commit()
# 从分析结果中提取记忆片段
memories_data = analyzer.extract_memories_from_analysis(
analysis_result,
chapter_id,
chapter.chapter_number
)
# 保存记忆到数据库和向量库
saved_count = 0
for mem_data in memories_data:
memory_id = str(uuid.uuid4())
# 保存到关系数据库
memory = StoryMemory(
id=memory_id,
project_id=project_id,
chapter_id=chapter_id,
memory_type=mem_data['type'],
title=mem_data.get('title', ''),
content=mem_data['content'],
story_timeline=chapter.chapter_number,
vector_id=memory_id,
**mem_data['metadata']
)
db.add(memory)
# 保存到向量库
await memory_service.add_memory(
user_id=user_id,
project_id=project_id,
memory_id=memory_id,
content=mem_data['content'],
memory_type=mem_data['type'],
metadata=mem_data['metadata']
)
saved_count += 1
await db.commit()
logger.info(f"✅ 章节分析完成: 保存{saved_count}条记忆")
return {
"success": True,
"message": f"分析完成,提取了{saved_count}条记忆",
"analysis": plot_analysis.to_dict(),
"memories_count": saved_count
}
except HTTPException:
raise
except Exception as e:
logger.error(f"❌ 章节分析失败: {str(e)}")
await db.rollback()
raise HTTPException(status_code=500, detail=f"分析失败: {str(e)}")
@router.get("/projects/{project_id}/memories")
async def get_project_memories(
project_id: str,
request: Request,
memory_type: Optional[str] = None,
chapter_id: Optional[str] = None,
limit: int = 50,
db: AsyncSession = Depends(get_db)
):
"""获取项目的记忆列表"""
try:
user_id = request.state.user_id
# 构建查询
query = select(StoryMemory).where(StoryMemory.project_id == project_id)
if memory_type:
query = query.where(StoryMemory.memory_type == memory_type)
if chapter_id:
query = query.where(StoryMemory.chapter_id == chapter_id)
query = query.order_by(desc(StoryMemory.importance_score), desc(StoryMemory.created_at)).limit(limit)
result = await db.execute(query)
memories = result.scalars().all()
return {
"success": True,
"memories": [mem.to_dict() for mem in memories],
"total": len(memories)
}
except Exception as e:
logger.error(f"❌ 获取记忆失败: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@router.get("/projects/{project_id}/analysis/{chapter_id}")
async def get_chapter_analysis(
project_id: str,
chapter_id: str,
db: AsyncSession = Depends(get_db)
):
"""获取章节的剧情分析"""
try:
result = await db.execute(
select(PlotAnalysis).where(
and_(
PlotAnalysis.project_id == project_id,
PlotAnalysis.chapter_id == chapter_id
)
)
)
analysis = result.scalar_one_or_none()
if not analysis:
raise HTTPException(status_code=404, detail="该章节还未进行分析")
return {
"success": True,
"analysis": analysis.to_dict()
}
except HTTPException:
raise
except Exception as e:
logger.error(f"❌ 获取分析失败: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/projects/{project_id}/search")
async def search_memories(
project_id: str,
request: Request,
query: str,
memory_types: Optional[List[str]] = None,
limit: int = 10,
min_importance: float = 0.0
):
"""语义搜索项目记忆"""
try:
user_id = request.state.user_id
memories = await memory_service.search_memories(
user_id=user_id,
project_id=project_id,
query=query,
memory_types=memory_types,
limit=limit,
min_importance=min_importance
)
return {
"success": True,
"query": query,
"memories": memories,
"total": len(memories)
}
except Exception as e:
logger.error(f"❌ 搜索记忆失败: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@router.get("/projects/{project_id}/foreshadows")
async def get_unresolved_foreshadows(
project_id: str,
request: Request,
current_chapter: int,
db: AsyncSession = Depends(get_db)
):
"""获取未完结的伏笔"""
try:
user_id = request.state.user_id
# 从向量库搜索
foreshadows = await memory_service.find_unresolved_foreshadows(
user_id=user_id,
project_id=project_id,
current_chapter=current_chapter
)
return {
"success": True,
"foreshadows": foreshadows,
"total": len(foreshadows)
}
except Exception as e:
logger.error(f"❌ 获取伏笔失败: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@router.get("/projects/{project_id}/stats")
async def get_memory_stats(
project_id: str,
request: Request
):
"""获取记忆统计信息"""
try:
user_id = request.state.user_id
stats = await memory_service.get_memory_stats(
user_id=user_id,
project_id=project_id
)
return {
"success": True,
"stats": stats
}
except Exception as e:
logger.error(f"❌ 获取统计失败: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@router.delete("/projects/{project_id}/chapters/{chapter_id}/memories")
async def delete_chapter_memories(
project_id: str,
chapter_id: str,
request: Request,
db: AsyncSession = Depends(get_db)
):
"""删除章节的所有记忆"""
try:
user_id = request.state.user_id
# 从数据库删除
result = await db.execute(
select(StoryMemory).where(
and_(
StoryMemory.project_id == project_id,
StoryMemory.chapter_id == chapter_id
)
)
)
memories = result.scalars().all()
for memory in memories:
await db.delete(memory)
# 从向量库删除
await memory_service.delete_chapter_memories(
user_id=user_id,
project_id=project_id,
chapter_id=chapter_id
)
await db.commit()
return {
"success": True,
"message": f"已删除{len(memories)}条记忆"
}
except Exception as e:
logger.error(f"❌ 删除记忆失败: {str(e)}")
await db.rollback()
raise HTTPException(status_code=500, detail=str(e))
+62 -12
View File
@@ -1,5 +1,5 @@
"""大纲管理API"""
from fastapi import APIRouter, Depends, HTTPException
from fastapi import APIRouter, Depends, HTTPException, Request
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, func, delete
from typing import List, AsyncGenerator, Dict, Any
@@ -21,6 +21,7 @@ from app.schemas.outline import (
)
from app.services.ai_service import AIService
from app.services.prompt_service import prompt_service
from app.services.memory_service import memory_service
from app.logger import get_logger
from app.api.settings import get_user_ai_service
from app.utils.sse_response import SSEResponse, create_sse_response
@@ -328,6 +329,7 @@ async def reorder_outlines(
@router.post("/generate", response_model=OutlineListResponse, summary="AI生成/续写大纲")
async def generate_outline(
request: OutlineGenerateRequest,
http_request: Request,
db: AsyncSession = Depends(get_db),
user_ai_service: AIService = Depends(get_user_ai_service)
):
@@ -377,8 +379,10 @@ async def generate_outline(
detail="续写模式需要已有大纲,当前项目没有大纲"
)
# 获取用户ID用于记忆检索
user_id = getattr(http_request.state, "user_id", "system")
return await _continue_outline(
request, project, existing_outlines, db, user_ai_service
request, project, existing_outlines, db, user_ai_service, user_id
)
else:
@@ -478,9 +482,10 @@ async def _continue_outline(
project: Project,
existing_outlines: List[Outline],
db: AsyncSession,
user_ai_service: AIService
user_ai_service: AIService,
user_id: str = "system"
) -> OutlineListResponse:
"""续写大纲 - 分批生成,每批5章"""
"""续写大纲 - 分批生成,每批5章(记忆增强版)"""
logger.info(f"续写大纲 - 项目: {project.id}, 已有: {len(existing_outlines)}")
# 分析已有大纲
@@ -545,7 +550,25 @@ async def _continue_outline(
for o in latest_outlines
])
# 使用标准续写提示词模板
# 🧠 构建记忆增强上下文(仅续写模式需要)
memory_context = None
try:
logger.info(f"🧠 为第{batch_num + 1}批构建记忆上下文...")
# 使用最近一章的大纲作为查询
query_outline = recent_outlines[-1].content if recent_outlines else ""
memory_context = await memory_service.build_context_for_generation(
user_id=user_id,
project_id=project.id,
current_chapter=current_start_chapter,
chapter_outline=query_outline,
character_names=[c.name for c in characters] if characters else None
)
logger.info(f"✅ 记忆上下文构建完成: {memory_context['stats']}")
except Exception as e:
logger.warning(f"⚠️ 记忆上下文构建失败,继续不使用记忆: {str(e)}")
memory_context = None
# 使用标准续写提示词模板(支持记忆增强)
prompt = prompt_service.get_outline_continue_prompt(
title=project.title,
theme=request.theme or project.theme or "未设定",
@@ -563,7 +586,8 @@ async def _continue_outline(
plot_stage_instruction=stage_instruction,
start_chapter=current_start_chapter,
story_direction=request.story_direction or "自然延续",
requirements=request.requirements or ""
requirements=request.requirements or "",
memory_context=memory_context
)
# 调用AI生成当前批次
@@ -834,9 +858,10 @@ async def new_outline_generator(
async def continue_outline_generator(
data: Dict[str, Any],
db: AsyncSession,
user_ai_service: AIService
user_ai_service: AIService,
user_id: str = "system"
) -> AsyncGenerator[str, None]:
"""大纲续写SSE生成器 - 分批生成,推送进度"""
"""大纲续写SSE生成器 - 分批生成,推送进度(记忆增强版)"""
db_committed = False
try:
yield await SSEResponse.send_progress("开始续写大纲...", 5)
@@ -940,12 +965,32 @@ async def continue_outline_generator(
for o in latest_outlines
])
# 🧠 构建记忆增强上下文
memory_context = None
try:
yield await SSEResponse.send_progress(
f"🧠 构建记忆上下文...",
batch_progress + 3
)
query_outline = recent_outlines[-1].content if recent_outlines else ""
memory_context = await memory_service.build_context_for_generation(
user_id=user_id,
project_id=project_id,
current_chapter=current_start_chapter,
chapter_outline=query_outline,
character_names=[c.name for c in characters] if characters else None
)
logger.info(f"✅ 记忆上下文: {memory_context['stats']}")
except Exception as e:
logger.warning(f"⚠️ 记忆上下文构建失败: {str(e)}")
memory_context = None
yield await SSEResponse.send_progress(
f"🤖 调用AI生成第{str(batch_num + 1)}批...",
f" 调用AI生成第{str(batch_num + 1)}批...",
batch_progress + 5
)
# 使用标准续写提示词模板
# 使用标准续写提示词模板(支持记忆增强)
prompt = prompt_service.get_outline_continue_prompt(
title=project.title,
theme=data.get("theme") or project.theme or "未设定",
@@ -963,7 +1008,8 @@ async def continue_outline_generator(
plot_stage_instruction=stage_instruction,
start_chapter=current_start_chapter,
story_direction=data.get("story_direction", "自然延续"),
requirements=data.get("requirements", "")
requirements=data.get("requirements", ""),
memory_context=memory_context
)
# 调用AI生成当前批次
@@ -1062,6 +1108,7 @@ async def continue_outline_generator(
@router.post("/generate-stream", summary="AI生成/续写大纲(SSE流式)")
async def generate_outline_stream(
data: Dict[str, Any],
request: Request,
db: AsyncSession = Depends(get_db),
user_ai_service: AIService = Depends(get_user_ai_service)
):
@@ -1111,6 +1158,9 @@ async def generate_outline_stream(
mode = "continue" if existing_outlines else "new"
logger.info(f"自动判断模式:{'续写' if existing_outlines else '新建'}")
# 获取用户ID
user_id = getattr(request.state, "user_id", "system")
# 根据模式选择生成器
if mode == "new":
return create_sse_response(new_outline_generator(data, db, user_ai_service))
@@ -1120,7 +1170,7 @@ async def generate_outline_stream(
status_code=400,
detail="续写模式需要已有大纲,当前项目没有大纲"
)
return create_sse_response(continue_outline_generator(data, db, user_ai_service))
return create_sse_response(continue_outline_generator(data, db, user_ai_service, user_id))
else:
raise HTTPException(
status_code=400,
+174 -2
View File
@@ -1,9 +1,11 @@
"""项目管理API"""
from fastapi import APIRouter, Depends, HTTPException
from fastapi import APIRouter, Depends, HTTPException, UploadFile, File
from fastapi.responses import Response
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, func, delete
from typing import List
import json
from urllib.parse import quote
from app.database import get_db
from app.models.project import Project
from app.models.character import Character
@@ -17,6 +19,12 @@ from app.schemas.project import (
ProjectResponse,
ProjectListResponse
)
from app.schemas.import_export import (
ExportOptions,
ImportValidationResult,
ImportResult
)
from app.services.import_export_service import ImportExportService
from app.logger import get_logger
from app.utils.data_consistency import (
run_full_data_consistency_check,
@@ -412,4 +420,168 @@ async def fix_project_member_counts(
raise
except Exception as e:
logger.error(f"修复成员计数失败: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail=f"修复失败: {str(e)}")
raise HTTPException(status_code=500, detail=f"修复失败: {str(e)}")
@router.post("/{project_id}/export-data", summary="导出项目数据为JSON")
async def export_project_data(
project_id: str,
options: ExportOptions,
db: AsyncSession = Depends(get_db)
):
"""
导出项目完整数据为JSON格式
Args:
project_id: 项目ID
options: 导出选项
Returns:
JSON文件下载
"""
try:
logger.info(f"开始导出项目数据: {project_id}")
# 检查项目是否存在
result = await db.execute(
select(Project).where(Project.id == project_id)
)
project = result.scalar_one_or_none()
if not project:
logger.warning(f"项目不存在: {project_id}")
raise HTTPException(status_code=404, detail="项目不存在")
# 导出数据
export_data = await ImportExportService.export_project(
project_id=project_id,
db=db,
include_generation_history=options.include_generation_history,
include_writing_styles=options.include_writing_styles
)
# 转换为JSON
json_content = export_data.model_dump_json(indent=2, exclude_none=True, by_alias=True)
# 生成文件名
safe_title = "".join(c for c in project.title if c.isalnum() or c in (' ', '-', '_'))
from datetime import datetime
date_str = datetime.now().strftime("%Y%m%d")
filename = f"project_{safe_title}_{date_str}.json"
encoded_filename = quote(filename)
logger.info(f"项目数据导出成功: {filename}")
return Response(
content=json_content.encode('utf-8'),
media_type="application/json; charset=utf-8",
headers={
"Content-Disposition": f"attachment; filename*=UTF-8''{encoded_filename}",
"Content-Type": "application/json; charset=utf-8"
}
)
except HTTPException:
raise
except Exception as e:
logger.error(f"导出项目数据失败: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail=f"导出失败: {str(e)}")
@router.post("/validate-import", response_model=ImportValidationResult, summary="验证导入文件")
async def validate_import_file(
file: UploadFile = File(...)
):
"""
验证导入文件的格式和内容
Args:
file: 上传的JSON文件
Returns:
验证结果
"""
try:
logger.info(f"验证导入文件: {file.filename}")
# 检查文件类型
if not file.filename.endswith('.json'):
raise HTTPException(status_code=400, detail="只支持JSON格式文件")
# 读取文件内容
content = await file.read()
# 检查文件大小(50MB限制)
max_size = 50 * 1024 * 1024 # 50MB
if len(content) > max_size:
raise HTTPException(status_code=413, detail="文件大小超过50MB限制")
# 解析JSON
try:
data = json.loads(content.decode('utf-8'))
except json.JSONDecodeError as e:
raise HTTPException(status_code=400, detail=f"无效的JSON格式: {str(e)}")
# 验证数据
validation_result = ImportExportService.validate_import_data(data)
logger.info(f"文件验证完成: valid={validation_result.valid}")
return validation_result
except HTTPException:
raise
except Exception as e:
logger.error(f"验证导入文件失败: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail=f"验证失败: {str(e)}")
@router.post("/import", response_model=ImportResult, summary="导入项目")
async def import_project(
file: UploadFile = File(...),
db: AsyncSession = Depends(get_db)
):
"""
导入项目数据(创建新项目)
Args:
file: 上传的JSON文件
Returns:
导入结果
"""
try:
logger.info(f"开始导入项目: {file.filename}")
# 检查文件类型
if not file.filename.endswith('.json'):
raise HTTPException(status_code=400, detail="只支持JSON格式文件")
# 读取文件内容
content = await file.read()
# 检查文件大小
max_size = 50 * 1024 * 1024 # 50MB
if len(content) > max_size:
raise HTTPException(status_code=413, detail="文件大小超过50MB限制")
# 解析JSON
try:
data = json.loads(content.decode('utf-8'))
except json.JSONDecodeError as e:
raise HTTPException(status_code=400, detail=f"无效的JSON格式: {str(e)}")
# 导入数据
import_result = await ImportExportService.import_project(data, db)
if import_result.success:
logger.info(f"项目导入成功: {import_result.project_id}")
else:
logger.warning(f"项目导入失败: {import_result.message}")
return import_result
except HTTPException:
raise
except Exception as e:
logger.error(f"导入项目失败: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail=f"导入失败: {str(e)}")