update:1.更新导入导出功能 2.实现RAG记忆功能,引入剧情分析功能

This commit is contained in:
xiamuceer
2025-11-04 14:38:59 +08:00
parent 1cde345ed9
commit e4f90d5da0
26 changed files with 6722 additions and 84 deletions
+712 -6
View File
@@ -1,11 +1,12 @@
"""章节管理API"""
from fastapi import APIRouter, Depends, HTTPException, Request, Query
from fastapi import APIRouter, Depends, HTTPException, Request, Query, BackgroundTasks
from fastapi.responses import StreamingResponse
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, func
import json
import asyncio
from typing import Optional
from datetime import datetime
from app.database import get_db
from app.models.chapter import Chapter
@@ -14,6 +15,8 @@ from app.models.outline import Outline
from app.models.character import Character
from app.models.generation_history import GenerationHistory
from app.models.writing_style import WritingStyle
from app.models.analysis_task import AnalysisTask
from app.models.memory import PlotAnalysis, StoryMemory
from app.schemas.chapter import (
ChapterCreate,
ChapterUpdate,
@@ -23,6 +26,8 @@ from app.schemas.chapter import (
)
from app.services.ai_service import AIService
from app.services.prompt_service import prompt_service
from app.services.plot_analyzer import PlotAnalyzer
from app.services.memory_service import memory_service
from app.logger import get_logger
from app.api.settings import get_user_ai_service
@@ -101,6 +106,63 @@ async def get_chapter(
return chapter
@router.get("/{chapter_id}/navigation", summary="获取章节导航信息")
async def get_chapter_navigation(
chapter_id: str,
db: AsyncSession = Depends(get_db)
):
"""
获取章节的导航信息(上一章/下一章)
用于章节阅读器的翻页功能
"""
# 获取当前章节
result = await db.execute(
select(Chapter).where(Chapter.id == chapter_id)
)
current_chapter = result.scalar_one_or_none()
if not current_chapter:
raise HTTPException(status_code=404, detail="章节不存在")
# 获取上一章
prev_result = await db.execute(
select(Chapter)
.where(Chapter.project_id == current_chapter.project_id)
.where(Chapter.chapter_number < current_chapter.chapter_number)
.order_by(Chapter.chapter_number.desc())
.limit(1)
)
prev_chapter = prev_result.scalar_one_or_none()
# 获取下一章
next_result = await db.execute(
select(Chapter)
.where(Chapter.project_id == current_chapter.project_id)
.where(Chapter.chapter_number > current_chapter.chapter_number)
.order_by(Chapter.chapter_number.asc())
.limit(1)
)
next_chapter = next_result.scalar_one_or_none()
return {
"current": {
"id": current_chapter.id,
"chapter_number": current_chapter.chapter_number,
"title": current_chapter.title
},
"previous": {
"id": prev_chapter.id,
"chapter_number": prev_chapter.chapter_number,
"title": prev_chapter.title
} if prev_chapter else None,
"next": {
"id": next_chapter.id,
"chapter_number": next_chapter.chapter_number,
"title": next_chapter.title
} if next_chapter else None
}
@router.put("/{chapter_id}", response_model=ChapterResponse, summary="更新章节")
async def update_chapter(
chapter_id: str,
@@ -248,10 +310,273 @@ async def check_can_generate(
}
async def analyze_chapter_background(
chapter_id: str,
user_id: str,
project_id: str,
task_id: str,
ai_service: AIService
):
"""
后台异步分析章节
Args:
chapter_id: 章节ID
user_id: 用户ID
project_id: 项目ID
task_id: 任务ID
ai_service: AI服务实例
"""
db_session = None
try:
logger.info(f"🔍 开始后台分析章节: {chapter_id}")
# 等待一小段时间,确保主会话的commit已经持久化到磁盘
await asyncio.sleep(0.1)
# 创建独立数据库会话
from app.database import get_engine
from sqlalchemy.ext.asyncio import async_sessionmaker, AsyncSession
engine = await get_engine(user_id)
AsyncSessionLocal = async_sessionmaker(
engine,
class_=AsyncSession,
expire_on_commit=False
)
db_session = AsyncSessionLocal()
# 1. 获取任务(添加重试逻辑)
task = None
for retry in range(3):
task_result = await db_session.execute(
select(AnalysisTask).where(AnalysisTask.id == task_id)
)
task = task_result.scalar_one_or_none()
if task:
break
if retry < 2:
logger.warning(f"⚠️ 第{retry+1}次未找到任务 {task_id},等待后重试...")
await asyncio.sleep(0.2)
if not task:
logger.error(f"❌ 任务不存在: {task_id}")
return
task.status = 'running'
task.started_at = datetime.now()
task.progress = 10
await db_session.commit()
# 2. 获取章节信息
chapter_result = await db_session.execute(
select(Chapter).where(Chapter.id == chapter_id)
)
chapter = chapter_result.scalar_one_or_none()
if not chapter or not chapter.content:
task.status = 'failed'
task.error_message = '章节不存在或内容为空'
task.completed_at = datetime.now()
await db_session.commit()
logger.error(f"❌ 章节不存在或内容为空: {chapter_id}")
return
task.progress = 20
await db_session.commit()
# 3. 使用PlotAnalyzer分析章节
analyzer = PlotAnalyzer(ai_service)
analysis_result = await analyzer.analyze_chapter(
chapter_number=chapter.chapter_number,
title=chapter.title,
content=chapter.content,
word_count=chapter.word_count or len(chapter.content)
)
if not analysis_result:
task.status = 'failed'
task.error_message = 'AI分析失败,请检查日志'
task.completed_at = datetime.now()
await db_session.commit()
logger.error(f"❌ AI分析失败: {chapter_id}")
return
task.progress = 60
await db_session.commit()
# 4. 保存分析结果到数据库(先检查是否已存在)
existing_analysis_result = await db_session.execute(
select(PlotAnalysis).where(PlotAnalysis.chapter_id == chapter_id)
)
existing_analysis = existing_analysis_result.scalar_one_or_none()
if existing_analysis:
# 更新现有记录
logger.info(f" 更新现有分析记录: {existing_analysis.id}")
existing_analysis.plot_stage = analysis_result.get('plot_stage', '发展')
existing_analysis.conflict_level = analysis_result.get('conflict', {}).get('level', 0)
existing_analysis.conflict_types = analysis_result.get('conflict', {}).get('types', [])
existing_analysis.emotional_tone = analysis_result.get('emotional_arc', {}).get('primary_emotion', '')
existing_analysis.emotional_intensity = analysis_result.get('emotional_arc', {}).get('intensity', 0) / 10.0
existing_analysis.hooks = analysis_result.get('hooks', [])
existing_analysis.hooks_count = len(analysis_result.get('hooks', []))
existing_analysis.foreshadows = analysis_result.get('foreshadows', [])
existing_analysis.foreshadows_planted = sum(1 for f in analysis_result.get('foreshadows', []) if f.get('type') == 'planted')
existing_analysis.foreshadows_resolved = sum(1 for f in analysis_result.get('foreshadows', []) if f.get('type') == 'resolved')
existing_analysis.plot_points = analysis_result.get('plot_points', [])
existing_analysis.plot_points_count = len(analysis_result.get('plot_points', []))
existing_analysis.character_states = analysis_result.get('character_states', [])
existing_analysis.scenes = analysis_result.get('scenes', [])
existing_analysis.pacing = analysis_result.get('pacing', 'moderate')
existing_analysis.overall_quality_score = analysis_result.get('scores', {}).get('overall', 0)
existing_analysis.pacing_score = analysis_result.get('scores', {}).get('pacing', 0)
existing_analysis.engagement_score = analysis_result.get('scores', {}).get('engagement', 0)
existing_analysis.coherence_score = analysis_result.get('scores', {}).get('coherence', 0)
existing_analysis.analysis_report = analyzer.generate_analysis_summary(analysis_result)
existing_analysis.suggestions = analysis_result.get('suggestions', [])
existing_analysis.dialogue_ratio = analysis_result.get('dialogue_ratio', 0)
existing_analysis.description_ratio = analysis_result.get('description_ratio', 0)
else:
# 创建新记录
logger.info(f" 创建新的分析记录")
plot_analysis = PlotAnalysis(
chapter_id=chapter_id,
project_id=project_id,
plot_stage=analysis_result.get('plot_stage', '发展'),
conflict_level=analysis_result.get('conflict', {}).get('level', 0),
conflict_types=analysis_result.get('conflict', {}).get('types', []),
emotional_tone=analysis_result.get('emotional_arc', {}).get('primary_emotion', ''),
emotional_intensity=analysis_result.get('emotional_arc', {}).get('intensity', 0) / 10.0,
hooks=analysis_result.get('hooks', []),
hooks_count=len(analysis_result.get('hooks', [])),
foreshadows=analysis_result.get('foreshadows', []),
foreshadows_planted=sum(1 for f in analysis_result.get('foreshadows', []) if f.get('type') == 'planted'),
foreshadows_resolved=sum(1 for f in analysis_result.get('foreshadows', []) if f.get('type') == 'resolved'),
plot_points=analysis_result.get('plot_points', []),
plot_points_count=len(analysis_result.get('plot_points', [])),
character_states=analysis_result.get('character_states', []),
scenes=analysis_result.get('scenes', []),
pacing=analysis_result.get('pacing', 'moderate'),
overall_quality_score=analysis_result.get('scores', {}).get('overall', 0),
pacing_score=analysis_result.get('scores', {}).get('pacing', 0),
engagement_score=analysis_result.get('scores', {}).get('engagement', 0),
coherence_score=analysis_result.get('scores', {}).get('coherence', 0),
analysis_report=analyzer.generate_analysis_summary(analysis_result),
suggestions=analysis_result.get('suggestions', []),
dialogue_ratio=analysis_result.get('dialogue_ratio', 0),
description_ratio=analysis_result.get('description_ratio', 0)
)
db_session.add(plot_analysis)
await db_session.commit()
task.progress = 80
await db_session.commit()
# 5. 提取记忆并保存到向量数据库(传入章节内容用于计算位置)
memories = analyzer.extract_memories_from_analysis(
analysis=analysis_result,
chapter_id=chapter_id,
chapter_number=chapter.chapter_number,
chapter_content=chapter.content or ""
)
# 先删除该章节的旧记忆(支持重新分析)
old_memories_result = await db_session.execute(
select(StoryMemory).where(StoryMemory.chapter_id == chapter_id)
)
old_memories = old_memories_result.scalars().all()
for old_mem in old_memories:
await db_session.delete(old_mem)
logger.info(f" 删除旧记忆: {len(old_memories)}")
# 准备批量添加的记忆数据
memory_records = []
for mem in memories:
memory_id = f"{chapter_id}_{mem['type']}_{len(memory_records)}"
memory_records.append({
'id': memory_id,
'content': mem['content'],
'type': mem['type'],
'metadata': mem['metadata']
})
# 从metadata中提取位置信息
text_position = mem['metadata'].get('text_position', -1)
text_length = mem['metadata'].get('text_length', 0)
# 同时保存到关系数据库
story_memory = StoryMemory(
id=memory_id,
project_id=project_id,
chapter_id=chapter_id,
memory_type=mem['type'],
content=mem['content'],
title=mem['title'],
importance_score=mem['metadata'].get('importance_score', 0.5),
tags=mem['metadata'].get('tags', []),
is_foreshadow=mem['metadata'].get('is_foreshadow', 0),
story_timeline=chapter.chapter_number, # 使用章节序号作为时间线
chapter_position=text_position, # 保存文本位置
text_length=text_length, # 保存文本长度
related_characters=mem['metadata'].get('related_characters', []),
related_locations=mem['metadata'].get('related_locations', [])
)
db_session.add(story_memory)
# 记录日志便于调试
if text_position >= 0:
logger.debug(f" 保存记忆 {memory_id}: position={text_position}, length={text_length}")
await db_session.commit()
# 批量添加到向量数据库
if memory_records:
added_count = await memory_service.batch_add_memories(
user_id=user_id,
project_id=project_id,
memories=memory_records
)
logger.info(f"✅ 添加{added_count}条记忆到向量库")
task.progress = 100
task.status = 'completed'
task.completed_at = datetime.now()
await db_session.commit()
logger.info(f"✅ 章节分析完成: {chapter_id}, 提取{len(memories)}条记忆")
except Exception as e:
logger.error(f"❌ 后台分析异常: {str(e)}", exc_info=True)
# 确保任务状态被更新为failed,避免前端一直轮询
if db_session:
try:
# 重新获取任务以确保有最新状态
task_result = await db_session.execute(
select(AnalysisTask).where(AnalysisTask.id == task_id)
)
task = task_result.scalar_one_or_none()
if task:
task.status = 'failed'
task.error_message = str(e)[:500]
task.completed_at = datetime.now()
task.progress = 0 # 重置进度
await db_session.commit()
logger.info(f"✅ 任务状态已更新为failed: {task_id}")
else:
logger.error(f"❌ 无法找到任务进行状态更新: {task_id}")
except Exception as update_error:
logger.error(f"❌ 更新任务状态失败: {str(update_error)}")
finally:
if db_session:
await db_session.close()
@router.post("/{chapter_id}/generate-stream", summary="AI创作章节内容(流式)")
async def generate_chapter_content_stream(
chapter_id: str,
request: Request,
background_tasks: BackgroundTasks,
generate_request: ChapterGenerateRequest = ChapterGenerateRequest(),
user_ai_service: AIService = Depends(get_user_ai_service)
):
@@ -301,6 +626,9 @@ async def generate_chapter_content_stream(
# 在生成器内部创建独立的数据库会话
db_session = None
db_committed = False
# 获取当前用户ID(在生成器外部就需要)
current_user_id = getattr(request.state, "user_id", "system")
try:
# 创建新的数据库会话
async for db_session in get_db(request):
@@ -396,11 +724,42 @@ async def generate_chapter_content_stream(
previous_content += recent_content
logger.info(f"构建前置上下文:{len(early_chapters)}章摘要 + {len(recent_chapters)}章完整内容")
# 🧠 构建记忆增强上下文
logger.info(f"🧠 开始构建记忆增强上下文...")
memory_context = await memory_service.build_context_for_generation(
user_id=current_user_id,
project_id=project.id,
current_chapter=current_chapter.chapter_number,
chapter_outline=outline.content if outline else current_chapter.summary or "",
character_names=[c.name for c in characters] if characters else None
)
# 计算各部分的字符长度
context_lengths = {
'recent_context': len(memory_context.get('recent_context', '')),
'relevant_memories': len(memory_context.get('relevant_memories', '')),
'foreshadows': len(memory_context.get('foreshadows', '')),
'character_states': len(memory_context.get('character_states', '')),
'plot_points': len(memory_context.get('plot_points', ''))
}
total_memory_length = sum(context_lengths.values())
logger.info(f"✅ 记忆上下文构建完成: {memory_context['stats']}")
logger.info(f"📏 记忆上下文长度统计:")
logger.info(f" - 最近章节记忆: {context_lengths['recent_context']} 字符")
logger.info(f" - 语义相关记忆: {context_lengths['relevant_memories']} 字符")
logger.info(f" - 未完结伏笔: {context_lengths['foreshadows']} 字符")
logger.info(f" - 角色状态记忆: {context_lengths['character_states']} 字符")
logger.info(f" - 重要情节点: {context_lengths['plot_points']} 字符")
logger.info(f" - 记忆总长度: {total_memory_length} 字符")
logger.info(f" - 前置章节上下文长度: {len(previous_content)} 字符")
logger.info(f" - 总上下文长度(估算): {total_memory_length + len(previous_content) + 2000} 字符")
# 发送开始事件
yield f"data: {json.dumps({'type': 'start', 'message': '开始AI创作...'}, ensure_ascii=False)}\n\n"
# 根据是否有前置内容选择不同的提示词,并应用写作风格
# 根据是否有前置内容选择不同的提示词,并应用写作风格和记忆增强
if previous_content:
prompt = prompt_service.get_chapter_generation_with_context_prompt(
title=project.title,
@@ -418,7 +777,8 @@ async def generate_chapter_content_stream(
chapter_title=current_chapter.title,
chapter_outline=outline.content if outline else current_chapter.summary or '暂无大纲',
style_content=style_content,
target_word_count=target_word_count
target_word_count=target_word_count,
memory_context=memory_context
)
else:
prompt = prompt_service.get_chapter_generation_prompt(
@@ -436,7 +796,8 @@ async def generate_chapter_content_stream(
chapter_title=current_chapter.title,
chapter_outline=outline.content if outline else current_chapter.summary or '暂无大纲',
style_content=style_content,
target_word_count=target_word_count
target_word_count=target_word_count,
memory_context=memory_context
)
logger.info(f"开始AI流式创作章节 {chapter_id}")
@@ -474,8 +835,48 @@ async def generate_chapter_content_stream(
logger.info(f"成功创作章节 {chapter_id},共 {new_word_count}")
# 发送完成事件
yield f"data: {json.dumps({'type': 'done', 'message': '创作完成', 'word_count': new_word_count}, ensure_ascii=False)}\n\n"
# 创建分析任务并启动后台分析
analysis_task = AnalysisTask(
chapter_id=chapter_id,
user_id=current_user_id,
project_id=project.id,
status='pending',
progress=0
)
db_session.add(analysis_task)
await db_session.commit()
# 不需要refresh,只需要获取ID
task_id = analysis_task.id
# 启动后台分析任务
background_tasks.add_task(
analyze_chapter_background,
chapter_id=chapter_id,
user_id=current_user_id,
project_id=project.id,
task_id=task_id,
ai_service=user_ai_service
)
logger.info(f"📋 已创建分析任务: {task_id}")
# 发送完成事件(包含分析任务ID
completion_data = {
'type': 'done',
'message': '创作完成',
'word_count': new_word_count,
'analysis_task_id': task_id
}
yield f"data: {json.dumps(completion_data, ensure_ascii=False)}\n\n"
# 发送分析排队事件
analysis_queued_data = {
'type': 'analysis_queued',
'task_id': task_id,
'message': '章节分析已加入队列'
}
yield f"data: {json.dumps(analysis_queued_data, ensure_ascii=False)}\n\n"
break # 退出async for db_session循环
@@ -527,3 +928,308 @@ async def generate_chapter_content_stream(
"X-Accel-Buffering": "no"
}
)
@router.get("/{chapter_id}/analysis/status", summary="查询章节分析任务状态")
async def get_analysis_task_status(
chapter_id: str,
db: AsyncSession = Depends(get_db)
):
"""
查询指定章节的最新分析任务状态
返回:
- task_id: 任务ID
- status: pending/running/completed/failed
- progress: 0-100
- error_message: 错误信息(如果失败)
- created_at: 创建时间
- completed_at: 完成时间
"""
# 获取该章节最新的分析任务
result = await db.execute(
select(AnalysisTask)
.where(AnalysisTask.chapter_id == chapter_id)
.order_by(AnalysisTask.created_at.desc())
.limit(1)
)
task = result.scalar_one_or_none()
if not task:
raise HTTPException(status_code=404, detail="未找到分析任务")
return {
"task_id": task.id,
"chapter_id": task.chapter_id,
"status": task.status,
"progress": task.progress,
"error_message": task.error_message,
"created_at": task.created_at.isoformat() if task.created_at else None,
"started_at": task.started_at.isoformat() if task.started_at else None,
"completed_at": task.completed_at.isoformat() if task.completed_at else None
}
@router.get("/{chapter_id}/analysis", summary="获取章节分析结果")
async def get_chapter_analysis(
chapter_id: str,
db: AsyncSession = Depends(get_db)
):
"""
获取章节的完整分析结果
返回:
- analysis_data: 完整的分析数据(JSON)
- summary: 分析摘要文本
- memories: 提取的记忆列表
- created_at: 分析时间
"""
# 获取分析结果
analysis_result = await db.execute(
select(PlotAnalysis)
.where(PlotAnalysis.chapter_id == chapter_id)
.order_by(PlotAnalysis.created_at.desc())
.limit(1)
)
analysis = analysis_result.scalar_one_or_none()
if not analysis:
raise HTTPException(status_code=404, detail="该章节暂无分析结果")
# 获取相关记忆
memories_result = await db.execute(
select(StoryMemory)
.where(StoryMemory.chapter_id == chapter_id)
.order_by(StoryMemory.importance_score.desc())
)
memories = memories_result.scalars().all()
return {
"chapter_id": chapter_id,
"analysis": analysis.to_dict(), # 使用to_dict()方法
"memories": [
{
"id": mem.id,
"type": mem.memory_type,
"title": mem.title,
"content": mem.content,
"importance": mem.importance_score,
"tags": mem.tags,
"is_foreshadow": mem.is_foreshadow,
"position": mem.chapter_position,
"related_characters": mem.related_characters
}
for mem in memories
],
"created_at": analysis.created_at.isoformat() if analysis.created_at else None
}
@router.get("/{chapter_id}/annotations", summary="获取章节标注数据")
async def get_chapter_annotations(
chapter_id: str,
db: AsyncSession = Depends(get_db)
):
"""
获取章节的标注数据(用于前端展示标注)
返回格式化的标注列表,包含精确位置信息
适用于章节内容的可视化标注展示
"""
# 获取章节
chapter_result = await db.execute(
select(Chapter).where(Chapter.id == chapter_id)
)
chapter = chapter_result.scalar_one_or_none()
if not chapter:
raise HTTPException(status_code=404, detail="章节不存在")
# 获取分析结果
analysis_result = await db.execute(
select(PlotAnalysis)
.where(PlotAnalysis.chapter_id == chapter_id)
.order_by(PlotAnalysis.created_at.desc())
.limit(1)
)
analysis = analysis_result.scalar_one_or_none()
# 获取记忆
memories_result = await db.execute(
select(StoryMemory)
.where(StoryMemory.chapter_id == chapter_id)
.order_by(StoryMemory.importance_score.desc())
)
memories = memories_result.scalars().all()
# 构建标注数据
annotations = []
for mem in memories:
# 优先从数据库读取位置信息
position = mem.chapter_position if mem.chapter_position is not None else -1
length = mem.text_length if hasattr(mem, 'text_length') and mem.text_length is not None else 0
metadata_extra = {}
# 如果数据库中没有位置信息,尝试从分析数据中重新计算
if position == -1 and analysis and chapter.content:
# 根据记忆类型从分析数据中查找对应项
if mem.memory_type == 'hook' and analysis.hooks:
for hook in analysis.hooks:
# 通过标题或内容匹配
if mem.title and hook.get('type') in mem.title:
keyword = hook.get('keyword', '')
if keyword:
pos = chapter.content.find(keyword)
if pos != -1:
position = pos
length = len(keyword)
metadata_extra["strength"] = hook.get('strength', 5)
metadata_extra["position_desc"] = hook.get('position', '')
break
elif mem.memory_type == 'foreshadow' and analysis.foreshadows:
for foreshadow in analysis.foreshadows:
if foreshadow.get('content') in mem.content:
keyword = foreshadow.get('keyword', '')
if keyword:
pos = chapter.content.find(keyword)
if pos != -1:
position = pos
length = len(keyword)
metadata_extra["foreshadow_type"] = foreshadow.get('type', 'planted')
metadata_extra["strength"] = foreshadow.get('strength', 5)
break
elif mem.memory_type == 'plot_point' and analysis.plot_points:
for plot_point in analysis.plot_points:
if plot_point.get('content') in mem.content:
keyword = plot_point.get('keyword', '')
if keyword:
pos = chapter.content.find(keyword)
if pos != -1:
position = pos
length = len(keyword)
break
else:
# 如果数据库有位置,也从分析数据中提取额外的元数据
if analysis:
if mem.memory_type == 'hook' and analysis.hooks:
for hook in analysis.hooks:
if mem.title and hook.get('type') in mem.title:
metadata_extra["strength"] = hook.get('strength', 5)
metadata_extra["position_desc"] = hook.get('position', '')
break
elif mem.memory_type == 'foreshadow' and analysis.foreshadows:
for foreshadow in analysis.foreshadows:
if foreshadow.get('content') in mem.content:
metadata_extra["foreshadow_type"] = foreshadow.get('type', 'planted')
metadata_extra["strength"] = foreshadow.get('strength', 5)
break
annotation = {
"id": mem.id,
"type": mem.memory_type,
"title": mem.title,
"content": mem.content,
"importance": mem.importance_score or 0.5,
"position": position,
"length": length,
"tags": mem.tags or [],
"metadata": {
"is_foreshadow": mem.is_foreshadow,
"related_characters": mem.related_characters or [],
"related_locations": mem.related_locations or [],
**metadata_extra
}
}
annotations.append(annotation)
return {
"chapter_id": chapter_id,
"chapter_number": chapter.chapter_number,
"title": chapter.title,
"word_count": chapter.word_count or 0,
"annotations": annotations,
"has_analysis": analysis is not None,
"summary": {
"total_annotations": len(annotations),
"hooks": len([a for a in annotations if a["type"] == "hook"]),
"foreshadows": len([a for a in annotations if a["type"] == "foreshadow"]),
"plot_points": len([a for a in annotations if a["type"] == "plot_point"]),
"character_events": len([a for a in annotations if a["type"] == "character_event"])
}
}
@router.post("/{chapter_id}/analyze", summary="手动触发章节分析")
async def trigger_chapter_analysis(
chapter_id: str,
request: Request,
background_tasks: BackgroundTasks,
db: AsyncSession = Depends(get_db),
user_ai_service: AIService = Depends(get_user_ai_service)
):
"""
手动触发章节分析(用于重新分析或分析旧章节)
"""
# 从请求中获取用户ID
user_id = getattr(request.state, "user_id", None)
if not user_id:
raise HTTPException(status_code=401, detail="未登录")
# 验证章节存在
chapter_result = await db.execute(
select(Chapter).where(Chapter.id == chapter_id)
)
chapter = chapter_result.scalar_one_or_none()
if not chapter:
raise HTTPException(status_code=404, detail="章节不存在")
if not chapter.content or chapter.content.strip() == "":
raise HTTPException(status_code=400, detail="章节内容为空,无法分析")
# 获取项目信息
project_result = await db.execute(
select(Project).where(Project.id == chapter.project_id)
)
project = project_result.scalar_one_or_none()
if not project:
raise HTTPException(status_code=404, detail="项目不存在")
# 创建分析任务
analysis_task = AnalysisTask(
chapter_id=chapter_id,
user_id=user_id,
project_id=project.id,
status='pending',
progress=0
)
db.add(analysis_task)
await db.commit()
# 注意:不需要refresh,因为我们只需要id,而id在commit后已经生成
task_id = analysis_task.id
# 启动后台分析任务
background_tasks.add_task(
analyze_chapter_background,
chapter_id=chapter_id,
user_id=user_id,
project_id=project.id,
task_id=task_id,
ai_service=user_ai_service
)
logger.info(f"📋 手动触发分析任务: {task_id}")
return {
"task_id": task_id,
"chapter_id": chapter_id,
"status": "pending",
"message": "分析任务已创建并加入队列"
}
+383
View File
@@ -0,0 +1,383 @@
"""记忆管理API - 提供记忆的查询、分析等接口"""
from fastapi import APIRouter, Depends, HTTPException, Request
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, and_, desc
from typing import List, Optional
from app.database import get_db
from app.models.memory import StoryMemory, PlotAnalysis
from app.models.chapter import Chapter
from app.services.memory_service import memory_service
from app.services.plot_analyzer import get_plot_analyzer
from app.services.ai_service import create_user_ai_service
from app.models.settings import Settings
from app.logger import get_logger
import uuid
logger = get_logger(__name__)
router = APIRouter(prefix="/api/memories", tags=["memories"])
@router.post("/projects/{project_id}/analyze-chapter/{chapter_id}")
async def analyze_chapter(
project_id: str,
chapter_id: str,
request: Request,
db: AsyncSession = Depends(get_db)
):
"""
分析章节并生成记忆
对指定章节进行剧情分析,提取钩子、伏笔、情节点等,并存入记忆系统
"""
try:
user_id = request.state.user_id
# 获取章节内容
result = await db.execute(
select(Chapter).where(
and_(
Chapter.id == chapter_id,
Chapter.project_id == project_id
)
)
)
chapter = result.scalar_one_or_none()
if not chapter:
raise HTTPException(status_code=404, detail="章节不存在")
if not chapter.content:
raise HTTPException(status_code=400, detail="章节内容为空,无法分析")
# 获取用户AI设置
settings_result = await db.execute(select(Settings))
settings = settings_result.scalar_one_or_none()
if not settings:
raise HTTPException(status_code=400, detail="请先配置AI设置")
# 创建AI服务
ai_service = create_user_ai_service(
api_provider=settings.api_provider,
api_key=settings.api_key,
api_base_url=settings.api_base_url,
model_name=settings.model_name,
temperature=settings.temperature,
max_tokens=settings.max_tokens
)
# 执行剧情分析
analyzer = get_plot_analyzer(ai_service)
analysis_result = await analyzer.analyze_chapter(
chapter_number=chapter.chapter_number,
title=chapter.title,
content=chapter.content,
word_count=chapter.word_count or len(chapter.content)
)
if not analysis_result:
raise HTTPException(status_code=500, detail="剧情分析失败")
# 保存分析结果到数据库
plot_analysis = PlotAnalysis(
id=str(uuid.uuid4()),
project_id=project_id,
chapter_id=chapter_id,
plot_stage=analysis_result.get('plot_stage'),
conflict_level=analysis_result.get('conflict', {}).get('level'),
conflict_types=analysis_result.get('conflict', {}).get('types'),
emotional_tone=analysis_result.get('emotional_arc', {}).get('primary_emotion'),
emotional_intensity=analysis_result.get('emotional_arc', {}).get('intensity', 0) / 10,
emotional_curve=analysis_result.get('emotional_arc'),
hooks=analysis_result.get('hooks'),
hooks_count=len(analysis_result.get('hooks', [])),
hooks_avg_strength=sum(h.get('strength', 0) for h in analysis_result.get('hooks', [])) / max(len(analysis_result.get('hooks', [])), 1),
foreshadows=analysis_result.get('foreshadows'),
foreshadows_planted=sum(1 for f in analysis_result.get('foreshadows', []) if f.get('type') == 'planted'),
foreshadows_resolved=sum(1 for f in analysis_result.get('foreshadows', []) if f.get('type') == 'resolved'),
plot_points=analysis_result.get('plot_points'),
plot_points_count=len(analysis_result.get('plot_points', [])),
character_states=analysis_result.get('character_states'),
scenes=analysis_result.get('scenes'),
pacing=analysis_result.get('pacing'),
dialogue_ratio=analysis_result.get('dialogue_ratio'),
description_ratio=analysis_result.get('description_ratio'),
overall_quality_score=analysis_result.get('scores', {}).get('overall'),
pacing_score=analysis_result.get('scores', {}).get('pacing'),
engagement_score=analysis_result.get('scores', {}).get('engagement'),
coherence_score=analysis_result.get('scores', {}).get('coherence'),
analysis_report=analyzer.generate_analysis_summary(analysis_result),
suggestions=analysis_result.get('suggestions'),
word_count=chapter.word_count
)
# 检查是否已存在分析记录
existing = await db.execute(
select(PlotAnalysis).where(PlotAnalysis.chapter_id == chapter_id)
)
if existing.scalar_one_or_none():
# 删除旧记录
await db.execute(
select(PlotAnalysis).where(PlotAnalysis.chapter_id == chapter_id)
)
await db.delete(existing.scalar_one())
db.add(plot_analysis)
await db.commit()
# 从分析结果中提取记忆片段
memories_data = analyzer.extract_memories_from_analysis(
analysis_result,
chapter_id,
chapter.chapter_number
)
# 保存记忆到数据库和向量库
saved_count = 0
for mem_data in memories_data:
memory_id = str(uuid.uuid4())
# 保存到关系数据库
memory = StoryMemory(
id=memory_id,
project_id=project_id,
chapter_id=chapter_id,
memory_type=mem_data['type'],
title=mem_data.get('title', ''),
content=mem_data['content'],
story_timeline=chapter.chapter_number,
vector_id=memory_id,
**mem_data['metadata']
)
db.add(memory)
# 保存到向量库
await memory_service.add_memory(
user_id=user_id,
project_id=project_id,
memory_id=memory_id,
content=mem_data['content'],
memory_type=mem_data['type'],
metadata=mem_data['metadata']
)
saved_count += 1
await db.commit()
logger.info(f"✅ 章节分析完成: 保存{saved_count}条记忆")
return {
"success": True,
"message": f"分析完成,提取了{saved_count}条记忆",
"analysis": plot_analysis.to_dict(),
"memories_count": saved_count
}
except HTTPException:
raise
except Exception as e:
logger.error(f"❌ 章节分析失败: {str(e)}")
await db.rollback()
raise HTTPException(status_code=500, detail=f"分析失败: {str(e)}")
@router.get("/projects/{project_id}/memories")
async def get_project_memories(
project_id: str,
request: Request,
memory_type: Optional[str] = None,
chapter_id: Optional[str] = None,
limit: int = 50,
db: AsyncSession = Depends(get_db)
):
"""获取项目的记忆列表"""
try:
user_id = request.state.user_id
# 构建查询
query = select(StoryMemory).where(StoryMemory.project_id == project_id)
if memory_type:
query = query.where(StoryMemory.memory_type == memory_type)
if chapter_id:
query = query.where(StoryMemory.chapter_id == chapter_id)
query = query.order_by(desc(StoryMemory.importance_score), desc(StoryMemory.created_at)).limit(limit)
result = await db.execute(query)
memories = result.scalars().all()
return {
"success": True,
"memories": [mem.to_dict() for mem in memories],
"total": len(memories)
}
except Exception as e:
logger.error(f"❌ 获取记忆失败: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@router.get("/projects/{project_id}/analysis/{chapter_id}")
async def get_chapter_analysis(
project_id: str,
chapter_id: str,
db: AsyncSession = Depends(get_db)
):
"""获取章节的剧情分析"""
try:
result = await db.execute(
select(PlotAnalysis).where(
and_(
PlotAnalysis.project_id == project_id,
PlotAnalysis.chapter_id == chapter_id
)
)
)
analysis = result.scalar_one_or_none()
if not analysis:
raise HTTPException(status_code=404, detail="该章节还未进行分析")
return {
"success": True,
"analysis": analysis.to_dict()
}
except HTTPException:
raise
except Exception as e:
logger.error(f"❌ 获取分析失败: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/projects/{project_id}/search")
async def search_memories(
project_id: str,
request: Request,
query: str,
memory_types: Optional[List[str]] = None,
limit: int = 10,
min_importance: float = 0.0
):
"""语义搜索项目记忆"""
try:
user_id = request.state.user_id
memories = await memory_service.search_memories(
user_id=user_id,
project_id=project_id,
query=query,
memory_types=memory_types,
limit=limit,
min_importance=min_importance
)
return {
"success": True,
"query": query,
"memories": memories,
"total": len(memories)
}
except Exception as e:
logger.error(f"❌ 搜索记忆失败: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@router.get("/projects/{project_id}/foreshadows")
async def get_unresolved_foreshadows(
project_id: str,
request: Request,
current_chapter: int,
db: AsyncSession = Depends(get_db)
):
"""获取未完结的伏笔"""
try:
user_id = request.state.user_id
# 从向量库搜索
foreshadows = await memory_service.find_unresolved_foreshadows(
user_id=user_id,
project_id=project_id,
current_chapter=current_chapter
)
return {
"success": True,
"foreshadows": foreshadows,
"total": len(foreshadows)
}
except Exception as e:
logger.error(f"❌ 获取伏笔失败: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@router.get("/projects/{project_id}/stats")
async def get_memory_stats(
project_id: str,
request: Request
):
"""获取记忆统计信息"""
try:
user_id = request.state.user_id
stats = await memory_service.get_memory_stats(
user_id=user_id,
project_id=project_id
)
return {
"success": True,
"stats": stats
}
except Exception as e:
logger.error(f"❌ 获取统计失败: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@router.delete("/projects/{project_id}/chapters/{chapter_id}/memories")
async def delete_chapter_memories(
project_id: str,
chapter_id: str,
request: Request,
db: AsyncSession = Depends(get_db)
):
"""删除章节的所有记忆"""
try:
user_id = request.state.user_id
# 从数据库删除
result = await db.execute(
select(StoryMemory).where(
and_(
StoryMemory.project_id == project_id,
StoryMemory.chapter_id == chapter_id
)
)
)
memories = result.scalars().all()
for memory in memories:
await db.delete(memory)
# 从向量库删除
await memory_service.delete_chapter_memories(
user_id=user_id,
project_id=project_id,
chapter_id=chapter_id
)
await db.commit()
return {
"success": True,
"message": f"已删除{len(memories)}条记忆"
}
except Exception as e:
logger.error(f"❌ 删除记忆失败: {str(e)}")
await db.rollback()
raise HTTPException(status_code=500, detail=str(e))
+62 -12
View File
@@ -1,5 +1,5 @@
"""大纲管理API"""
from fastapi import APIRouter, Depends, HTTPException
from fastapi import APIRouter, Depends, HTTPException, Request
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, func, delete
from typing import List, AsyncGenerator, Dict, Any
@@ -21,6 +21,7 @@ from app.schemas.outline import (
)
from app.services.ai_service import AIService
from app.services.prompt_service import prompt_service
from app.services.memory_service import memory_service
from app.logger import get_logger
from app.api.settings import get_user_ai_service
from app.utils.sse_response import SSEResponse, create_sse_response
@@ -328,6 +329,7 @@ async def reorder_outlines(
@router.post("/generate", response_model=OutlineListResponse, summary="AI生成/续写大纲")
async def generate_outline(
request: OutlineGenerateRequest,
http_request: Request,
db: AsyncSession = Depends(get_db),
user_ai_service: AIService = Depends(get_user_ai_service)
):
@@ -377,8 +379,10 @@ async def generate_outline(
detail="续写模式需要已有大纲,当前项目没有大纲"
)
# 获取用户ID用于记忆检索
user_id = getattr(http_request.state, "user_id", "system")
return await _continue_outline(
request, project, existing_outlines, db, user_ai_service
request, project, existing_outlines, db, user_ai_service, user_id
)
else:
@@ -478,9 +482,10 @@ async def _continue_outline(
project: Project,
existing_outlines: List[Outline],
db: AsyncSession,
user_ai_service: AIService
user_ai_service: AIService,
user_id: str = "system"
) -> OutlineListResponse:
"""续写大纲 - 分批生成,每批5章"""
"""续写大纲 - 分批生成,每批5章(记忆增强版)"""
logger.info(f"续写大纲 - 项目: {project.id}, 已有: {len(existing_outlines)}")
# 分析已有大纲
@@ -545,7 +550,25 @@ async def _continue_outline(
for o in latest_outlines
])
# 使用标准续写提示词模板
# 🧠 构建记忆增强上下文(仅续写模式需要)
memory_context = None
try:
logger.info(f"🧠 为第{batch_num + 1}批构建记忆上下文...")
# 使用最近一章的大纲作为查询
query_outline = recent_outlines[-1].content if recent_outlines else ""
memory_context = await memory_service.build_context_for_generation(
user_id=user_id,
project_id=project.id,
current_chapter=current_start_chapter,
chapter_outline=query_outline,
character_names=[c.name for c in characters] if characters else None
)
logger.info(f"✅ 记忆上下文构建完成: {memory_context['stats']}")
except Exception as e:
logger.warning(f"⚠️ 记忆上下文构建失败,继续不使用记忆: {str(e)}")
memory_context = None
# 使用标准续写提示词模板(支持记忆增强)
prompt = prompt_service.get_outline_continue_prompt(
title=project.title,
theme=request.theme or project.theme or "未设定",
@@ -563,7 +586,8 @@ async def _continue_outline(
plot_stage_instruction=stage_instruction,
start_chapter=current_start_chapter,
story_direction=request.story_direction or "自然延续",
requirements=request.requirements or ""
requirements=request.requirements or "",
memory_context=memory_context
)
# 调用AI生成当前批次
@@ -834,9 +858,10 @@ async def new_outline_generator(
async def continue_outline_generator(
data: Dict[str, Any],
db: AsyncSession,
user_ai_service: AIService
user_ai_service: AIService,
user_id: str = "system"
) -> AsyncGenerator[str, None]:
"""大纲续写SSE生成器 - 分批生成,推送进度"""
"""大纲续写SSE生成器 - 分批生成,推送进度(记忆增强版)"""
db_committed = False
try:
yield await SSEResponse.send_progress("开始续写大纲...", 5)
@@ -940,12 +965,32 @@ async def continue_outline_generator(
for o in latest_outlines
])
# 🧠 构建记忆增强上下文
memory_context = None
try:
yield await SSEResponse.send_progress(
f"🧠 构建记忆上下文...",
batch_progress + 3
)
query_outline = recent_outlines[-1].content if recent_outlines else ""
memory_context = await memory_service.build_context_for_generation(
user_id=user_id,
project_id=project_id,
current_chapter=current_start_chapter,
chapter_outline=query_outline,
character_names=[c.name for c in characters] if characters else None
)
logger.info(f"✅ 记忆上下文: {memory_context['stats']}")
except Exception as e:
logger.warning(f"⚠️ 记忆上下文构建失败: {str(e)}")
memory_context = None
yield await SSEResponse.send_progress(
f"🤖 调用AI生成第{str(batch_num + 1)}批...",
f" 调用AI生成第{str(batch_num + 1)}批...",
batch_progress + 5
)
# 使用标准续写提示词模板
# 使用标准续写提示词模板(支持记忆增强)
prompt = prompt_service.get_outline_continue_prompt(
title=project.title,
theme=data.get("theme") or project.theme or "未设定",
@@ -963,7 +1008,8 @@ async def continue_outline_generator(
plot_stage_instruction=stage_instruction,
start_chapter=current_start_chapter,
story_direction=data.get("story_direction", "自然延续"),
requirements=data.get("requirements", "")
requirements=data.get("requirements", ""),
memory_context=memory_context
)
# 调用AI生成当前批次
@@ -1062,6 +1108,7 @@ async def continue_outline_generator(
@router.post("/generate-stream", summary="AI生成/续写大纲(SSE流式)")
async def generate_outline_stream(
data: Dict[str, Any],
request: Request,
db: AsyncSession = Depends(get_db),
user_ai_service: AIService = Depends(get_user_ai_service)
):
@@ -1111,6 +1158,9 @@ async def generate_outline_stream(
mode = "continue" if existing_outlines else "new"
logger.info(f"自动判断模式:{'续写' if existing_outlines else '新建'}")
# 获取用户ID
user_id = getattr(request.state, "user_id", "system")
# 根据模式选择生成器
if mode == "new":
return create_sse_response(new_outline_generator(data, db, user_ai_service))
@@ -1120,7 +1170,7 @@ async def generate_outline_stream(
status_code=400,
detail="续写模式需要已有大纲,当前项目没有大纲"
)
return create_sse_response(continue_outline_generator(data, db, user_ai_service))
return create_sse_response(continue_outline_generator(data, db, user_ai_service, user_id))
else:
raise HTTPException(
status_code=400,
+174 -2
View File
@@ -1,9 +1,11 @@
"""项目管理API"""
from fastapi import APIRouter, Depends, HTTPException
from fastapi import APIRouter, Depends, HTTPException, UploadFile, File
from fastapi.responses import Response
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, func, delete
from typing import List
import json
from urllib.parse import quote
from app.database import get_db
from app.models.project import Project
from app.models.character import Character
@@ -17,6 +19,12 @@ from app.schemas.project import (
ProjectResponse,
ProjectListResponse
)
from app.schemas.import_export import (
ExportOptions,
ImportValidationResult,
ImportResult
)
from app.services.import_export_service import ImportExportService
from app.logger import get_logger
from app.utils.data_consistency import (
run_full_data_consistency_check,
@@ -412,4 +420,168 @@ async def fix_project_member_counts(
raise
except Exception as e:
logger.error(f"修复成员计数失败: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail=f"修复失败: {str(e)}")
raise HTTPException(status_code=500, detail=f"修复失败: {str(e)}")
@router.post("/{project_id}/export-data", summary="导出项目数据为JSON")
async def export_project_data(
project_id: str,
options: ExportOptions,
db: AsyncSession = Depends(get_db)
):
"""
导出项目完整数据为JSON格式
Args:
project_id: 项目ID
options: 导出选项
Returns:
JSON文件下载
"""
try:
logger.info(f"开始导出项目数据: {project_id}")
# 检查项目是否存在
result = await db.execute(
select(Project).where(Project.id == project_id)
)
project = result.scalar_one_or_none()
if not project:
logger.warning(f"项目不存在: {project_id}")
raise HTTPException(status_code=404, detail="项目不存在")
# 导出数据
export_data = await ImportExportService.export_project(
project_id=project_id,
db=db,
include_generation_history=options.include_generation_history,
include_writing_styles=options.include_writing_styles
)
# 转换为JSON
json_content = export_data.model_dump_json(indent=2, exclude_none=True, by_alias=True)
# 生成文件名
safe_title = "".join(c for c in project.title if c.isalnum() or c in (' ', '-', '_'))
from datetime import datetime
date_str = datetime.now().strftime("%Y%m%d")
filename = f"project_{safe_title}_{date_str}.json"
encoded_filename = quote(filename)
logger.info(f"项目数据导出成功: {filename}")
return Response(
content=json_content.encode('utf-8'),
media_type="application/json; charset=utf-8",
headers={
"Content-Disposition": f"attachment; filename*=UTF-8''{encoded_filename}",
"Content-Type": "application/json; charset=utf-8"
}
)
except HTTPException:
raise
except Exception as e:
logger.error(f"导出项目数据失败: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail=f"导出失败: {str(e)}")
@router.post("/validate-import", response_model=ImportValidationResult, summary="验证导入文件")
async def validate_import_file(
file: UploadFile = File(...)
):
"""
验证导入文件的格式和内容
Args:
file: 上传的JSON文件
Returns:
验证结果
"""
try:
logger.info(f"验证导入文件: {file.filename}")
# 检查文件类型
if not file.filename.endswith('.json'):
raise HTTPException(status_code=400, detail="只支持JSON格式文件")
# 读取文件内容
content = await file.read()
# 检查文件大小(50MB限制)
max_size = 50 * 1024 * 1024 # 50MB
if len(content) > max_size:
raise HTTPException(status_code=413, detail="文件大小超过50MB限制")
# 解析JSON
try:
data = json.loads(content.decode('utf-8'))
except json.JSONDecodeError as e:
raise HTTPException(status_code=400, detail=f"无效的JSON格式: {str(e)}")
# 验证数据
validation_result = ImportExportService.validate_import_data(data)
logger.info(f"文件验证完成: valid={validation_result.valid}")
return validation_result
except HTTPException:
raise
except Exception as e:
logger.error(f"验证导入文件失败: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail=f"验证失败: {str(e)}")
@router.post("/import", response_model=ImportResult, summary="导入项目")
async def import_project(
file: UploadFile = File(...),
db: AsyncSession = Depends(get_db)
):
"""
导入项目数据(创建新项目)
Args:
file: 上传的JSON文件
Returns:
导入结果
"""
try:
logger.info(f"开始导入项目: {file.filename}")
# 检查文件类型
if not file.filename.endswith('.json'):
raise HTTPException(status_code=400, detail="只支持JSON格式文件")
# 读取文件内容
content = await file.read()
# 检查文件大小
max_size = 50 * 1024 * 1024 # 50MB
if len(content) > max_size:
raise HTTPException(status_code=413, detail="文件大小超过50MB限制")
# 解析JSON
try:
data = json.loads(content.decode('utf-8'))
except json.JSONDecodeError as e:
raise HTTPException(status_code=400, detail=f"无效的JSON格式: {str(e)}")
# 导入数据
import_result = await ImportExportService.import_project(data, db)
if import_result.success:
logger.info(f"项目导入成功: {import_result.project_id}")
else:
logger.warning(f"项目导入失败: {import_result.message}")
return import_result
except HTTPException:
raise
except Exception as e:
logger.error(f"导入项目失败: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail=f"导入失败: {str(e)}")
+9
View File
@@ -15,6 +15,15 @@ logger = get_logger(__name__)
# 创建基类
Base = declarative_base()
# 导入所有模型,确保 Base.metadata 能够发现它们
# 这必须在 Base 创建之后、init_db 之前导入
from app.models import (
Project, Outline, Character, Chapter, GenerationHistory,
Settings, WritingStyle, ProjectDefaultStyle,
RelationshipType, CharacterRelationship, Organization, OrganizationMember,
StoryMemory, PlotAnalysis, AnalysisTask
)
# 引擎缓存:每个用户一个引擎
_engine_cache: Dict[str, Any] = {}
+2 -1
View File
@@ -114,7 +114,7 @@ async def db_session_stats():
from app.api import (
projects, outlines, characters, chapters,
wizard_stream, relationships, organizations,
auth, users, settings, writing_styles
auth, users, settings, writing_styles, memories
)
app.include_router(auth.router, prefix="/api")
@@ -129,6 +129,7 @@ app.include_router(chapters.router, prefix="/api")
app.include_router(relationships.router, prefix="/api")
app.include_router(organizations.router, prefix="/api")
app.include_router(writing_styles.router, prefix="/api")
app.include_router(memories.router) # 记忆管理API (已包含/api前缀)
static_dir = Path(__file__).parent.parent / "static"
if static_dir.exists():
+5
View File
@@ -13,6 +13,8 @@ from app.models.relationship import (
Organization,
OrganizationMember
)
from app.models.memory import StoryMemory, PlotAnalysis
from app.models.analysis_task import AnalysisTask
__all__ = [
"Project",
@@ -27,4 +29,7 @@ __all__ = [
"CharacterRelationship",
"Organization",
"OrganizationMember",
"StoryMemory",
"PlotAnalysis",
"AnalysisTask",
]
+38
View File
@@ -0,0 +1,38 @@
"""分析任务模型 - 追踪异步章节分析任务状态"""
from sqlalchemy import Column, String, Integer, Text, DateTime, ForeignKey, Index
from sqlalchemy.sql import func
from app.database import Base
import uuid
class AnalysisTask(Base):
"""
分析任务表 - 追踪异步分析任务的执行状态
状态流转: pending -> running -> completed/failed
"""
__tablename__ = "analysis_tasks"
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()), comment="任务ID")
chapter_id = Column(String(36), ForeignKey('chapters.id', ondelete='CASCADE'), nullable=False, comment="章节ID")
user_id = Column(String(50), nullable=False, comment="用户ID")
project_id = Column(String(36), nullable=False, comment="项目ID")
# 任务状态
status = Column(String(20), nullable=False, default='pending', comment="任务状态: pending/running/completed/failed")
progress = Column(Integer, default=0, comment="进度 0-100")
error_message = Column(Text, nullable=True, comment="错误信息")
# 时间戳
created_at = Column(DateTime, server_default=func.now(), comment="创建时间")
started_at = Column(DateTime, nullable=True, comment="开始执行时间")
completed_at = Column(DateTime, nullable=True, comment="完成时间")
# 索引优化查询
__table_args__ = (
Index('idx_chapter_id_created', 'chapter_id', 'created_at'),
Index('idx_status', 'status'),
)
def __repr__(self):
return f"<AnalysisTask(id={self.id[:8]}..., chapter_id={self.chapter_id[:8]}..., status={self.status})>"
+200
View File
@@ -0,0 +1,200 @@
"""长期记忆数据模型 - 支持向量检索和剧情分析"""
from sqlalchemy import Column, String, Text, Integer, DateTime, ForeignKey, Float, JSON, Boolean
from sqlalchemy.sql import func
from app.database import Base
import uuid
class StoryMemory(Base):
"""故事记忆表 - 存储结构化的故事片段和元数据"""
__tablename__ = "story_memories"
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
project_id = Column(String(36), ForeignKey("projects.id", ondelete="CASCADE"), nullable=False, index=True)
chapter_id = Column(String(36), ForeignKey("chapters.id", ondelete="CASCADE"), nullable=True, index=True)
# 记忆类型
memory_type = Column(String(50), nullable=False, index=True, comment="""
记忆类型:
- plot_point: 情节点
- character_event: 角色事件
- world_detail: 世界观细节
- hook: 钩子(悬念/冲突)
- foreshadow: 伏笔
- dialogue: 重要对话
- scene: 场景描写
""")
# 记忆内容
title = Column(String(200), comment="记忆标题/简述")
content = Column(Text, nullable=False, comment="记忆内容摘要(100-500字)")
full_context = Column(Text, comment="完整上下文(可选,用于详细记录)")
# 关联信息
related_characters = Column(JSON, comment="涉及角色ID列表: ['char_id_1', 'char_id_2']")
related_locations = Column(JSON, comment="涉及地点列表: ['地点1', '地点2']")
tags = Column(JSON, comment="标签列表: ['悬念', '转折', '伏笔', '高潮']")
# 重要性评分 (用于过滤和排序)
importance_score = Column(Float, default=0.5, comment="重要性评分 0.0-1.0")
# 时间线定位
story_timeline = Column(Integer, nullable=False, index=True, comment="故事时间线位置(章节序号)")
chapter_position = Column(Integer, default=0, comment="章节内位置(字符位置)")
text_length = Column(Integer, default=0, comment="文本长度(字符数)")
# 伏笔相关字段
is_foreshadow = Column(Integer, default=0, comment="伏笔状态: 0=普通记忆, 1=已埋下伏笔, 2=伏笔已回收")
foreshadow_resolved_at = Column(String(36), ForeignKey("chapters.id", ondelete="SET NULL"), comment="伏笔回收的章节ID")
foreshadow_strength = Column(Float, comment="伏笔强度 0.0-1.0")
# 向量数据库关联
vector_id = Column(String(100), unique=True, comment="向量数据库中的唯一ID")
embedding_model = Column(String(100), default="paraphrase-multilingual-MiniLM-L12-v2", comment="使用的embedding模型")
created_at = Column(DateTime, server_default=func.now(), comment="创建时间")
updated_at = Column(DateTime, server_default=func.now(), onupdate=func.now(), comment="更新时间")
def __repr__(self):
return f"<StoryMemory(id={self.id[:8]}, type={self.memory_type}, title={self.title})>"
def to_dict(self):
"""转换为字典格式"""
return {
"id": self.id,
"project_id": self.project_id,
"chapter_id": self.chapter_id,
"memory_type": self.memory_type,
"title": self.title,
"content": self.content,
"related_characters": self.related_characters,
"related_locations": self.related_locations,
"tags": self.tags,
"importance_score": self.importance_score,
"story_timeline": self.story_timeline,
"is_foreshadow": self.is_foreshadow,
"created_at": self.created_at.isoformat() if self.created_at else None
}
class PlotAnalysis(Base):
"""剧情分析表 - 存储AI分析的章节结构和剧情元素"""
__tablename__ = "plot_analysis"
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
project_id = Column(String(36), ForeignKey("projects.id", ondelete="CASCADE"), nullable=False, index=True)
chapter_id = Column(String(36), ForeignKey("chapters.id", ondelete="CASCADE"), nullable=False, unique=True, index=True)
# 剧情结构分析
plot_stage = Column(String(50), comment="剧情阶段: 开端/发展/高潮/结局/过渡")
conflict_level = Column(Integer, comment="冲突强度 1-10")
conflict_types = Column(JSON, comment="冲突类型列表: ['人与人', '人与己', '人与环境']")
# 情感分析
emotional_tone = Column(String(100), comment="主导情感: 紧张/温馨/悲伤/激昂/平静")
emotional_intensity = Column(Float, comment="情感强度 0.0-1.0")
emotional_curve = Column(JSON, comment="情感曲线: {start: 0.3, middle: 0.7, end: 0.5}")
# 钩子分析 (Hook Analysis)
hooks = Column(JSON, comment="""钩子列表 - 吸引读者的元素: [
{
"type": "悬念|情感|冲突|认知",
"content": "具体内容",
"strength": 8,
"position": "开头|中段|结尾"
}
]""")
hooks_count = Column(Integer, default=0, comment="钩子数量")
hooks_avg_strength = Column(Float, comment="钩子平均强度")
# 伏笔分析 (Foreshadowing Analysis)
foreshadows = Column(JSON, comment="""伏笔列表: [
{
"content": "伏笔内容",
"type": "planted|resolved",
"strength": 7,
"subtlety": 8,
"reference_chapter": 3
}
]""")
foreshadows_planted = Column(Integer, default=0, comment="本章埋下的伏笔数量")
foreshadows_resolved = Column(Integer, default=0, comment="本章回收的伏笔数量")
# 关键情节点 (Plot Points)
plot_points = Column(JSON, comment="""情节点列表: [
{
"content": "情节点描述",
"importance": 0.9,
"type": "revelation|conflict|resolution|transition",
"impact": "对故事的影响描述"
}
]""")
plot_points_count = Column(Integer, default=0, comment="情节点数量")
# 角色状态追踪 (Character State Tracking)
character_states = Column(JSON, comment="""角色状态变化: [
{
"character_id": "xxx",
"character_name": "张三",
"state_before": "犹豫不决",
"state_after": "坚定信念",
"psychological_change": "内心描述",
"key_event": "触发事件",
"relationship_changes": {"李四": "关系变化"}
}
]""")
# 场景和氛围
scenes = Column(JSON, comment="场景列表: [{location: '地点', atmosphere: '氛围', duration: '时长'}]")
pacing = Column(String(50), comment="节奏: slow|moderate|fast|varied")
# 质量评分
overall_quality_score = Column(Float, comment="整体质量评分 0.0-10.0")
pacing_score = Column(Float, comment="节奏评分 0.0-10.0")
engagement_score = Column(Float, comment="吸引力评分 0.0-10.0")
coherence_score = Column(Float, comment="连贯性评分 0.0-10.0")
# 文本分析报告
analysis_report = Column(Text, comment="完整的文字分析报告")
suggestions = Column(JSON, comment="改进建议列表: ['建议1', '建议2']")
# 统计信息
word_count = Column(Integer, comment="章节字数")
dialogue_ratio = Column(Float, comment="对话占比 0.0-1.0")
description_ratio = Column(Float, comment="描写占比 0.0-1.0")
created_at = Column(DateTime, server_default=func.now(), comment="分析时间")
def __repr__(self):
return f"<PlotAnalysis(chapter_id={self.chapter_id[:8]}, stage={self.plot_stage}, quality={self.overall_quality_score})>"
def to_dict(self):
"""转换为字典格式"""
return {
"id": self.id,
"chapter_id": self.chapter_id,
"plot_stage": self.plot_stage,
"conflict_level": self.conflict_level,
"conflict_types": self.conflict_types or [],
"emotional_tone": self.emotional_tone,
"emotional_intensity": self.emotional_intensity or 0.0,
"hooks": self.hooks or [],
"hooks_count": self.hooks_count or 0,
"foreshadows": self.foreshadows or [],
"foreshadows_planted": self.foreshadows_planted or 0,
"foreshadows_resolved": self.foreshadows_resolved or 0,
"plot_points": self.plot_points or [],
"plot_points_count": self.plot_points_count or 0,
"character_states": self.character_states or [],
"scenes": self.scenes or [],
"pacing": self.pacing,
"overall_quality_score": self.overall_quality_score or 0.0,
"pacing_score": self.pacing_score or 0.0,
"engagement_score": self.engagement_score or 0.0,
"coherence_score": self.coherence_score or 0.0,
"analysis_report": self.analysis_report,
"suggestions": self.suggestions or [],
"dialogue_ratio": self.dialogue_ratio or 0.0,
"description_ratio": self.description_ratio or 0.0,
"created_at": self.created_at.isoformat() if self.created_at else None
}
+136
View File
@@ -0,0 +1,136 @@
"""导入导出相关的Pydantic模型"""
from pydantic import BaseModel, Field
from typing import Optional, List, Dict, Any
from datetime import datetime
class ExportOptions(BaseModel):
"""导出选项"""
include_generation_history: bool = Field(False, description="是否包含生成历史")
include_writing_styles: bool = Field(True, description="是否包含写作风格")
class ChapterExportData(BaseModel):
"""章节导出数据"""
title: str
content: Optional[str] = None
summary: Optional[str] = None
chapter_number: int
word_count: int = 0
status: str = "draft"
created_at: Optional[str] = None
class CharacterExportData(BaseModel):
"""角色导出数据"""
name: str
age: Optional[str] = None
gender: Optional[str] = None
is_organization: bool = False
role_type: Optional[str] = None
personality: Optional[str] = None
background: Optional[str] = None
appearance: Optional[str] = None
traits: Optional[List[str]] = None
organization_type: Optional[str] = None
organization_purpose: Optional[str] = None
created_at: Optional[str] = None
class OutlineExportData(BaseModel):
"""大纲导出数据"""
title: str
content: Optional[str] = None
structure: Optional[str] = None
order_index: Optional[int] = None
created_at: Optional[str] = None
class RelationshipExportData(BaseModel):
"""关系导出数据"""
source_name: str
target_name: str
relationship_name: Optional[str] = None
intimacy_level: int = 50
status: str = "active"
description: Optional[str] = None
started_at: Optional[str] = None
class OrganizationExportData(BaseModel):
"""组织详情导出数据"""
character_name: str
parent_org_name: Optional[str] = None
power_level: int = 50
member_count: int = 0
location: Optional[str] = None
motto: Optional[str] = None
color: Optional[str] = None
class OrganizationMemberExportData(BaseModel):
"""组织成员导出数据"""
organization_name: str
character_name: str
position: str
rank: int = 0
status: str = "active"
joined_at: Optional[str] = None
loyalty: int = 50
contribution: int = 0
notes: Optional[str] = None
class WritingStyleExportData(BaseModel):
"""写作风格导出数据"""
name: str
style_type: str
preset_id: Optional[str] = None
description: Optional[str] = None
prompt_content: str
order_index: int = 0
class GenerationHistoryExportData(BaseModel):
"""生成历史导出数据"""
chapter_title: Optional[str] = None
prompt: Optional[str] = None
generated_content: Optional[str] = None
model: Optional[str] = None
tokens_used: Optional[int] = None
generation_time: Optional[float] = None
created_at: Optional[str] = None
class ProjectExportData(BaseModel):
"""项目完整导出数据"""
version: str = "1.0.0"
export_time: str
project: Dict[str, Any]
chapters: List[ChapterExportData] = []
characters: List[CharacterExportData] = []
outlines: List[OutlineExportData] = []
relationships: List[RelationshipExportData] = []
organizations: List[OrganizationExportData] = []
organization_members: List[OrganizationMemberExportData] = []
writing_styles: List[WritingStyleExportData] = []
generation_history: List[GenerationHistoryExportData] = []
class ImportValidationResult(BaseModel):
"""导入验证结果"""
valid: bool
version: str
project_name: Optional[str] = None
statistics: Dict[str, int] = {}
errors: List[str] = []
warnings: List[str] = []
class ImportResult(BaseModel):
"""导入结果"""
success: bool
project_id: Optional[str] = None
message: str
statistics: Dict[str, int] = {}
warnings: List[str] = []
@@ -0,0 +1,769 @@
"""导入导出服务"""
import json
from datetime import datetime
from typing import Dict, List, Optional, Tuple
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
from app.models.project import Project
from app.models.chapter import Chapter
from app.models.character import Character
from app.models.outline import Outline
from app.models.relationship import CharacterRelationship, Organization, OrganizationMember
from app.models.writing_style import WritingStyle
from app.models.generation_history import GenerationHistory
from app.schemas.import_export import (
ProjectExportData,
ChapterExportData,
CharacterExportData,
OutlineExportData,
RelationshipExportData,
OrganizationExportData,
OrganizationMemberExportData,
WritingStyleExportData,
GenerationHistoryExportData,
ImportValidationResult,
ImportResult
)
from app.logger import get_logger
logger = get_logger(__name__)
class ImportExportService:
"""导入导出服务类"""
SUPPORTED_VERSION = "1.0.0"
@staticmethod
async def export_project(
project_id: str,
db: AsyncSession,
include_generation_history: bool = False,
include_writing_styles: bool = True
) -> ProjectExportData:
"""
导出项目完整数据
Args:
project_id: 项目ID
db: 数据库会话
include_generation_history: 是否包含生成历史
include_writing_styles: 是否包含写作风格
Returns:
ProjectExportData: 导出的项目数据
"""
logger.info(f"开始导出项目: {project_id}")
# 获取项目基本信息
result = await db.execute(select(Project).where(Project.id == project_id))
project = result.scalar_one_or_none()
if not project:
raise ValueError(f"项目不存在: {project_id}")
# 项目基本信息
project_data = {
"title": project.title,
"description": project.description,
"theme": project.theme,
"genre": project.genre,
"target_words": project.target_words,
"current_words": project.current_words,
"status": project.status,
"world_time_period": project.world_time_period,
"world_location": project.world_location,
"world_atmosphere": project.world_atmosphere,
"world_rules": project.world_rules,
"chapter_count": project.chapter_count,
"narrative_perspective": project.narrative_perspective,
"character_count": project.character_count,
"created_at": project.created_at.isoformat() if project.created_at else None,
}
# 导出章节
chapters = await ImportExportService._export_chapters(project_id, db)
logger.info(f"导出章节数: {len(chapters)}")
# 导出角色
characters = await ImportExportService._export_characters(project_id, db)
logger.info(f"导出角色数: {len(characters)}")
# 导出大纲
outlines = await ImportExportService._export_outlines(project_id, db)
logger.info(f"导出大纲数: {len(outlines)}")
# 导出关系
relationships = await ImportExportService._export_relationships(project_id, db)
logger.info(f"导出关系数: {len(relationships)}")
# 导出组织详情
organizations = await ImportExportService._export_organizations(project_id, db)
logger.info(f"导出组织数: {len(organizations)}")
# 导出组织成员
org_members = await ImportExportService._export_organization_members(project_id, db)
logger.info(f"导出组织成员数: {len(org_members)}")
# 导出写作风格(可选)
writing_styles = []
if include_writing_styles:
writing_styles = await ImportExportService._export_writing_styles(project_id, db)
logger.info(f"导出写作风格数: {len(writing_styles)}")
# 导出生成历史(可选)
generation_history = []
if include_generation_history:
generation_history = await ImportExportService._export_generation_history(project_id, db)
logger.info(f"导出生成历史数: {len(generation_history)}")
export_data = ProjectExportData(
version=ImportExportService.SUPPORTED_VERSION,
export_time=datetime.utcnow().isoformat(),
project=project_data,
chapters=chapters,
characters=characters,
outlines=outlines,
relationships=relationships,
organizations=organizations,
organization_members=org_members,
writing_styles=writing_styles,
generation_history=generation_history
)
logger.info(f"项目导出完成: {project_id}")
return export_data
@staticmethod
async def _export_chapters(project_id: str, db: AsyncSession) -> List[ChapterExportData]:
"""导出章节"""
result = await db.execute(
select(Chapter)
.where(Chapter.project_id == project_id)
.order_by(Chapter.chapter_number)
)
chapters = result.scalars().all()
return [
ChapterExportData(
title=ch.title,
content=ch.content,
summary=ch.summary,
chapter_number=ch.chapter_number,
word_count=ch.word_count or 0,
status=ch.status,
created_at=ch.created_at.isoformat() if ch.created_at else None
)
for ch in chapters
]
@staticmethod
async def _export_characters(project_id: str, db: AsyncSession) -> List[CharacterExportData]:
"""导出角色"""
result = await db.execute(
select(Character).where(Character.project_id == project_id)
)
characters = result.scalars().all()
exported = []
for char in characters:
# 解析traits JSON
traits = None
if char.traits:
try:
traits = json.loads(char.traits) if isinstance(char.traits, str) else char.traits
except:
traits = None
exported.append(CharacterExportData(
name=char.name,
age=char.age,
gender=char.gender,
is_organization=char.is_organization or False,
role_type=char.role_type,
personality=char.personality,
background=char.background,
appearance=char.appearance,
traits=traits,
organization_type=char.organization_type,
organization_purpose=char.organization_purpose,
created_at=char.created_at.isoformat() if char.created_at else None
))
return exported
@staticmethod
async def _export_outlines(project_id: str, db: AsyncSession) -> List[OutlineExportData]:
"""导出大纲"""
result = await db.execute(
select(Outline)
.where(Outline.project_id == project_id)
.order_by(Outline.order_index)
)
outlines = result.scalars().all()
return [
OutlineExportData(
title=ol.title,
content=ol.content,
structure=ol.structure,
order_index=ol.order_index,
created_at=ol.created_at.isoformat() if ol.created_at else None
)
for ol in outlines
]
@staticmethod
async def _export_relationships(project_id: str, db: AsyncSession) -> List[RelationshipExportData]:
"""导出关系"""
result = await db.execute(
select(CharacterRelationship, Character)
.join(Character, CharacterRelationship.character_from_id == Character.id)
.where(CharacterRelationship.project_id == project_id)
)
relationships = result.all()
exported = []
for rel, char_from in relationships:
# 获取目标角色名称
target_result = await db.execute(
select(Character).where(Character.id == rel.character_to_id)
)
char_to = target_result.scalar_one_or_none()
if char_to:
exported.append(RelationshipExportData(
source_name=char_from.name,
target_name=char_to.name,
relationship_name=rel.relationship_name,
intimacy_level=rel.intimacy_level or 50,
status=rel.status or "active",
description=rel.description,
started_at=rel.started_at
))
return exported
@staticmethod
async def _export_organizations(project_id: str, db: AsyncSession) -> List[OrganizationExportData]:
"""导出组织详情"""
result = await db.execute(
select(Organization, Character)
.join(Character, Organization.character_id == Character.id)
.where(Organization.project_id == project_id)
)
organizations = result.all()
exported = []
for org, char in organizations:
# 获取父组织名称
parent_name = None
if org.parent_org_id:
parent_result = await db.execute(
select(Organization, Character)
.join(Character, Organization.character_id == Character.id)
.where(Organization.id == org.parent_org_id)
)
parent_data = parent_result.first()
if parent_data:
parent_name = parent_data[1].name
exported.append(OrganizationExportData(
character_name=char.name,
parent_org_name=parent_name,
power_level=org.power_level or 50,
member_count=org.member_count or 0,
location=org.location,
motto=org.motto,
color=org.color
))
return exported
@staticmethod
async def _export_organization_members(project_id: str, db: AsyncSession) -> List[OrganizationMemberExportData]:
"""导出组织成员"""
result = await db.execute(
select(OrganizationMember, Organization, Character)
.join(Organization, OrganizationMember.organization_id == Organization.id)
.join(Character, Organization.character_id == Character.id)
.where(Organization.project_id == project_id)
)
members = result.all()
exported = []
for member, org, org_char in members:
# 获取成员角色名称
char_result = await db.execute(
select(Character).where(Character.id == member.character_id)
)
member_char = char_result.scalar_one_or_none()
if member_char:
exported.append(OrganizationMemberExportData(
organization_name=org_char.name,
character_name=member_char.name,
position=member.position,
rank=member.rank or 0,
status=member.status or "active",
joined_at=member.joined_at,
loyalty=member.loyalty or 50,
contribution=member.contribution or 0,
notes=member.notes
))
return exported
@staticmethod
async def _export_writing_styles(project_id: str, db: AsyncSession) -> List[WritingStyleExportData]:
"""导出写作风格"""
result = await db.execute(
select(WritingStyle)
.where(WritingStyle.project_id == project_id)
.order_by(WritingStyle.order_index)
)
styles = result.scalars().all()
return [
WritingStyleExportData(
name=style.name,
style_type=style.style_type,
preset_id=style.preset_id,
description=style.description,
prompt_content=style.prompt_content,
order_index=style.order_index or 0
)
for style in styles
]
@staticmethod
async def _export_generation_history(project_id: str, db: AsyncSession) -> List[GenerationHistoryExportData]:
"""导出生成历史"""
result = await db.execute(
select(GenerationHistory, Chapter)
.outerjoin(Chapter, GenerationHistory.chapter_id == Chapter.id)
.where(GenerationHistory.project_id == project_id)
.order_by(GenerationHistory.created_at.desc())
.limit(100) # 限制最多导出100条历史记录
)
histories = result.all()
return [
GenerationHistoryExportData(
chapter_title=chapter.title if chapter else None,
prompt=history.prompt,
generated_content=history.generated_content,
model=history.model,
tokens_used=history.tokens_used,
generation_time=history.generation_time,
created_at=history.created_at.isoformat() if history.created_at else None
)
for history, chapter in histories
]
@staticmethod
def validate_import_data(data: Dict) -> ImportValidationResult:
"""
验证导入数据
Args:
data: 导入的JSON数据
Returns:
ImportValidationResult: 验证结果
"""
errors = []
warnings = []
statistics = {}
# 检查版本
version = data.get("version", "")
if not version:
errors.append("缺少版本信息")
elif version != ImportExportService.SUPPORTED_VERSION:
warnings.append(f"版本不匹配: 导入文件版本为 {version}, 当前支持版本为 {ImportExportService.SUPPORTED_VERSION}")
# 检查必需字段
if "project" not in data:
errors.append("缺少项目信息")
else:
project = data["project"]
if not project.get("title"):
errors.append("项目标题不能为空")
# 统计数据
statistics = {
"chapters": len(data.get("chapters", [])),
"characters": len(data.get("characters", [])),
"outlines": len(data.get("outlines", [])),
"relationships": len(data.get("relationships", [])),
"organizations": len(data.get("organizations", [])),
"organization_members": len(data.get("organization_members", [])),
"writing_styles": len(data.get("writing_styles", [])),
"generation_history": len(data.get("generation_history", []))
}
# 检查数据完整性
if statistics["chapters"] == 0:
warnings.append("项目没有章节数据")
if statistics["characters"] == 0:
warnings.append("项目没有角色数据")
project_name = data.get("project", {}).get("title", "未知项目")
return ImportValidationResult(
valid=len(errors) == 0,
version=version,
project_name=project_name,
statistics=statistics,
errors=errors,
warnings=warnings
)
@staticmethod
async def import_project(
data: Dict,
db: AsyncSession
) -> ImportResult:
"""
导入项目数据(创建新项目)
Args:
data: 导入的JSON数据
db: 数据库会话
Returns:
ImportResult: 导入结果
"""
warnings = []
statistics = {}
try:
# 验证数据
validation = ImportExportService.validate_import_data(data)
if not validation.valid:
return ImportResult(
success=False,
message=f"数据验证失败: {', '.join(validation.errors)}",
statistics={},
warnings=validation.warnings
)
warnings.extend(validation.warnings)
logger.info(f"开始导入项目: {validation.project_name}")
# 创建项目
project_data = data["project"]
new_project = Project(
title=project_data.get("title"),
description=project_data.get("description"),
theme=project_data.get("theme"),
genre=project_data.get("genre"),
target_words=project_data.get("target_words"),
status=project_data.get("status", "planning"),
world_time_period=project_data.get("world_time_period"),
world_location=project_data.get("world_location"),
world_atmosphere=project_data.get("world_atmosphere"),
world_rules=project_data.get("world_rules"),
chapter_count=project_data.get("chapter_count"),
narrative_perspective=project_data.get("narrative_perspective"),
character_count=project_data.get("character_count"),
current_words=project_data.get("current_words", 0), # 保留原项目的字数
wizard_step=4, # 导入的项目设置为向导完成状态
wizard_status="completed" # 标记向导已完成
)
db.add(new_project)
await db.flush() # 获取project_id
logger.info(f"创建项目成功: {new_project.id}")
# 导入章节
chapters_count = await ImportExportService._import_chapters(
new_project.id, data.get("chapters", []), db
)
statistics["chapters"] = chapters_count
logger.info(f"导入章节数: {chapters_count}")
# 导入角色(包括组织)
char_mapping = await ImportExportService._import_characters(
new_project.id, data.get("characters", []), db
)
statistics["characters"] = len(char_mapping)
logger.info(f"导入角色数: {len(char_mapping)}")
# 导入大纲
outlines_count = await ImportExportService._import_outlines(
new_project.id, data.get("outlines", []), db
)
statistics["outlines"] = outlines_count
logger.info(f"导入大纲数: {outlines_count}")
# 导入关系
relationships_count = await ImportExportService._import_relationships(
new_project.id, data.get("relationships", []), char_mapping, db
)
statistics["relationships"] = relationships_count
logger.info(f"导入关系数: {relationships_count}")
# 导入组织详情
org_mapping = await ImportExportService._import_organizations(
new_project.id, data.get("organizations", []), char_mapping, db
)
statistics["organizations"] = len(org_mapping)
logger.info(f"导入组织数: {len(org_mapping)}")
# 导入组织成员
org_members_count = await ImportExportService._import_organization_members(
data.get("organization_members", []), char_mapping, org_mapping, db
)
statistics["organization_members"] = org_members_count
logger.info(f"导入组织成员数: {org_members_count}")
# 导入写作风格
styles_count = await ImportExportService._import_writing_styles(
new_project.id, data.get("writing_styles", []), db
)
statistics["writing_styles"] = styles_count
logger.info(f"导入写作风格数: {styles_count}")
# 提交事务
await db.commit()
logger.info(f"项目导入完成: {new_project.id}")
return ImportResult(
success=True,
project_id=new_project.id,
message="项目导入成功",
statistics=statistics,
warnings=warnings
)
except Exception as e:
await db.rollback()
logger.error(f"导入项目失败: {str(e)}", exc_info=True)
return ImportResult(
success=False,
message=f"导入失败: {str(e)}",
statistics=statistics,
warnings=warnings
)
@staticmethod
async def _import_chapters(
project_id: str,
chapters_data: List[Dict],
db: AsyncSession
) -> int:
"""导入章节"""
count = 0
for ch_data in chapters_data:
chapter = Chapter(
project_id=project_id,
title=ch_data.get("title"),
content=ch_data.get("content"),
summary=ch_data.get("summary"),
chapter_number=ch_data.get("chapter_number"),
word_count=ch_data.get("word_count", 0),
status=ch_data.get("status", "draft")
)
db.add(chapter)
count += 1
return count
@staticmethod
async def _import_characters(
project_id: str,
characters_data: List[Dict],
db: AsyncSession
) -> Dict[str, str]:
"""导入角色,返回名称到ID的映射"""
char_mapping = {}
for char_data in characters_data:
# 处理traits
traits = char_data.get("traits")
if traits and isinstance(traits, list):
traits = json.dumps(traits, ensure_ascii=False)
character = Character(
project_id=project_id,
name=char_data.get("name"),
age=char_data.get("age"),
gender=char_data.get("gender"),
is_organization=char_data.get("is_organization", False),
role_type=char_data.get("role_type"),
personality=char_data.get("personality"),
background=char_data.get("background"),
appearance=char_data.get("appearance"),
traits=traits,
organization_type=char_data.get("organization_type"),
organization_purpose=char_data.get("organization_purpose")
)
db.add(character)
await db.flush() # 获取ID
char_mapping[char_data.get("name")] = character.id
return char_mapping
@staticmethod
async def _import_outlines(
project_id: str,
outlines_data: List[Dict],
db: AsyncSession
) -> int:
"""导入大纲"""
count = 0
for ol_data in outlines_data:
outline = Outline(
project_id=project_id,
title=ol_data.get("title"),
content=ol_data.get("content"),
structure=ol_data.get("structure"),
order_index=ol_data.get("order_index")
)
db.add(outline)
count += 1
return count
@staticmethod
async def _import_relationships(
project_id: str,
relationships_data: List[Dict],
char_mapping: Dict[str, str],
db: AsyncSession
) -> int:
"""导入关系"""
count = 0
for rel_data in relationships_data:
source_name = rel_data.get("source_name")
target_name = rel_data.get("target_name")
# 查找角色ID
source_id = char_mapping.get(source_name)
target_id = char_mapping.get(target_name)
if source_id and target_id:
relationship = CharacterRelationship(
project_id=project_id,
character_from_id=source_id,
character_to_id=target_id,
relationship_name=rel_data.get("relationship_name"),
intimacy_level=rel_data.get("intimacy_level", 50),
status=rel_data.get("status", "active"),
description=rel_data.get("description"),
started_at=rel_data.get("started_at")
)
db.add(relationship)
count += 1
return count
@staticmethod
async def _import_organizations(
project_id: str,
organizations_data: List[Dict],
char_mapping: Dict[str, str],
db: AsyncSession
) -> Dict[str, str]:
"""导入组织详情,返回名称到ID的映射"""
org_mapping = {}
# 第一遍:创建所有组织(不设置父组织)
temp_orgs = []
for org_data in organizations_data:
char_name = org_data.get("character_name")
char_id = char_mapping.get(char_name)
if char_id:
organization = Organization(
project_id=project_id,
character_id=char_id,
power_level=org_data.get("power_level", 50),
member_count=org_data.get("member_count", 0),
location=org_data.get("location"),
motto=org_data.get("motto"),
color=org_data.get("color")
)
db.add(organization)
temp_orgs.append((organization, org_data.get("parent_org_name")))
await db.flush() # 获取所有组织的ID
# 建立名称到ID的映射
for org, _ in temp_orgs:
# 通过character_id查找角色名
result = await db.execute(
select(Character).where(Character.id == org.character_id)
)
char = result.scalar_one_or_none()
if char:
org_mapping[char.name] = org.id
# 第二遍:设置父组织关系
for org, parent_name in temp_orgs:
if parent_name:
parent_id = org_mapping.get(parent_name)
if parent_id:
org.parent_org_id = parent_id
return org_mapping
@staticmethod
async def _import_organization_members(
org_members_data: List[Dict],
char_mapping: Dict[str, str],
org_mapping: Dict[str, str],
db: AsyncSession
) -> int:
"""导入组织成员"""
count = 0
for member_data in org_members_data:
org_name = member_data.get("organization_name")
char_name = member_data.get("character_name")
org_id = org_mapping.get(org_name)
char_id = char_mapping.get(char_name)
if org_id and char_id:
member = OrganizationMember(
organization_id=org_id,
character_id=char_id,
position=member_data.get("position"),
rank=member_data.get("rank", 0),
status=member_data.get("status", "active"),
joined_at=member_data.get("joined_at"),
loyalty=member_data.get("loyalty", 50),
contribution=member_data.get("contribution", 0),
notes=member_data.get("notes")
)
db.add(member)
count += 1
return count
@staticmethod
async def _import_writing_styles(
project_id: str,
styles_data: List[Dict],
db: AsyncSession
) -> int:
"""导入写作风格"""
count = 0
for style_data in styles_data:
style = WritingStyle(
project_id=project_id,
name=style_data.get("name"),
style_type=style_data.get("style_type"),
preset_id=style_data.get("preset_id"),
description=style_data.get("description"),
prompt_content=style_data.get("prompt_content"),
order_index=style_data.get("order_index", 0)
)
db.add(style)
count += 1
return count
+739
View File
@@ -0,0 +1,739 @@
"""向量记忆服务 - 基于ChromaDB实现长期记忆和语义检索"""
import chromadb
from sentence_transformers import SentenceTransformer
from typing import List, Dict, Any, Optional
import json
from datetime import datetime
from app.logger import get_logger
import os
import hashlib
logger = get_logger(__name__)
# 配置离线模式,避免联网检查
os.environ['TRANSFORMERS_OFFLINE'] = '1'
os.environ['HF_DATASETS_OFFLINE'] = '1'
class MemoryService:
"""向量记忆管理服务 - 实现语义检索和长期记忆"""
_instance = None
_initialized = False
def __new__(cls):
"""单例模式"""
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self):
"""初始化ChromaDB和Embedding模型"""
if self._initialized:
return
try:
# 确保数据目录存在
chroma_dir = "data/chroma_db"
os.makedirs(chroma_dir, exist_ok=True)
# 初始化ChromaDB客户端(使用新API - PersistentClient)
self.client = chromadb.PersistentClient(path=chroma_dir)
# 初始化多语言embedding模型(支持中文)
logger.info("🔄 正在加载Embedding模型...")
# 确保模型缓存目录存在
model_cache_dir = 'data/models'
os.makedirs(model_cache_dir, exist_ok=True)
try:
# 优先使用本地缓存的模型
# cache_folder会让模型优先从本地加载,只有不存在时才联网下载
self.embedding_model = SentenceTransformer(
'paraphrase-multilingual-MiniLM-L12-v2',
cache_folder=model_cache_dir,
device='cpu' # 明确指定使用CPU
)
logger.info("✅ Embedding模型加载成功")
except Exception as e:
logger.warning(f"⚠️ 无法加载多语言模型: {str(e)}")
logger.info("🔄 尝试使用备用模型...")
try:
# 降级到更小的模型作为备选
self.embedding_model = SentenceTransformer(
'all-MiniLM-L6-v2',
cache_folder=model_cache_dir,
device='cpu'
)
logger.info("✅ 使用备用Embedding模型")
except Exception as e2:
logger.error(f"❌ 所有模型加载失败: {str(e2)}")
logger.error("💡 模型首次使用需要联网下载(约420MB)")
logger.error(" 或手动下载模型文件到 data/models 目录")
raise RuntimeError("无法加载任何Embedding模型")
self._initialized = True
logger.info("✅ MemoryService初始化成功")
logger.info(f" - ChromaDB目录: {chroma_dir}")
logger.info(f" - Embedding模型: paraphrase-multilingual-MiniLM-L12-v2")
except Exception as e:
logger.error(f"❌ MemoryService初始化失败: {str(e)}")
raise
def get_collection(self, user_id: str, project_id: str):
"""
获取或创建项目的记忆集合
每个用户的每个项目有独立的collection,实现数据隔离
Args:
user_id: 用户ID
project_id: 项目ID
Returns:
ChromaDB Collection对象
"""
# ChromaDB collection命名规则:
# 1. 3-63字符(最重要!)
# 2. 开头和结尾必须是字母或数字
# 3. 只能包含字母、数字、下划线或短横线
# 4. 不能包含连续的点(..)
# 5. 不能是有效的IPv4地址
# 使用SHA256哈希压缩ID长度,确保不超过63字符
# 格式: u_{user_hash}_p_{project_hash} (约30字符)
user_hash = hashlib.sha256(user_id.encode()).hexdigest()[:8]
project_hash = hashlib.sha256(project_id.encode()).hexdigest()[:8]
collection_name = f"u_{user_hash}_p_{project_hash}"
try:
return self.client.get_or_create_collection(
name=collection_name,
metadata={
"user_id": user_id,
"project_id": project_id,
"created_at": datetime.now().isoformat()
}
)
except Exception as e:
logger.error(f"❌ 获取collection失败: {str(e)}")
raise
async def add_memory(
self,
user_id: str,
project_id: str,
memory_id: str,
content: str,
memory_type: str,
metadata: Dict[str, Any]
) -> bool:
"""
添加记忆到向量数据库
Args:
user_id: 用户ID
project_id: 项目ID
memory_id: 记忆唯一ID
content: 记忆内容(将被转换为向量)
memory_type: 记忆类型
metadata: 附加元数据
Returns:
是否添加成功
"""
try:
collection = self.get_collection(user_id, project_id)
# 生成文本的向量表示
embedding = self.embedding_model.encode(content).tolist()
# 准备元数据(ChromaDB要求所有值为基础类型)
chroma_metadata = {
"memory_type": memory_type,
"chapter_id": str(metadata.get("chapter_id", "")),
"chapter_number": int(metadata.get("chapter_number", 0)),
"importance": float(metadata.get("importance_score", 0.5)),
"tags": json.dumps(metadata.get("tags", []), ensure_ascii=False),
"title": str(metadata.get("title", ""))[:200], # 限制长度
"is_foreshadow": int(metadata.get("is_foreshadow", 0)),
"created_at": datetime.now().isoformat()
}
# 添加相关角色信息
if metadata.get("related_characters"):
chroma_metadata["related_characters"] = json.dumps(
metadata["related_characters"],
ensure_ascii=False
)
# 存储到向量库
collection.add(
ids=[memory_id],
embeddings=[embedding],
documents=[content],
metadatas=[chroma_metadata]
)
logger.info(f"✅ 记忆已添加: {memory_id[:8]}... (类型:{memory_type}, 重要性:{chroma_metadata['importance']})")
return True
except Exception as e:
logger.error(f"❌ 添加记忆失败: {str(e)}")
return False
async def batch_add_memories(
self,
user_id: str,
project_id: str,
memories: List[Dict[str, Any]]
) -> int:
"""
批量添加记忆(性能更好)
Args:
user_id: 用户ID
project_id: 项目ID
memories: 记忆列表,每个包含id、content、type、metadata
Returns:
成功添加的数量
"""
if not memories:
return 0
try:
collection = self.get_collection(user_id, project_id)
ids = []
documents = []
metadatas = []
embeddings = []
# 批量准备数据
for mem in memories:
ids.append(mem['id'])
documents.append(mem['content'])
# 生成embedding
embedding = self.embedding_model.encode(mem['content']).tolist()
embeddings.append(embedding)
# 准备元数据
metadata = mem.get('metadata', {})
chroma_metadata = {
"memory_type": mem['type'],
"chapter_id": str(metadata.get("chapter_id", "")),
"chapter_number": int(metadata.get("chapter_number", 0)),
"importance": float(metadata.get("importance_score", 0.5)),
"tags": json.dumps(metadata.get("tags", []), ensure_ascii=False),
"title": str(metadata.get("title", ""))[:200],
"is_foreshadow": int(metadata.get("is_foreshadow", 0)),
"created_at": datetime.now().isoformat()
}
metadatas.append(chroma_metadata)
# 批量添加
collection.add(
ids=ids,
embeddings=embeddings,
documents=documents,
metadatas=metadatas
)
logger.info(f"✅ 批量添加记忆成功: {len(memories)}")
return len(memories)
except Exception as e:
logger.error(f"❌ 批量添加记忆失败: {str(e)}")
return 0
async def search_memories(
self,
user_id: str,
project_id: str,
query: str,
memory_types: Optional[List[str]] = None,
limit: int = 10,
min_importance: float = 0.0,
chapter_range: Optional[tuple] = None
) -> List[Dict[str, Any]]:
"""
语义搜索相关记忆
Args:
user_id: 用户ID
project_id: 项目ID
query: 查询文本(会被转换为向量进行相似度搜索)
memory_types: 过滤特定类型的记忆
limit: 返回结果数量
min_importance: 最低重要性阈值
chapter_range: 章节范围 (start, end)
Returns:
相关记忆列表,按相似度排序
"""
try:
collection = self.get_collection(user_id, project_id)
# 生成查询向量
query_embedding = self.embedding_model.encode(query).tolist()
# 构建过滤条件 - ChromaDB要求使用$and组合多个条件
where_filter = None
conditions = []
if memory_types:
conditions.append({"memory_type": {"$in": memory_types}})
if min_importance > 0:
conditions.append({"importance": {"$gte": min_importance}})
if chapter_range:
conditions.append({"chapter_number": {"$gte": chapter_range[0]}})
conditions.append({"chapter_number": {"$lte": chapter_range[1]}})
# 根据条件数量选择合适的格式
if len(conditions) == 0:
where_filter = None
elif len(conditions) == 1:
where_filter = conditions[0]
else:
where_filter = {"$and": conditions}
# 执行向量相似度搜索
results = collection.query(
query_embeddings=[query_embedding],
n_results=limit,
where=where_filter
)
# 格式化结果
memories = []
if results['ids'] and results['ids'][0]:
for i in range(len(results['ids'][0])):
memories.append({
"id": results['ids'][0][i],
"content": results['documents'][0][i],
"metadata": results['metadatas'][0][i],
"similarity": 1 - results['distances'][0][i] if 'distances' in results else 1.0,
"distance": results['distances'][0][i] if 'distances' in results else 0.0
})
logger.info(f"🔍 语义搜索完成: 查询='{query[:30]}...', 找到{len(memories)}条记忆")
return memories
except Exception as e:
logger.error(f"❌ 搜索记忆失败: {str(e)}")
return []
async def get_recent_memories(
self,
user_id: str,
project_id: str,
current_chapter: int,
recent_count: int = 3,
min_importance: float = 0.5
) -> List[Dict[str, Any]]:
"""
获取最近几章的重要记忆(用于保持连贯性)
Args:
user_id: 用户ID
project_id: 项目ID
current_chapter: 当前章节号
recent_count: 获取最近几章
min_importance: 最低重要性阈值
Returns:
最近章节的记忆列表,按重要性排序
"""
try:
collection = self.get_collection(user_id, project_id)
# 计算章节范围
start_chapter = max(1, current_chapter - recent_count)
# 获取最近章节的记忆
results = collection.get(
where={
"$and": [
{"chapter_number": {"$gte": start_chapter}},
{"chapter_number": {"$lt": current_chapter}},
{"importance": {"$gte": min_importance}}
]
},
limit=100 # 先获取足够多的记忆
)
memories = []
if results['ids']:
for i in range(len(results['ids'])):
memories.append({
"id": results['ids'][i],
"content": results['documents'][i],
"metadata": results['metadatas'][i]
})
# 按重要性和章节号排序
memories.sort(
key=lambda x: (float(x['metadata'].get('importance', 0)),
int(x['metadata'].get('chapter_number', 0))),
reverse=True
)
# 返回最重要的前N条
top_memories = memories[:20]
logger.info(f"📚 获取最近记忆: 章节{start_chapter}-{current_chapter-1}, 找到{len(top_memories)}")
return top_memories
except Exception as e:
logger.error(f"❌ 获取最近记忆失败: {str(e)}")
return []
async def find_unresolved_foreshadows(
self,
user_id: str,
project_id: str,
current_chapter: int
) -> List[Dict[str, Any]]:
"""
查找未完结的伏笔
Args:
user_id: 用户ID
project_id: 项目ID
current_chapter: 当前章节号
Returns:
未完结伏笔列表
"""
try:
collection = self.get_collection(user_id, project_id)
# 查找伏笔状态为1(已埋下但未回收)的记忆
results = collection.get(
where={
"$and": [
{"is_foreshadow": 1},
{"chapter_number": {"$lt": current_chapter}}
]
},
limit=50
)
foreshadows = []
if results['ids']:
for i in range(len(results['ids'])):
foreshadows.append({
"id": results['ids'][i],
"content": results['documents'][i],
"metadata": results['metadatas'][i]
})
# 按重要性排序
foreshadows.sort(
key=lambda x: float(x['metadata'].get('importance', 0)),
reverse=True
)
logger.info(f"🎣 找到未完结伏笔: {len(foreshadows)}")
return foreshadows
except Exception as e:
logger.error(f"❌ 查找伏笔失败: {str(e)}")
return []
async def build_context_for_generation(
self,
user_id: str,
project_id: str,
current_chapter: int,
chapter_outline: str,
character_names: List[str] = None
) -> Dict[str, Any]:
"""
为章节生成构建智能上下文
这是核心功能: 结合多种检索策略,为AI生成提供最相关的记忆
Args:
user_id: 用户ID
project_id: 项目ID
current_chapter: 当前章节号
chapter_outline: 本章大纲
character_names: 涉及的角色名列表
Returns:
包含各种上下文信息的字典
"""
logger.info(f"🧠 开始构建章节{current_chapter}的智能上下文...")
# 1. 获取最近章节上下文(时间连续性)
recent = await self.get_recent_memories(
user_id, project_id, current_chapter,
recent_count=3, min_importance=0.5
)
# 2. 语义搜索相关记忆
relevant = await self.search_memories(
user_id=user_id,
project_id=project_id,
query=chapter_outline,
limit=10,
min_importance=0.4
)
# 3. 查找未完结伏笔
foreshadows = await self.find_unresolved_foreshadows(
user_id, project_id, current_chapter
)
# 4. 如果有指定角色,获取角色相关记忆
character_memories = []
if character_names:
character_query = " ".join(character_names) + " 角色 状态 关系"
character_memories = await self.search_memories(
user_id=user_id,
project_id=project_id,
query=character_query,
memory_types=["character_event", "plot_point"],
limit=8
)
# 5. 获取重要情节点
# 注意:ChromaDB的where条件需要特殊处理,不能同时使用多个顶层条件
try:
plot_points = await self.search_memories(
user_id=user_id,
project_id=project_id,
query="重要 转折 高潮 关键",
memory_types=["plot_point", "hook"],
limit=5,
min_importance=0.7
)
except Exception as e:
logger.error(f"❌ 搜索记忆失败: {str(e)}")
# 降级处理:分别查询
plot_points = []
try:
plot_points = await self.search_memories(
user_id=user_id,
project_id=project_id,
query="重要 转折 高潮 关键",
memory_types=["plot_point", "hook"],
limit=5
)
except Exception as e2:
logger.warning(f"⚠️ 降级查询也失败: {str(e2)}")
plot_points = []
context = {
"recent_context": self._format_memories(recent, "最近章节记忆"),
"relevant_memories": self._format_memories(relevant, "语义相关记忆"),
"character_states": self._format_memories(character_memories, "角色相关记忆"),
"foreshadows": self._format_memories(foreshadows[:5], "未完结伏笔"),
"plot_points": self._format_memories(plot_points, "重要情节点"),
"stats": {
"recent_count": len(recent),
"relevant_count": len(relevant),
"character_count": len(character_memories),
"foreshadow_count": len(foreshadows),
"plot_point_count": len(plot_points)
}
}
logger.info(f"✅ 上下文构建完成: 最近{len(recent)}条, 相关{len(relevant)}条, 伏笔{len(foreshadows)}")
return context
def _format_memories(self, memories: List[Dict], section_title: str = "记忆") -> str:
"""
格式化记忆列表为文本
Args:
memories: 记忆列表
section_title: 章节标题
Returns:
格式化后的文本
"""
if not memories:
return f"{section_title}\n暂无相关记忆\n"
lines = [f"{section_title}"]
for i, mem in enumerate(memories, 1):
meta = mem.get('metadata', {})
chapter_num = meta.get('chapter_number', '?')
mem_type = meta.get('memory_type', '未知')
importance = float(meta.get('importance', 0.5))
title = meta.get('title', '')
content = mem['content']
# 格式: [序号] 第X章-类型(重要性) 标题: 内容
line = f"{i}. [第{chapter_num}章-{mem_type}{importance:.1f}]"
if title:
line += f" {title}: {content[:100]}"
else:
line += f" {content[:150]}"
lines.append(line)
return "\n".join(lines) + "\n"
async def delete_chapter_memories(
self,
user_id: str,
project_id: str,
chapter_id: str
) -> bool:
"""
删除指定章节的所有记忆
Args:
user_id: 用户ID
project_id: 项目ID
chapter_id: 章节ID
Returns:
是否删除成功
"""
try:
collection = self.get_collection(user_id, project_id)
# 查找该章节的所有记忆
results = collection.get(
where={"chapter_id": chapter_id}
)
if results['ids']:
# 删除这些记忆
collection.delete(ids=results['ids'])
logger.info(f"🗑️ 已删除章节{chapter_id[:8]}{len(results['ids'])}条记忆")
return True
else:
logger.info(f"️ 章节{chapter_id[:8]}没有记忆需要删除")
return True
except Exception as e:
logger.error(f"❌ 删除章节记忆失败: {str(e)}")
return False
async def update_memory(
self,
user_id: str,
project_id: str,
memory_id: str,
content: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None
) -> bool:
"""
更新记忆内容或元数据
Args:
user_id: 用户ID
project_id: 项目ID
memory_id: 记忆ID
content: 新内容(可选)
metadata: 新元数据(可选)
Returns:
是否更新成功
"""
try:
collection = self.get_collection(user_id, project_id)
update_data = {}
if content:
# 重新生成embedding
embedding = self.embedding_model.encode(content).tolist()
update_data['embeddings'] = [embedding]
update_data['documents'] = [content]
if metadata:
# 准备新的元数据
chroma_metadata = {}
for key, value in metadata.items():
if isinstance(value, (list, dict)):
chroma_metadata[key] = json.dumps(value, ensure_ascii=False)
else:
chroma_metadata[key] = value
update_data['metadatas'] = [chroma_metadata]
if update_data:
collection.update(
ids=[memory_id],
**update_data
)
logger.info(f"✅ 记忆已更新: {memory_id[:8]}...")
return True
else:
logger.warning("⚠️ 没有提供更新内容")
return False
except Exception as e:
logger.error(f"❌ 更新记忆失败: {str(e)}")
return False
async def get_memory_stats(
self,
user_id: str,
project_id: str
) -> Dict[str, Any]:
"""
获取记忆统计信息
Args:
user_id: 用户ID
project_id: 项目ID
Returns:
统计信息字典
"""
try:
collection = self.get_collection(user_id, project_id)
# 获取所有记忆
all_memories = collection.get()
if not all_memories['ids']:
return {
"total_count": 0,
"by_type": {},
"by_chapter": {},
"foreshadow_count": 0
}
# 统计各类型数量
type_counts = {}
chapter_counts = {}
foreshadow_count = 0
for i, meta in enumerate(all_memories['metadatas']):
mem_type = meta.get('memory_type', 'unknown')
chapter_num = meta.get('chapter_number', 0)
is_foreshadow = meta.get('is_foreshadow', 0)
type_counts[mem_type] = type_counts.get(mem_type, 0) + 1
chapter_counts[str(chapter_num)] = chapter_counts.get(str(chapter_num), 0) + 1
if is_foreshadow == 1:
foreshadow_count += 1
stats = {
"total_count": len(all_memories['ids']),
"by_type": type_counts,
"by_chapter": chapter_counts,
"foreshadow_count": foreshadow_count,
"foreshadow_resolved": sum(1 for m in all_memories['metadatas'] if m.get('is_foreshadow') == 2)
}
logger.info(f"📊 记忆统计: 总计{stats['total_count']}条, 伏笔{foreshadow_count}")
return stats
except Exception as e:
logger.error(f"❌ 获取统计信息失败: {str(e)}")
return {"error": str(e)}
# 创建全局实例
memory_service = MemoryService()
+559
View File
@@ -0,0 +1,559 @@
"""剧情分析服务 - 自动分析章节的钩子、伏笔、冲突等元素"""
from typing import Dict, Any, List, Optional
from app.services.ai_service import AIService
from app.logger import get_logger
import json
import re
logger = get_logger(__name__)
class PlotAnalyzer:
"""剧情分析器 - 使用AI分析章节内容"""
# AI分析提示词模板
ANALYSIS_PROMPT = """你是一位专业的小说编辑和剧情分析师。请深度分析以下章节内容:
**章节信息:**
- 章节: 第{chapter_number}
- 标题: {title}
- 字数: {word_count}
**章节内容:**
{content}
---
**分析任务:**
请从专业编辑的角度,全面分析这一章节:
### 1. 剧情钩子 (Hooks) - 吸引读者的元素
识别能够吸引读者继续阅读的关键元素:
- **悬念钩子**: 未解之谜、疑问、谜团
- **情感钩子**: 引发共鸣的情感点、触动心弦的时刻
- **冲突钩子**: 矛盾对抗、紧张局势
- **认知钩子**: 颠覆认知的信息、惊人真相
每个钩子需要:
- 类型分类
- 具体内容描述
- 强度评分(1-10)
- 出现位置(开头/中段/结尾)
- **关键词**: 【必填】从章节原文中逐字复制一段关键文本(8-25字),必须是原文中真实存在的连续文字,用于在文本中精确定位。不要概括或改写,必须原样复制!
### 2. 伏笔分析 (Foreshadowing)
- **埋下的新伏笔**: 描述内容、预期作用、隐藏程度(1-10)
- **回收的旧伏笔**: 呼应哪一章、回收效果评分
- **伏笔质量**: 巧妙性和合理性评估
- **关键词**: 【必填】从章节原文中逐字复制一段关键文本(8-25字),必须是原文中真实存在的连续文字,用于在文本中精确定位。不要概括或改写,必须原样复制!
### 3. 冲突分析 (Conflict)
- 冲突类型: 人与人/人与己/人与环境/人与社会
- 冲突各方及其立场
- 冲突强度评分(1-10)
- 冲突解决进度(0-100%)
### 4. 情感曲线 (Emotional Arc)
- 主导情绪: 紧张/温馨/悲伤/激昂/平静等
- 情感强度(1-10)
- 情绪变化轨迹描述
### 5. 角色状态追踪 (Character Development)
对每个出场角色分析:
- 心理状态变化(前→后)
- 关系变化
- 关键行动和决策
- 成长或退步
### 6. 关键情节点 (Plot Points)
列出3-5个核心情节点:
- 情节内容
- 类型(revelation/conflict/resolution/transition)
- 重要性(0.0-1.0)
- 对故事的影响
- **关键词**: 【必填】从章节原文中逐字复制一段关键文本(8-25字),必须是原文中真实存在的连续文字,用于在文本中精确定位。不要概括或改写,必须原样复制!
### 7. 场景与节奏
- 主要场景
- 叙事节奏(快/中/慢)
- 对话与描写的比例
### 8. 质量评分
- 节奏把控: 1-10分
- 吸引力: 1-10分
- 连贯性: 1-10分
- 整体质量: 1-10分
### 9. 改进建议
提供3-5条具体的改进建议
---
**输出格式(纯JSON,不要markdown标记):**
{{
"hooks": [
{{
"type": "悬念",
"content": "具体描述",
"strength": 8,
"position": "中段",
"keyword": "必须从原文逐字复制的文本片段"
}}
],
"foreshadows": [
{{
"content": "伏笔内容",
"type": "planted",
"strength": 7,
"subtlety": 8,
"reference_chapter": null,
"keyword": "必须从原文逐字复制的文本片段"
}}
],
"conflict": {{
"types": ["人与人", "人与己"],
"parties": ["主角-复仇", "反派-维护现状"],
"level": 8,
"description": "冲突描述",
"resolution_progress": 0.3
}},
"emotional_arc": {{
"primary_emotion": "紧张",
"intensity": 8,
"curve": "平静→紧张→高潮→释放",
"secondary_emotions": ["期待", "焦虑"]
}},
"character_states": [
{{
"character_name": "张三",
"state_before": "犹豫",
"state_after": "坚定",
"psychological_change": "心理变化描述",
"key_event": "触发事件",
"relationship_changes": {{"李四": "关系改善"}}
}}
],
"plot_points": [
{{
"content": "情节点描述",
"type": "revelation",
"importance": 0.9,
"impact": "推动故事发展",
"keyword": "必须从原文逐字复制的文本片段"
}}
],
"scenes": [
{{
"location": "地点",
"atmosphere": "氛围",
"duration": "时长估计"
}}
],
"pacing": "varied",
"dialogue_ratio": 0.4,
"description_ratio": 0.3,
"scores": {{
"pacing": 8,
"engagement": 9,
"coherence": 8,
"overall": 8.5
}},
"plot_stage": "发展",
"suggestions": [
"具体建议1",
"具体建议2"
]
}}
**重要提示:**
1. 每个钩子、伏笔、情节点的keyword字段是必填的,不能为空
2. keyword必须是从章节原文中逐字复制的文本,长度8-25字
3. keyword用于在前端标注文本位置,所以必须能在原文中精确找到
4. 不要使用概括性语句或改写后的文字作为keyword
只返回JSON,不要其他说明。"""
def __init__(self, ai_service: AIService):
"""
初始化剧情分析器
Args:
ai_service: AI服务实例
"""
self.ai_service = ai_service
logger.info("✅ PlotAnalyzer初始化成功")
async def analyze_chapter(
self,
chapter_number: int,
title: str,
content: str,
word_count: int
) -> Optional[Dict[str, Any]]:
"""
分析单章内容
Args:
chapter_number: 章节号
title: 章节标题
content: 章节内容
word_count: 字数
Returns:
分析结果字典,失败返回None
"""
try:
logger.info(f"🔍 开始分析第{chapter_number}章: {title}")
# 如果内容过长,截取前8000字(避免超token)
analysis_content = content[:8000] if len(content) > 8000 else content
# 构建提示词
prompt = self.ANALYSIS_PROMPT.format(
chapter_number=chapter_number,
title=title,
word_count=word_count,
content=analysis_content
)
# 调用AI进行分析
# 注意:不指定max_tokens,使用用户在设置中配置的值
logger.info(f" 调用AI分析(内容长度: {len(analysis_content)}字)...")
response = await self.ai_service.generate_text(
prompt=prompt,
temperature=0.3 # 降低温度以获得更稳定的JSON输出
)
# 解析JSON结果
analysis_result = self._parse_analysis_response(response)
if analysis_result:
logger.info(f"✅ 第{chapter_number}章分析完成")
logger.info(f" - 钩子: {len(analysis_result.get('hooks', []))}")
logger.info(f" - 伏笔: {len(analysis_result.get('foreshadows', []))}")
logger.info(f" - 情节点: {len(analysis_result.get('plot_points', []))}")
logger.info(f" - 整体评分: {analysis_result.get('scores', {}).get('overall', 'N/A')}")
return analysis_result
else:
logger.error(f"❌ 第{chapter_number}章分析失败: JSON解析错误")
return None
except Exception as e:
logger.error(f"❌ 章节分析异常: {str(e)}")
return None
def _parse_analysis_response(self, response: str) -> Optional[Dict[str, Any]]:
"""
解析AI返回的分析结果
Args:
response: AI返回的文本
Returns:
解析后的字典,失败返回None
"""
try:
# 清理响应文本
cleaned = response.strip()
# 移除可能的markdown标记
cleaned = re.sub(r'^```json\s*', '', cleaned)
cleaned = re.sub(r'^```\s*', '', cleaned)
cleaned = re.sub(r'\s*```$', '', cleaned)
# 尝试解析JSON
result = json.loads(cleaned)
# 验证必要字段
required_fields = ['hooks', 'plot_points', 'scores']
for field in required_fields:
if field not in result:
logger.warning(f"⚠️ 分析结果缺少字段: {field}")
result[field] = [] if field != 'scores' else {}
return result
except json.JSONDecodeError as e:
logger.error(f"❌ JSON解析失败: {str(e)}")
logger.error(f" 原始响应(前500字): {response[:500]}")
# 尝试提取JSON部分
json_match = re.search(r'\{[\s\S]*\}', response)
if json_match:
try:
result = json.loads(json_match.group())
logger.info("✅ 通过正则提取成功解析JSON")
return result
except:
pass
return None
except Exception as e:
logger.error(f"❌ 解析异常: {str(e)}")
return None
def extract_memories_from_analysis(
self,
analysis: Dict[str, Any],
chapter_id: str,
chapter_number: int,
chapter_content: str = ""
) -> List[Dict[str, Any]]:
"""
从分析结果中提取记忆片段
Args:
analysis: 分析结果
chapter_id: 章节ID
chapter_number: 章节号
chapter_content: 章节完整内容(用于计算位置)
Returns:
记忆片段列表
"""
memories = []
try:
# 1. 提取钩子作为记忆
for i, hook in enumerate(analysis.get('hooks', [])):
if hook.get('strength', 0) >= 6: # 只保存强度>=6的钩子
keyword = hook.get('keyword', '')
position, length = self._find_text_position(chapter_content, keyword)
logger.info(f" 钩子位置: keyword='{keyword[:30]}...', pos={position}, len={length}")
memories.append({
'type': 'hook',
'content': f"[{hook.get('type', '未知')}钩子] {hook.get('content', '')}",
'title': f"{hook.get('type', '钩子')} - {hook.get('position', '')}",
'metadata': {
'chapter_id': chapter_id,
'chapter_number': chapter_number,
'importance_score': min(hook.get('strength', 5) / 10, 1.0),
'tags': [hook.get('type', '钩子'), hook.get('position', '')],
'is_foreshadow': 0,
'keyword': keyword,
'text_position': position,
'text_length': length,
'strength': hook.get('strength', 5),
'position_desc': hook.get('position', '')
}
})
# 2. 提取伏笔作为记忆
for i, foreshadow in enumerate(analysis.get('foreshadows', [])):
is_planted = foreshadow.get('type') == 'planted'
keyword = foreshadow.get('keyword', '')
position, length = self._find_text_position(chapter_content, keyword)
logger.info(f" 伏笔位置: keyword='{keyword[:30]}...', pos={position}, len={length}")
memories.append({
'type': 'foreshadow',
'content': foreshadow.get('content', ''),
'title': f"{'埋下伏笔' if is_planted else '回收伏笔'}",
'metadata': {
'chapter_id': chapter_id,
'chapter_number': chapter_number,
'importance_score': min(foreshadow.get('strength', 5) / 10, 1.0),
'tags': ['伏笔', foreshadow.get('type', 'planted')],
'is_foreshadow': 1 if is_planted else 2,
'reference_chapter': foreshadow.get('reference_chapter'),
'keyword': keyword,
'text_position': position,
'text_length': length,
'foreshadow_type': foreshadow.get('type', 'planted'),
'strength': foreshadow.get('strength', 5)
}
})
# 3. 提取关键情节点
for i, plot_point in enumerate(analysis.get('plot_points', [])):
if plot_point.get('importance', 0) >= 0.6: # 只保存重要性>=0.6的情节点
keyword = plot_point.get('keyword', '')
position, length = self._find_text_position(chapter_content, keyword)
logger.info(f" 情节点位置: keyword='{keyword[:30]}...', pos={position}, len={length}")
memories.append({
'type': 'plot_point',
'content': f"{plot_point.get('content', '')}。影响: {plot_point.get('impact', '')}",
'title': f"情节点 - {plot_point.get('type', '未知')}",
'metadata': {
'chapter_id': chapter_id,
'chapter_number': chapter_number,
'importance_score': plot_point.get('importance', 0.5),
'tags': ['情节点', plot_point.get('type', '未知')],
'is_foreshadow': 0,
'keyword': keyword,
'text_position': position,
'text_length': length
}
})
# 4. 提取角色状态变化
for i, char_state in enumerate(analysis.get('character_states', [])):
char_name = char_state.get('character_name', '未知角色')
memories.append({
'type': 'character_event',
'content': f"{char_name}的状态变化: {char_state.get('state_before', '')}{char_state.get('state_after', '')}{char_state.get('psychological_change', '')}",
'title': f"{char_name}的变化",
'metadata': {
'chapter_id': chapter_id,
'chapter_number': chapter_number,
'importance_score': 0.7,
'tags': ['角色', char_name, '状态变化'],
'related_characters': [char_name],
'is_foreshadow': 0
}
})
# 5. 如果有重要冲突,也记录下来
conflict = analysis.get('conflict', {})
if conflict and conflict.get('level', 0) >= 7:
# 确保 parties 和 types 都是字符串列表
parties = conflict.get('parties', [])
if parties and isinstance(parties, list):
parties = [str(p) for p in parties]
types = conflict.get('types', [])
if types and isinstance(types, list):
types = [str(t) for t in types]
memories.append({
'type': 'plot_point',
'content': f"重要冲突: {conflict.get('description', '')}。冲突各方: {', '.join(parties)}",
'title': f"冲突 - 强度{conflict.get('level', 0)}",
'metadata': {
'chapter_id': chapter_id,
'chapter_number': chapter_number,
'importance_score': min(conflict.get('level', 5) / 10, 1.0),
'tags': ['冲突'] + types,
'is_foreshadow': 0
}
})
logger.info(f"📝 从分析中提取了{len(memories)}条记忆")
return memories
except Exception as e:
logger.error(f"❌ 提取记忆失败: {str(e)}")
return []
def _find_text_position(self, full_text: str, keyword: str) -> tuple[int, int]:
"""
在全文中查找关键词位置
Args:
full_text: 完整文本
keyword: 关键词
Returns:
(起始位置, 长度) 如果未找到返回(-1, 0)
"""
if not keyword or not full_text:
return (-1, 0)
try:
# 1. 精确匹配
pos = full_text.find(keyword)
if pos != -1:
return (pos, len(keyword))
# 2. 去除标点符号后匹配
import re
clean_keyword = re.sub(r'[,。!?、;:""''()《》【】]', '', keyword)
clean_text = re.sub(r'[,。!?、;:""''()《》【】]', '', full_text)
pos = clean_text.find(clean_keyword)
if pos != -1:
# 反向映射到原文位置(简化处理)
return (pos, len(clean_keyword))
# 3. 模糊匹配:查找关键词的前半部分
if len(keyword) > 10:
partial = keyword[:min(15, len(keyword))]
pos = full_text.find(partial)
if pos != -1:
return (pos, len(partial))
# 4. 未找到
logger.debug(f"未找到关键词位置: {keyword[:30]}...")
return (-1, 0)
except Exception as e:
logger.error(f"查找位置失败: {str(e)}")
return (-1, 0)
def generate_analysis_summary(self, analysis: Dict[str, Any]) -> str:
"""
生成分析摘要文本
Args:
analysis: 分析结果
Returns:
格式化的摘要文本
"""
try:
lines = ["=== 章节分析报告 ===\n"]
# 整体评分
scores = analysis.get('scores', {})
lines.append(f"【整体评分】")
lines.append(f" 整体质量: {scores.get('overall', 'N/A')}/10")
lines.append(f" 节奏把控: {scores.get('pacing', 'N/A')}/10")
lines.append(f" 吸引力: {scores.get('engagement', 'N/A')}/10")
lines.append(f" 连贯性: {scores.get('coherence', 'N/A')}/10\n")
# 剧情阶段
lines.append(f"【剧情阶段】{analysis.get('plot_stage', '未知')}\n")
# 钩子统计
hooks = analysis.get('hooks', [])
if hooks:
lines.append(f"【钩子分析】共{len(hooks)}")
for hook in hooks[:3]: # 只显示前3个
lines.append(f" • [{hook.get('type')}] {hook.get('content', '')[:50]}... (强度:{hook.get('strength', 0)})")
lines.append("")
# 伏笔统计
foreshadows = analysis.get('foreshadows', [])
if foreshadows:
planted = sum(1 for f in foreshadows if f.get('type') == 'planted')
resolved = sum(1 for f in foreshadows if f.get('type') == 'resolved')
lines.append(f"【伏笔分析】埋下{planted}个, 回收{resolved}\n")
# 冲突分析
conflict = analysis.get('conflict', {})
if conflict:
lines.append(f"【冲突分析】")
lines.append(f" 类型: {', '.join(conflict.get('types', []))}")
lines.append(f" 强度: {conflict.get('level', 0)}/10")
lines.append(f" 进度: {int(conflict.get('resolution_progress', 0) * 100)}%\n")
# 改进建议
suggestions = analysis.get('suggestions', [])
if suggestions:
lines.append(f"【改进建议】")
for i, sug in enumerate(suggestions, 1):
lines.append(f" {i}. {sug}")
return "\n".join(lines)
except Exception as e:
logger.error(f"❌ 生成摘要失败: {str(e)}")
return "分析摘要生成失败"
# 创建全局实例(需要时手动初始化)
_plot_analyzer_instance = None
def get_plot_analyzer(ai_service: AIService) -> PlotAnalyzer:
"""获取剧情分析器实例"""
global _plot_analyzer_instance
if _plot_analyzer_instance is None:
_plot_analyzer_instance = PlotAnalyzer(ai_service)
return _plot_analyzer_instance
+81 -15
View File
@@ -315,7 +315,7 @@ class PromptService:
2. 数组中要包含{chapter_count}个章节对象
3. 文本中不要使用中文引号(""),改用【】或《》"""
# 大纲续写提示词
# 大纲续写提示词(记忆增强版)
OUTLINE_CONTINUE_GENERATION = """你是一位经验丰富的小说作家和编剧。请基于以下信息续写小说大纲:
【项目信息】
@@ -340,6 +340,11 @@ class PromptService:
【最近剧情】
{recent_plot}
【🧠 智能记忆系统 - 续写参考】
以下是从故事记忆库中检索到的相关信息,请在续写大纲时参考:
{memory_context}
【续写指导】
- 当前情节阶段:{plot_stage_instruction}
- 起始章节编号:第{start_chapter}
@@ -348,10 +353,12 @@ class PromptService:
请生成第{start_chapter}章到第{end_chapter}章的大纲。
要求:
- 与前文自然衔接,保持故事连贯性
- 遵循情节阶段的发展要求
- 保持与已有章节相同的风格和详细程度
- 推进角色成长和情节发展
- **剧情连贯性**与前文自然衔接,保持故事连贯性
- **记忆参考**:适当参考记忆系统中的伏笔、钩子和情节点
- **伏笔回收**:可以考虑回收未完结的伏笔,制造呼应
- **角色发展**:遵循角色在前文中的成长轨迹
- **情节阶段**:遵循情节阶段的发展要求
- **风格一致**:保持与已有章节相同的风格和详细程度
**重要格式要求:**
1. 只返回纯JSON数组格式,不要包含任何markdown标记、代码块标记或其他说明文字
@@ -465,7 +472,7 @@ class PromptService:
请直接输出章节正文内容,不要包含章节标题和其他说明文字。"""
# 章节完整创作提示词(带前置章节上下文)
# 章节完整创作提示词(带前置章节上下文和记忆增强
CHAPTER_GENERATION_WITH_CONTEXT = """你是一位专业的小说作家。请根据以下信息创作本章内容:
项目信息:
@@ -489,6 +496,11 @@ class PromptService:
【已完成的前置章节内容】
{previous_content}
【🧠 智能记忆系统 - 重要参考】
以下是从故事记忆库中检索到的相关信息,请在创作时适当参考和呼应:
{memory_context}
本章信息:
- 章节序号:第{chapter_number}
- 章节标题:{chapter_title}
@@ -518,8 +530,15 @@ class PromptService:
- 体现世界观特色
5. **承上启下**
- 开头自然衔接上一章结尾
- 结尾为下一章做好铺垫
- 开头自然衔接上一章结尾
- 结尾为下一章做好铺垫
6. **记忆系统使用指南**
- **最近章节记忆**:保持情节连贯,注意角色状态和剧情发展
- **语义相关记忆**:参考相似情节的处理方式
- **未完结伏笔**:适当时机可以回收伏笔,制造呼应效果
- **角色状态记忆**:确保角色行为符合其发展轨迹
- **重要情节点**:与关键剧情保持一致
请直接输出章节正文内容,不要包含章节标题和其他说明文字。"""
@@ -746,14 +765,26 @@ class PromptService:
characters_info: str, outlines_context: str,
chapter_number: int, chapter_title: str,
chapter_outline: str, style_content: str = "",
target_word_count: int = 3000) -> str:
target_word_count: int = 3000,
memory_context: dict = None) -> str:
"""
获取章节完整创作提示词
Args:
style_content: 写作风格要求内容,如果提供则会追加到提示词中
target_word_count: 目标字数,默认3000字
memory_context: 记忆上下文(可选)
"""
# 格式化记忆上下文
memory_text = ""
if memory_context:
memory_text = "\n【🧠 智能记忆系统 - 重要参考】\n"
memory_text += memory_context.get('recent_context', '')
memory_text += "\n" + memory_context.get('relevant_memories', '')
memory_text += "\n" + memory_context.get('foreshadows', '')
memory_text += "\n" + memory_context.get('character_states', '')
memory_text += "\n" + memory_context.get('plot_points', '')
base_prompt = cls.format_prompt(
cls.CHAPTER_GENERATION,
title=title,
@@ -772,6 +803,13 @@ class PromptService:
target_word_count=target_word_count
)
# 插入记忆上下文
if memory_text:
base_prompt = base_prompt.replace(
"本章信息:",
memory_text + "\n\n本章信息:"
)
# 如果有风格要求,应用到提示词中
if style_content:
return WritingStyleManager.apply_style_to_prompt(base_prompt, style_content)
@@ -786,14 +824,27 @@ class PromptService:
previous_content: str, chapter_number: int,
chapter_title: str, chapter_outline: str,
style_content: str = "",
target_word_count: int = 3000) -> str:
target_word_count: int = 3000,
memory_context: dict = None) -> str:
"""
获取章节完整创作提示词(带前置章节上下文)
获取章节完整创作提示词(带前置章节上下文和记忆增强
Args:
style_content: 写作风格要求内容,如果提供则会追加到提示词中
target_word_count: 目标字数,默认3000字
memory_context: 记忆上下文(可选)
"""
# 格式化记忆上下文
memory_text = ""
if memory_context:
memory_text = memory_context.get('recent_context', '')
memory_text += "\n" + memory_context.get('relevant_memories', '')
memory_text += "\n" + memory_context.get('foreshadows', '')
memory_text += "\n" + memory_context.get('character_states', '')
memory_text += "\n" + memory_context.get('plot_points', '')
else:
memory_text = "暂无相关记忆"
base_prompt = cls.format_prompt(
cls.CHAPTER_GENERATION_WITH_CONTEXT,
title=title,
@@ -810,7 +861,8 @@ class PromptService:
chapter_number=chapter_number,
chapter_title=chapter_title,
chapter_outline=chapter_outline,
target_word_count=target_word_count
target_word_count=target_word_count,
memory_context=memory_text
)
# 如果有风格要求,应用到提示词中
@@ -839,9 +891,22 @@ class PromptService:
current_chapter_count: int, all_chapters_brief: str,
recent_plot: str, plot_stage_instruction: str,
start_chapter: int, story_direction: str,
requirements: str = "") -> str:
"""获取大纲续写提示词"""
requirements: str = "",
memory_context: dict = None) -> str:
"""获取大纲续写提示词(支持记忆增强)"""
end_chapter = start_chapter + chapter_count - 1
# 格式化记忆上下文
memory_text = ""
if memory_context:
memory_text = memory_context.get('recent_context', '')
memory_text += "\n" + memory_context.get('relevant_memories', '')
memory_text += "\n" + memory_context.get('foreshadows', '')
memory_text += "\n" + memory_context.get('character_states', '')
memory_text += "\n" + memory_context.get('plot_points', '')
else:
memory_text = "暂无相关记忆(可能是首次续写或记忆库为空)"
return cls.format_prompt(
cls.OUTLINE_CONTINUE_GENERATION,
title=title,
@@ -861,7 +926,8 @@ class PromptService:
start_chapter=start_chapter,
end_chapter=end_chapter,
story_direction=story_direction,
requirements=requirements or "无特殊要求"
requirements=requirements or "无特殊要求",
memory_context=memory_text
)
@classmethod
+5 -1
View File
@@ -17,4 +17,8 @@ anthropic==0.18.0
# 工具库
httpx==0.26.0
python-dotenv==1.0.0
python-dotenv==1.0.0
# 向量数据库和Embedding (长期记忆系统)
chromadb>=0.5.0
sentence-transformers>=2.3.1