update:1.更新导入导出功能 2.实现RAG记忆功能,引入剧情分析功能
This commit is contained in:
+712
-6
@@ -1,11 +1,12 @@
|
||||
"""章节管理API"""
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, Query
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, Query, BackgroundTasks
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, func
|
||||
import json
|
||||
import asyncio
|
||||
from typing import Optional
|
||||
from datetime import datetime
|
||||
|
||||
from app.database import get_db
|
||||
from app.models.chapter import Chapter
|
||||
@@ -14,6 +15,8 @@ from app.models.outline import Outline
|
||||
from app.models.character import Character
|
||||
from app.models.generation_history import GenerationHistory
|
||||
from app.models.writing_style import WritingStyle
|
||||
from app.models.analysis_task import AnalysisTask
|
||||
from app.models.memory import PlotAnalysis, StoryMemory
|
||||
from app.schemas.chapter import (
|
||||
ChapterCreate,
|
||||
ChapterUpdate,
|
||||
@@ -23,6 +26,8 @@ from app.schemas.chapter import (
|
||||
)
|
||||
from app.services.ai_service import AIService
|
||||
from app.services.prompt_service import prompt_service
|
||||
from app.services.plot_analyzer import PlotAnalyzer
|
||||
from app.services.memory_service import memory_service
|
||||
from app.logger import get_logger
|
||||
from app.api.settings import get_user_ai_service
|
||||
|
||||
@@ -101,6 +106,63 @@ async def get_chapter(
|
||||
return chapter
|
||||
|
||||
|
||||
@router.get("/{chapter_id}/navigation", summary="获取章节导航信息")
|
||||
async def get_chapter_navigation(
|
||||
chapter_id: str,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
获取章节的导航信息(上一章/下一章)
|
||||
用于章节阅读器的翻页功能
|
||||
"""
|
||||
# 获取当前章节
|
||||
result = await db.execute(
|
||||
select(Chapter).where(Chapter.id == chapter_id)
|
||||
)
|
||||
current_chapter = result.scalar_one_or_none()
|
||||
|
||||
if not current_chapter:
|
||||
raise HTTPException(status_code=404, detail="章节不存在")
|
||||
|
||||
# 获取上一章
|
||||
prev_result = await db.execute(
|
||||
select(Chapter)
|
||||
.where(Chapter.project_id == current_chapter.project_id)
|
||||
.where(Chapter.chapter_number < current_chapter.chapter_number)
|
||||
.order_by(Chapter.chapter_number.desc())
|
||||
.limit(1)
|
||||
)
|
||||
prev_chapter = prev_result.scalar_one_or_none()
|
||||
|
||||
# 获取下一章
|
||||
next_result = await db.execute(
|
||||
select(Chapter)
|
||||
.where(Chapter.project_id == current_chapter.project_id)
|
||||
.where(Chapter.chapter_number > current_chapter.chapter_number)
|
||||
.order_by(Chapter.chapter_number.asc())
|
||||
.limit(1)
|
||||
)
|
||||
next_chapter = next_result.scalar_one_or_none()
|
||||
|
||||
return {
|
||||
"current": {
|
||||
"id": current_chapter.id,
|
||||
"chapter_number": current_chapter.chapter_number,
|
||||
"title": current_chapter.title
|
||||
},
|
||||
"previous": {
|
||||
"id": prev_chapter.id,
|
||||
"chapter_number": prev_chapter.chapter_number,
|
||||
"title": prev_chapter.title
|
||||
} if prev_chapter else None,
|
||||
"next": {
|
||||
"id": next_chapter.id,
|
||||
"chapter_number": next_chapter.chapter_number,
|
||||
"title": next_chapter.title
|
||||
} if next_chapter else None
|
||||
}
|
||||
|
||||
|
||||
@router.put("/{chapter_id}", response_model=ChapterResponse, summary="更新章节")
|
||||
async def update_chapter(
|
||||
chapter_id: str,
|
||||
@@ -248,10 +310,273 @@ async def check_can_generate(
|
||||
}
|
||||
|
||||
|
||||
async def analyze_chapter_background(
|
||||
chapter_id: str,
|
||||
user_id: str,
|
||||
project_id: str,
|
||||
task_id: str,
|
||||
ai_service: AIService
|
||||
):
|
||||
"""
|
||||
后台异步分析章节
|
||||
|
||||
Args:
|
||||
chapter_id: 章节ID
|
||||
user_id: 用户ID
|
||||
project_id: 项目ID
|
||||
task_id: 任务ID
|
||||
ai_service: AI服务实例
|
||||
"""
|
||||
db_session = None
|
||||
try:
|
||||
logger.info(f"🔍 开始后台分析章节: {chapter_id}")
|
||||
|
||||
# 等待一小段时间,确保主会话的commit已经持久化到磁盘
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# 创建独立数据库会话
|
||||
from app.database import get_engine
|
||||
from sqlalchemy.ext.asyncio import async_sessionmaker, AsyncSession
|
||||
|
||||
engine = await get_engine(user_id)
|
||||
AsyncSessionLocal = async_sessionmaker(
|
||||
engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False
|
||||
)
|
||||
db_session = AsyncSessionLocal()
|
||||
|
||||
# 1. 获取任务(添加重试逻辑)
|
||||
task = None
|
||||
for retry in range(3):
|
||||
task_result = await db_session.execute(
|
||||
select(AnalysisTask).where(AnalysisTask.id == task_id)
|
||||
)
|
||||
task = task_result.scalar_one_or_none()
|
||||
if task:
|
||||
break
|
||||
if retry < 2:
|
||||
logger.warning(f"⚠️ 第{retry+1}次未找到任务 {task_id},等待后重试...")
|
||||
await asyncio.sleep(0.2)
|
||||
|
||||
if not task:
|
||||
logger.error(f"❌ 任务不存在: {task_id}")
|
||||
return
|
||||
|
||||
task.status = 'running'
|
||||
task.started_at = datetime.now()
|
||||
task.progress = 10
|
||||
await db_session.commit()
|
||||
|
||||
# 2. 获取章节信息
|
||||
chapter_result = await db_session.execute(
|
||||
select(Chapter).where(Chapter.id == chapter_id)
|
||||
)
|
||||
chapter = chapter_result.scalar_one_or_none()
|
||||
if not chapter or not chapter.content:
|
||||
task.status = 'failed'
|
||||
task.error_message = '章节不存在或内容为空'
|
||||
task.completed_at = datetime.now()
|
||||
await db_session.commit()
|
||||
logger.error(f"❌ 章节不存在或内容为空: {chapter_id}")
|
||||
return
|
||||
|
||||
task.progress = 20
|
||||
await db_session.commit()
|
||||
|
||||
# 3. 使用PlotAnalyzer分析章节
|
||||
analyzer = PlotAnalyzer(ai_service)
|
||||
analysis_result = await analyzer.analyze_chapter(
|
||||
chapter_number=chapter.chapter_number,
|
||||
title=chapter.title,
|
||||
content=chapter.content,
|
||||
word_count=chapter.word_count or len(chapter.content)
|
||||
)
|
||||
|
||||
if not analysis_result:
|
||||
task.status = 'failed'
|
||||
task.error_message = 'AI分析失败,请检查日志'
|
||||
task.completed_at = datetime.now()
|
||||
await db_session.commit()
|
||||
logger.error(f"❌ AI分析失败: {chapter_id}")
|
||||
return
|
||||
|
||||
task.progress = 60
|
||||
await db_session.commit()
|
||||
|
||||
# 4. 保存分析结果到数据库(先检查是否已存在)
|
||||
existing_analysis_result = await db_session.execute(
|
||||
select(PlotAnalysis).where(PlotAnalysis.chapter_id == chapter_id)
|
||||
)
|
||||
existing_analysis = existing_analysis_result.scalar_one_or_none()
|
||||
|
||||
if existing_analysis:
|
||||
# 更新现有记录
|
||||
logger.info(f" 更新现有分析记录: {existing_analysis.id}")
|
||||
existing_analysis.plot_stage = analysis_result.get('plot_stage', '发展')
|
||||
existing_analysis.conflict_level = analysis_result.get('conflict', {}).get('level', 0)
|
||||
existing_analysis.conflict_types = analysis_result.get('conflict', {}).get('types', [])
|
||||
existing_analysis.emotional_tone = analysis_result.get('emotional_arc', {}).get('primary_emotion', '')
|
||||
existing_analysis.emotional_intensity = analysis_result.get('emotional_arc', {}).get('intensity', 0) / 10.0
|
||||
existing_analysis.hooks = analysis_result.get('hooks', [])
|
||||
existing_analysis.hooks_count = len(analysis_result.get('hooks', []))
|
||||
existing_analysis.foreshadows = analysis_result.get('foreshadows', [])
|
||||
existing_analysis.foreshadows_planted = sum(1 for f in analysis_result.get('foreshadows', []) if f.get('type') == 'planted')
|
||||
existing_analysis.foreshadows_resolved = sum(1 for f in analysis_result.get('foreshadows', []) if f.get('type') == 'resolved')
|
||||
existing_analysis.plot_points = analysis_result.get('plot_points', [])
|
||||
existing_analysis.plot_points_count = len(analysis_result.get('plot_points', []))
|
||||
existing_analysis.character_states = analysis_result.get('character_states', [])
|
||||
existing_analysis.scenes = analysis_result.get('scenes', [])
|
||||
existing_analysis.pacing = analysis_result.get('pacing', 'moderate')
|
||||
existing_analysis.overall_quality_score = analysis_result.get('scores', {}).get('overall', 0)
|
||||
existing_analysis.pacing_score = analysis_result.get('scores', {}).get('pacing', 0)
|
||||
existing_analysis.engagement_score = analysis_result.get('scores', {}).get('engagement', 0)
|
||||
existing_analysis.coherence_score = analysis_result.get('scores', {}).get('coherence', 0)
|
||||
existing_analysis.analysis_report = analyzer.generate_analysis_summary(analysis_result)
|
||||
existing_analysis.suggestions = analysis_result.get('suggestions', [])
|
||||
existing_analysis.dialogue_ratio = analysis_result.get('dialogue_ratio', 0)
|
||||
existing_analysis.description_ratio = analysis_result.get('description_ratio', 0)
|
||||
else:
|
||||
# 创建新记录
|
||||
logger.info(f" 创建新的分析记录")
|
||||
plot_analysis = PlotAnalysis(
|
||||
chapter_id=chapter_id,
|
||||
project_id=project_id,
|
||||
plot_stage=analysis_result.get('plot_stage', '发展'),
|
||||
conflict_level=analysis_result.get('conflict', {}).get('level', 0),
|
||||
conflict_types=analysis_result.get('conflict', {}).get('types', []),
|
||||
emotional_tone=analysis_result.get('emotional_arc', {}).get('primary_emotion', ''),
|
||||
emotional_intensity=analysis_result.get('emotional_arc', {}).get('intensity', 0) / 10.0,
|
||||
hooks=analysis_result.get('hooks', []),
|
||||
hooks_count=len(analysis_result.get('hooks', [])),
|
||||
foreshadows=analysis_result.get('foreshadows', []),
|
||||
foreshadows_planted=sum(1 for f in analysis_result.get('foreshadows', []) if f.get('type') == 'planted'),
|
||||
foreshadows_resolved=sum(1 for f in analysis_result.get('foreshadows', []) if f.get('type') == 'resolved'),
|
||||
plot_points=analysis_result.get('plot_points', []),
|
||||
plot_points_count=len(analysis_result.get('plot_points', [])),
|
||||
character_states=analysis_result.get('character_states', []),
|
||||
scenes=analysis_result.get('scenes', []),
|
||||
pacing=analysis_result.get('pacing', 'moderate'),
|
||||
overall_quality_score=analysis_result.get('scores', {}).get('overall', 0),
|
||||
pacing_score=analysis_result.get('scores', {}).get('pacing', 0),
|
||||
engagement_score=analysis_result.get('scores', {}).get('engagement', 0),
|
||||
coherence_score=analysis_result.get('scores', {}).get('coherence', 0),
|
||||
analysis_report=analyzer.generate_analysis_summary(analysis_result),
|
||||
suggestions=analysis_result.get('suggestions', []),
|
||||
dialogue_ratio=analysis_result.get('dialogue_ratio', 0),
|
||||
description_ratio=analysis_result.get('description_ratio', 0)
|
||||
)
|
||||
db_session.add(plot_analysis)
|
||||
|
||||
await db_session.commit()
|
||||
|
||||
task.progress = 80
|
||||
await db_session.commit()
|
||||
|
||||
# 5. 提取记忆并保存到向量数据库(传入章节内容用于计算位置)
|
||||
memories = analyzer.extract_memories_from_analysis(
|
||||
analysis=analysis_result,
|
||||
chapter_id=chapter_id,
|
||||
chapter_number=chapter.chapter_number,
|
||||
chapter_content=chapter.content or ""
|
||||
)
|
||||
|
||||
# 先删除该章节的旧记忆(支持重新分析)
|
||||
old_memories_result = await db_session.execute(
|
||||
select(StoryMemory).where(StoryMemory.chapter_id == chapter_id)
|
||||
)
|
||||
old_memories = old_memories_result.scalars().all()
|
||||
for old_mem in old_memories:
|
||||
await db_session.delete(old_mem)
|
||||
logger.info(f" 删除旧记忆: {len(old_memories)}条")
|
||||
|
||||
# 准备批量添加的记忆数据
|
||||
memory_records = []
|
||||
for mem in memories:
|
||||
memory_id = f"{chapter_id}_{mem['type']}_{len(memory_records)}"
|
||||
memory_records.append({
|
||||
'id': memory_id,
|
||||
'content': mem['content'],
|
||||
'type': mem['type'],
|
||||
'metadata': mem['metadata']
|
||||
})
|
||||
|
||||
# 从metadata中提取位置信息
|
||||
text_position = mem['metadata'].get('text_position', -1)
|
||||
text_length = mem['metadata'].get('text_length', 0)
|
||||
|
||||
# 同时保存到关系数据库
|
||||
story_memory = StoryMemory(
|
||||
id=memory_id,
|
||||
project_id=project_id,
|
||||
chapter_id=chapter_id,
|
||||
memory_type=mem['type'],
|
||||
content=mem['content'],
|
||||
title=mem['title'],
|
||||
importance_score=mem['metadata'].get('importance_score', 0.5),
|
||||
tags=mem['metadata'].get('tags', []),
|
||||
is_foreshadow=mem['metadata'].get('is_foreshadow', 0),
|
||||
story_timeline=chapter.chapter_number, # 使用章节序号作为时间线
|
||||
chapter_position=text_position, # 保存文本位置
|
||||
text_length=text_length, # 保存文本长度
|
||||
related_characters=mem['metadata'].get('related_characters', []),
|
||||
related_locations=mem['metadata'].get('related_locations', [])
|
||||
)
|
||||
db_session.add(story_memory)
|
||||
|
||||
# 记录日志便于调试
|
||||
if text_position >= 0:
|
||||
logger.debug(f" 保存记忆 {memory_id}: position={text_position}, length={text_length}")
|
||||
|
||||
await db_session.commit()
|
||||
|
||||
# 批量添加到向量数据库
|
||||
if memory_records:
|
||||
added_count = await memory_service.batch_add_memories(
|
||||
user_id=user_id,
|
||||
project_id=project_id,
|
||||
memories=memory_records
|
||||
)
|
||||
logger.info(f"✅ 添加{added_count}条记忆到向量库")
|
||||
|
||||
task.progress = 100
|
||||
task.status = 'completed'
|
||||
task.completed_at = datetime.now()
|
||||
await db_session.commit()
|
||||
|
||||
logger.info(f"✅ 章节分析完成: {chapter_id}, 提取{len(memories)}条记忆")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 后台分析异常: {str(e)}", exc_info=True)
|
||||
# 确保任务状态被更新为failed,避免前端一直轮询
|
||||
if db_session:
|
||||
try:
|
||||
# 重新获取任务以确保有最新状态
|
||||
task_result = await db_session.execute(
|
||||
select(AnalysisTask).where(AnalysisTask.id == task_id)
|
||||
)
|
||||
task = task_result.scalar_one_or_none()
|
||||
if task:
|
||||
task.status = 'failed'
|
||||
task.error_message = str(e)[:500]
|
||||
task.completed_at = datetime.now()
|
||||
task.progress = 0 # 重置进度
|
||||
await db_session.commit()
|
||||
logger.info(f"✅ 任务状态已更新为failed: {task_id}")
|
||||
else:
|
||||
logger.error(f"❌ 无法找到任务进行状态更新: {task_id}")
|
||||
except Exception as update_error:
|
||||
logger.error(f"❌ 更新任务状态失败: {str(update_error)}")
|
||||
finally:
|
||||
if db_session:
|
||||
await db_session.close()
|
||||
|
||||
|
||||
@router.post("/{chapter_id}/generate-stream", summary="AI创作章节内容(流式)")
|
||||
async def generate_chapter_content_stream(
|
||||
chapter_id: str,
|
||||
request: Request,
|
||||
background_tasks: BackgroundTasks,
|
||||
generate_request: ChapterGenerateRequest = ChapterGenerateRequest(),
|
||||
user_ai_service: AIService = Depends(get_user_ai_service)
|
||||
):
|
||||
@@ -301,6 +626,9 @@ async def generate_chapter_content_stream(
|
||||
# 在生成器内部创建独立的数据库会话
|
||||
db_session = None
|
||||
db_committed = False
|
||||
# 获取当前用户ID(在生成器外部就需要)
|
||||
current_user_id = getattr(request.state, "user_id", "system")
|
||||
|
||||
try:
|
||||
# 创建新的数据库会话
|
||||
async for db_session in get_db(request):
|
||||
@@ -396,11 +724,42 @@ async def generate_chapter_content_stream(
|
||||
previous_content += recent_content
|
||||
|
||||
logger.info(f"构建前置上下文:{len(early_chapters)}章摘要 + {len(recent_chapters)}章完整内容")
|
||||
|
||||
# 🧠 构建记忆增强上下文
|
||||
logger.info(f"🧠 开始构建记忆增强上下文...")
|
||||
memory_context = await memory_service.build_context_for_generation(
|
||||
user_id=current_user_id,
|
||||
project_id=project.id,
|
||||
current_chapter=current_chapter.chapter_number,
|
||||
chapter_outline=outline.content if outline else current_chapter.summary or "",
|
||||
character_names=[c.name for c in characters] if characters else None
|
||||
)
|
||||
|
||||
# 计算各部分的字符长度
|
||||
context_lengths = {
|
||||
'recent_context': len(memory_context.get('recent_context', '')),
|
||||
'relevant_memories': len(memory_context.get('relevant_memories', '')),
|
||||
'foreshadows': len(memory_context.get('foreshadows', '')),
|
||||
'character_states': len(memory_context.get('character_states', '')),
|
||||
'plot_points': len(memory_context.get('plot_points', ''))
|
||||
}
|
||||
total_memory_length = sum(context_lengths.values())
|
||||
|
||||
logger.info(f"✅ 记忆上下文构建完成: {memory_context['stats']}")
|
||||
logger.info(f"📏 记忆上下文长度统计:")
|
||||
logger.info(f" - 最近章节记忆: {context_lengths['recent_context']} 字符")
|
||||
logger.info(f" - 语义相关记忆: {context_lengths['relevant_memories']} 字符")
|
||||
logger.info(f" - 未完结伏笔: {context_lengths['foreshadows']} 字符")
|
||||
logger.info(f" - 角色状态记忆: {context_lengths['character_states']} 字符")
|
||||
logger.info(f" - 重要情节点: {context_lengths['plot_points']} 字符")
|
||||
logger.info(f" - 记忆总长度: {total_memory_length} 字符")
|
||||
logger.info(f" - 前置章节上下文长度: {len(previous_content)} 字符")
|
||||
logger.info(f" - 总上下文长度(估算): {total_memory_length + len(previous_content) + 2000} 字符")
|
||||
|
||||
# 发送开始事件
|
||||
yield f"data: {json.dumps({'type': 'start', 'message': '开始AI创作...'}, ensure_ascii=False)}\n\n"
|
||||
|
||||
# 根据是否有前置内容选择不同的提示词,并应用写作风格
|
||||
# 根据是否有前置内容选择不同的提示词,并应用写作风格和记忆增强
|
||||
if previous_content:
|
||||
prompt = prompt_service.get_chapter_generation_with_context_prompt(
|
||||
title=project.title,
|
||||
@@ -418,7 +777,8 @@ async def generate_chapter_content_stream(
|
||||
chapter_title=current_chapter.title,
|
||||
chapter_outline=outline.content if outline else current_chapter.summary or '暂无大纲',
|
||||
style_content=style_content,
|
||||
target_word_count=target_word_count
|
||||
target_word_count=target_word_count,
|
||||
memory_context=memory_context
|
||||
)
|
||||
else:
|
||||
prompt = prompt_service.get_chapter_generation_prompt(
|
||||
@@ -436,7 +796,8 @@ async def generate_chapter_content_stream(
|
||||
chapter_title=current_chapter.title,
|
||||
chapter_outline=outline.content if outline else current_chapter.summary or '暂无大纲',
|
||||
style_content=style_content,
|
||||
target_word_count=target_word_count
|
||||
target_word_count=target_word_count,
|
||||
memory_context=memory_context
|
||||
)
|
||||
|
||||
logger.info(f"开始AI流式创作章节 {chapter_id}")
|
||||
@@ -474,8 +835,48 @@ async def generate_chapter_content_stream(
|
||||
|
||||
logger.info(f"成功创作章节 {chapter_id},共 {new_word_count} 字")
|
||||
|
||||
# 发送完成事件
|
||||
yield f"data: {json.dumps({'type': 'done', 'message': '创作完成', 'word_count': new_word_count}, ensure_ascii=False)}\n\n"
|
||||
# 创建分析任务并启动后台分析
|
||||
analysis_task = AnalysisTask(
|
||||
chapter_id=chapter_id,
|
||||
user_id=current_user_id,
|
||||
project_id=project.id,
|
||||
status='pending',
|
||||
progress=0
|
||||
)
|
||||
db_session.add(analysis_task)
|
||||
await db_session.commit()
|
||||
# 不需要refresh,只需要获取ID
|
||||
|
||||
task_id = analysis_task.id
|
||||
|
||||
# 启动后台分析任务
|
||||
background_tasks.add_task(
|
||||
analyze_chapter_background,
|
||||
chapter_id=chapter_id,
|
||||
user_id=current_user_id,
|
||||
project_id=project.id,
|
||||
task_id=task_id,
|
||||
ai_service=user_ai_service
|
||||
)
|
||||
|
||||
logger.info(f"📋 已创建分析任务: {task_id}")
|
||||
|
||||
# 发送完成事件(包含分析任务ID)
|
||||
completion_data = {
|
||||
'type': 'done',
|
||||
'message': '创作完成',
|
||||
'word_count': new_word_count,
|
||||
'analysis_task_id': task_id
|
||||
}
|
||||
yield f"data: {json.dumps(completion_data, ensure_ascii=False)}\n\n"
|
||||
|
||||
# 发送分析排队事件
|
||||
analysis_queued_data = {
|
||||
'type': 'analysis_queued',
|
||||
'task_id': task_id,
|
||||
'message': '章节分析已加入队列'
|
||||
}
|
||||
yield f"data: {json.dumps(analysis_queued_data, ensure_ascii=False)}\n\n"
|
||||
|
||||
break # 退出async for db_session循环
|
||||
|
||||
@@ -527,3 +928,308 @@ async def generate_chapter_content_stream(
|
||||
"X-Accel-Buffering": "no"
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{chapter_id}/analysis/status", summary="查询章节分析任务状态")
|
||||
async def get_analysis_task_status(
|
||||
chapter_id: str,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
查询指定章节的最新分析任务状态
|
||||
|
||||
返回:
|
||||
- task_id: 任务ID
|
||||
- status: pending/running/completed/failed
|
||||
- progress: 0-100
|
||||
- error_message: 错误信息(如果失败)
|
||||
- created_at: 创建时间
|
||||
- completed_at: 完成时间
|
||||
"""
|
||||
# 获取该章节最新的分析任务
|
||||
result = await db.execute(
|
||||
select(AnalysisTask)
|
||||
.where(AnalysisTask.chapter_id == chapter_id)
|
||||
.order_by(AnalysisTask.created_at.desc())
|
||||
.limit(1)
|
||||
)
|
||||
task = result.scalar_one_or_none()
|
||||
|
||||
if not task:
|
||||
raise HTTPException(status_code=404, detail="未找到分析任务")
|
||||
|
||||
return {
|
||||
"task_id": task.id,
|
||||
"chapter_id": task.chapter_id,
|
||||
"status": task.status,
|
||||
"progress": task.progress,
|
||||
"error_message": task.error_message,
|
||||
"created_at": task.created_at.isoformat() if task.created_at else None,
|
||||
"started_at": task.started_at.isoformat() if task.started_at else None,
|
||||
"completed_at": task.completed_at.isoformat() if task.completed_at else None
|
||||
}
|
||||
|
||||
|
||||
@router.get("/{chapter_id}/analysis", summary="获取章节分析结果")
|
||||
async def get_chapter_analysis(
|
||||
chapter_id: str,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
获取章节的完整分析结果
|
||||
|
||||
返回:
|
||||
- analysis_data: 完整的分析数据(JSON)
|
||||
- summary: 分析摘要文本
|
||||
- memories: 提取的记忆列表
|
||||
- created_at: 分析时间
|
||||
"""
|
||||
# 获取分析结果
|
||||
analysis_result = await db.execute(
|
||||
select(PlotAnalysis)
|
||||
.where(PlotAnalysis.chapter_id == chapter_id)
|
||||
.order_by(PlotAnalysis.created_at.desc())
|
||||
.limit(1)
|
||||
)
|
||||
analysis = analysis_result.scalar_one_or_none()
|
||||
|
||||
if not analysis:
|
||||
raise HTTPException(status_code=404, detail="该章节暂无分析结果")
|
||||
|
||||
# 获取相关记忆
|
||||
memories_result = await db.execute(
|
||||
select(StoryMemory)
|
||||
.where(StoryMemory.chapter_id == chapter_id)
|
||||
.order_by(StoryMemory.importance_score.desc())
|
||||
)
|
||||
memories = memories_result.scalars().all()
|
||||
|
||||
return {
|
||||
"chapter_id": chapter_id,
|
||||
"analysis": analysis.to_dict(), # 使用to_dict()方法
|
||||
"memories": [
|
||||
{
|
||||
"id": mem.id,
|
||||
"type": mem.memory_type,
|
||||
"title": mem.title,
|
||||
"content": mem.content,
|
||||
"importance": mem.importance_score,
|
||||
"tags": mem.tags,
|
||||
"is_foreshadow": mem.is_foreshadow,
|
||||
"position": mem.chapter_position,
|
||||
"related_characters": mem.related_characters
|
||||
}
|
||||
for mem in memories
|
||||
],
|
||||
"created_at": analysis.created_at.isoformat() if analysis.created_at else None
|
||||
}
|
||||
|
||||
|
||||
@router.get("/{chapter_id}/annotations", summary="获取章节标注数据")
|
||||
async def get_chapter_annotations(
|
||||
chapter_id: str,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
获取章节的标注数据(用于前端展示标注)
|
||||
|
||||
返回格式化的标注列表,包含精确位置信息
|
||||
适用于章节内容的可视化标注展示
|
||||
"""
|
||||
# 获取章节
|
||||
chapter_result = await db.execute(
|
||||
select(Chapter).where(Chapter.id == chapter_id)
|
||||
)
|
||||
chapter = chapter_result.scalar_one_or_none()
|
||||
|
||||
if not chapter:
|
||||
raise HTTPException(status_code=404, detail="章节不存在")
|
||||
|
||||
# 获取分析结果
|
||||
analysis_result = await db.execute(
|
||||
select(PlotAnalysis)
|
||||
.where(PlotAnalysis.chapter_id == chapter_id)
|
||||
.order_by(PlotAnalysis.created_at.desc())
|
||||
.limit(1)
|
||||
)
|
||||
analysis = analysis_result.scalar_one_or_none()
|
||||
|
||||
# 获取记忆
|
||||
memories_result = await db.execute(
|
||||
select(StoryMemory)
|
||||
.where(StoryMemory.chapter_id == chapter_id)
|
||||
.order_by(StoryMemory.importance_score.desc())
|
||||
)
|
||||
memories = memories_result.scalars().all()
|
||||
|
||||
# 构建标注数据
|
||||
annotations = []
|
||||
|
||||
for mem in memories:
|
||||
# 优先从数据库读取位置信息
|
||||
position = mem.chapter_position if mem.chapter_position is not None else -1
|
||||
length = mem.text_length if hasattr(mem, 'text_length') and mem.text_length is not None else 0
|
||||
metadata_extra = {}
|
||||
|
||||
# 如果数据库中没有位置信息,尝试从分析数据中重新计算
|
||||
if position == -1 and analysis and chapter.content:
|
||||
# 根据记忆类型从分析数据中查找对应项
|
||||
if mem.memory_type == 'hook' and analysis.hooks:
|
||||
for hook in analysis.hooks:
|
||||
# 通过标题或内容匹配
|
||||
if mem.title and hook.get('type') in mem.title:
|
||||
keyword = hook.get('keyword', '')
|
||||
if keyword:
|
||||
pos = chapter.content.find(keyword)
|
||||
if pos != -1:
|
||||
position = pos
|
||||
length = len(keyword)
|
||||
metadata_extra["strength"] = hook.get('strength', 5)
|
||||
metadata_extra["position_desc"] = hook.get('position', '')
|
||||
break
|
||||
|
||||
elif mem.memory_type == 'foreshadow' and analysis.foreshadows:
|
||||
for foreshadow in analysis.foreshadows:
|
||||
if foreshadow.get('content') in mem.content:
|
||||
keyword = foreshadow.get('keyword', '')
|
||||
if keyword:
|
||||
pos = chapter.content.find(keyword)
|
||||
if pos != -1:
|
||||
position = pos
|
||||
length = len(keyword)
|
||||
metadata_extra["foreshadow_type"] = foreshadow.get('type', 'planted')
|
||||
metadata_extra["strength"] = foreshadow.get('strength', 5)
|
||||
break
|
||||
|
||||
elif mem.memory_type == 'plot_point' and analysis.plot_points:
|
||||
for plot_point in analysis.plot_points:
|
||||
if plot_point.get('content') in mem.content:
|
||||
keyword = plot_point.get('keyword', '')
|
||||
if keyword:
|
||||
pos = chapter.content.find(keyword)
|
||||
if pos != -1:
|
||||
position = pos
|
||||
length = len(keyword)
|
||||
break
|
||||
else:
|
||||
# 如果数据库有位置,也从分析数据中提取额外的元数据
|
||||
if analysis:
|
||||
if mem.memory_type == 'hook' and analysis.hooks:
|
||||
for hook in analysis.hooks:
|
||||
if mem.title and hook.get('type') in mem.title:
|
||||
metadata_extra["strength"] = hook.get('strength', 5)
|
||||
metadata_extra["position_desc"] = hook.get('position', '')
|
||||
break
|
||||
|
||||
elif mem.memory_type == 'foreshadow' and analysis.foreshadows:
|
||||
for foreshadow in analysis.foreshadows:
|
||||
if foreshadow.get('content') in mem.content:
|
||||
metadata_extra["foreshadow_type"] = foreshadow.get('type', 'planted')
|
||||
metadata_extra["strength"] = foreshadow.get('strength', 5)
|
||||
break
|
||||
|
||||
annotation = {
|
||||
"id": mem.id,
|
||||
"type": mem.memory_type,
|
||||
"title": mem.title,
|
||||
"content": mem.content,
|
||||
"importance": mem.importance_score or 0.5,
|
||||
"position": position,
|
||||
"length": length,
|
||||
"tags": mem.tags or [],
|
||||
"metadata": {
|
||||
"is_foreshadow": mem.is_foreshadow,
|
||||
"related_characters": mem.related_characters or [],
|
||||
"related_locations": mem.related_locations or [],
|
||||
**metadata_extra
|
||||
}
|
||||
}
|
||||
|
||||
annotations.append(annotation)
|
||||
|
||||
return {
|
||||
"chapter_id": chapter_id,
|
||||
"chapter_number": chapter.chapter_number,
|
||||
"title": chapter.title,
|
||||
"word_count": chapter.word_count or 0,
|
||||
"annotations": annotations,
|
||||
"has_analysis": analysis is not None,
|
||||
"summary": {
|
||||
"total_annotations": len(annotations),
|
||||
"hooks": len([a for a in annotations if a["type"] == "hook"]),
|
||||
"foreshadows": len([a for a in annotations if a["type"] == "foreshadow"]),
|
||||
"plot_points": len([a for a in annotations if a["type"] == "plot_point"]),
|
||||
"character_events": len([a for a in annotations if a["type"] == "character_event"])
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@router.post("/{chapter_id}/analyze", summary="手动触发章节分析")
|
||||
async def trigger_chapter_analysis(
|
||||
chapter_id: str,
|
||||
request: Request,
|
||||
background_tasks: BackgroundTasks,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
user_ai_service: AIService = Depends(get_user_ai_service)
|
||||
):
|
||||
"""
|
||||
手动触发章节分析(用于重新分析或分析旧章节)
|
||||
"""
|
||||
# 从请求中获取用户ID
|
||||
user_id = getattr(request.state, "user_id", None)
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="未登录")
|
||||
|
||||
# 验证章节存在
|
||||
chapter_result = await db.execute(
|
||||
select(Chapter).where(Chapter.id == chapter_id)
|
||||
)
|
||||
chapter = chapter_result.scalar_one_or_none()
|
||||
|
||||
if not chapter:
|
||||
raise HTTPException(status_code=404, detail="章节不存在")
|
||||
|
||||
if not chapter.content or chapter.content.strip() == "":
|
||||
raise HTTPException(status_code=400, detail="章节内容为空,无法分析")
|
||||
|
||||
# 获取项目信息
|
||||
project_result = await db.execute(
|
||||
select(Project).where(Project.id == chapter.project_id)
|
||||
)
|
||||
project = project_result.scalar_one_or_none()
|
||||
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="项目不存在")
|
||||
|
||||
# 创建分析任务
|
||||
analysis_task = AnalysisTask(
|
||||
chapter_id=chapter_id,
|
||||
user_id=user_id,
|
||||
project_id=project.id,
|
||||
status='pending',
|
||||
progress=0
|
||||
)
|
||||
db.add(analysis_task)
|
||||
await db.commit()
|
||||
# 注意:不需要refresh,因为我们只需要id,而id在commit后已经生成
|
||||
|
||||
task_id = analysis_task.id
|
||||
|
||||
# 启动后台分析任务
|
||||
background_tasks.add_task(
|
||||
analyze_chapter_background,
|
||||
chapter_id=chapter_id,
|
||||
user_id=user_id,
|
||||
project_id=project.id,
|
||||
task_id=task_id,
|
||||
ai_service=user_ai_service
|
||||
)
|
||||
|
||||
logger.info(f"📋 手动触发分析任务: {task_id}")
|
||||
|
||||
return {
|
||||
"task_id": task_id,
|
||||
"chapter_id": chapter_id,
|
||||
"status": "pending",
|
||||
"message": "分析任务已创建并加入队列"
|
||||
}
|
||||
|
||||
@@ -0,0 +1,383 @@
|
||||
"""记忆管理API - 提供记忆的查询、分析等接口"""
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, and_, desc
|
||||
from typing import List, Optional
|
||||
from app.database import get_db
|
||||
from app.models.memory import StoryMemory, PlotAnalysis
|
||||
from app.models.chapter import Chapter
|
||||
from app.services.memory_service import memory_service
|
||||
from app.services.plot_analyzer import get_plot_analyzer
|
||||
from app.services.ai_service import create_user_ai_service
|
||||
from app.models.settings import Settings
|
||||
from app.logger import get_logger
|
||||
import uuid
|
||||
|
||||
logger = get_logger(__name__)
|
||||
router = APIRouter(prefix="/api/memories", tags=["memories"])
|
||||
|
||||
|
||||
@router.post("/projects/{project_id}/analyze-chapter/{chapter_id}")
|
||||
async def analyze_chapter(
|
||||
project_id: str,
|
||||
chapter_id: str,
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
分析章节并生成记忆
|
||||
|
||||
对指定章节进行剧情分析,提取钩子、伏笔、情节点等,并存入记忆系统
|
||||
"""
|
||||
try:
|
||||
user_id = request.state.user_id
|
||||
|
||||
# 获取章节内容
|
||||
result = await db.execute(
|
||||
select(Chapter).where(
|
||||
and_(
|
||||
Chapter.id == chapter_id,
|
||||
Chapter.project_id == project_id
|
||||
)
|
||||
)
|
||||
)
|
||||
chapter = result.scalar_one_or_none()
|
||||
|
||||
if not chapter:
|
||||
raise HTTPException(status_code=404, detail="章节不存在")
|
||||
|
||||
if not chapter.content:
|
||||
raise HTTPException(status_code=400, detail="章节内容为空,无法分析")
|
||||
|
||||
# 获取用户AI设置
|
||||
settings_result = await db.execute(select(Settings))
|
||||
settings = settings_result.scalar_one_or_none()
|
||||
|
||||
if not settings:
|
||||
raise HTTPException(status_code=400, detail="请先配置AI设置")
|
||||
|
||||
# 创建AI服务
|
||||
ai_service = create_user_ai_service(
|
||||
api_provider=settings.api_provider,
|
||||
api_key=settings.api_key,
|
||||
api_base_url=settings.api_base_url,
|
||||
model_name=settings.model_name,
|
||||
temperature=settings.temperature,
|
||||
max_tokens=settings.max_tokens
|
||||
)
|
||||
|
||||
# 执行剧情分析
|
||||
analyzer = get_plot_analyzer(ai_service)
|
||||
analysis_result = await analyzer.analyze_chapter(
|
||||
chapter_number=chapter.chapter_number,
|
||||
title=chapter.title,
|
||||
content=chapter.content,
|
||||
word_count=chapter.word_count or len(chapter.content)
|
||||
)
|
||||
|
||||
if not analysis_result:
|
||||
raise HTTPException(status_code=500, detail="剧情分析失败")
|
||||
|
||||
# 保存分析结果到数据库
|
||||
plot_analysis = PlotAnalysis(
|
||||
id=str(uuid.uuid4()),
|
||||
project_id=project_id,
|
||||
chapter_id=chapter_id,
|
||||
plot_stage=analysis_result.get('plot_stage'),
|
||||
conflict_level=analysis_result.get('conflict', {}).get('level'),
|
||||
conflict_types=analysis_result.get('conflict', {}).get('types'),
|
||||
emotional_tone=analysis_result.get('emotional_arc', {}).get('primary_emotion'),
|
||||
emotional_intensity=analysis_result.get('emotional_arc', {}).get('intensity', 0) / 10,
|
||||
emotional_curve=analysis_result.get('emotional_arc'),
|
||||
hooks=analysis_result.get('hooks'),
|
||||
hooks_count=len(analysis_result.get('hooks', [])),
|
||||
hooks_avg_strength=sum(h.get('strength', 0) for h in analysis_result.get('hooks', [])) / max(len(analysis_result.get('hooks', [])), 1),
|
||||
foreshadows=analysis_result.get('foreshadows'),
|
||||
foreshadows_planted=sum(1 for f in analysis_result.get('foreshadows', []) if f.get('type') == 'planted'),
|
||||
foreshadows_resolved=sum(1 for f in analysis_result.get('foreshadows', []) if f.get('type') == 'resolved'),
|
||||
plot_points=analysis_result.get('plot_points'),
|
||||
plot_points_count=len(analysis_result.get('plot_points', [])),
|
||||
character_states=analysis_result.get('character_states'),
|
||||
scenes=analysis_result.get('scenes'),
|
||||
pacing=analysis_result.get('pacing'),
|
||||
dialogue_ratio=analysis_result.get('dialogue_ratio'),
|
||||
description_ratio=analysis_result.get('description_ratio'),
|
||||
overall_quality_score=analysis_result.get('scores', {}).get('overall'),
|
||||
pacing_score=analysis_result.get('scores', {}).get('pacing'),
|
||||
engagement_score=analysis_result.get('scores', {}).get('engagement'),
|
||||
coherence_score=analysis_result.get('scores', {}).get('coherence'),
|
||||
analysis_report=analyzer.generate_analysis_summary(analysis_result),
|
||||
suggestions=analysis_result.get('suggestions'),
|
||||
word_count=chapter.word_count
|
||||
)
|
||||
|
||||
# 检查是否已存在分析记录
|
||||
existing = await db.execute(
|
||||
select(PlotAnalysis).where(PlotAnalysis.chapter_id == chapter_id)
|
||||
)
|
||||
if existing.scalar_one_or_none():
|
||||
# 删除旧记录
|
||||
await db.execute(
|
||||
select(PlotAnalysis).where(PlotAnalysis.chapter_id == chapter_id)
|
||||
)
|
||||
await db.delete(existing.scalar_one())
|
||||
|
||||
db.add(plot_analysis)
|
||||
await db.commit()
|
||||
|
||||
# 从分析结果中提取记忆片段
|
||||
memories_data = analyzer.extract_memories_from_analysis(
|
||||
analysis_result,
|
||||
chapter_id,
|
||||
chapter.chapter_number
|
||||
)
|
||||
|
||||
# 保存记忆到数据库和向量库
|
||||
saved_count = 0
|
||||
for mem_data in memories_data:
|
||||
memory_id = str(uuid.uuid4())
|
||||
|
||||
# 保存到关系数据库
|
||||
memory = StoryMemory(
|
||||
id=memory_id,
|
||||
project_id=project_id,
|
||||
chapter_id=chapter_id,
|
||||
memory_type=mem_data['type'],
|
||||
title=mem_data.get('title', ''),
|
||||
content=mem_data['content'],
|
||||
story_timeline=chapter.chapter_number,
|
||||
vector_id=memory_id,
|
||||
**mem_data['metadata']
|
||||
)
|
||||
db.add(memory)
|
||||
|
||||
# 保存到向量库
|
||||
await memory_service.add_memory(
|
||||
user_id=user_id,
|
||||
project_id=project_id,
|
||||
memory_id=memory_id,
|
||||
content=mem_data['content'],
|
||||
memory_type=mem_data['type'],
|
||||
metadata=mem_data['metadata']
|
||||
)
|
||||
saved_count += 1
|
||||
|
||||
await db.commit()
|
||||
|
||||
logger.info(f"✅ 章节分析完成: 保存{saved_count}条记忆")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"分析完成,提取了{saved_count}条记忆",
|
||||
"analysis": plot_analysis.to_dict(),
|
||||
"memories_count": saved_count
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 章节分析失败: {str(e)}")
|
||||
await db.rollback()
|
||||
raise HTTPException(status_code=500, detail=f"分析失败: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/projects/{project_id}/memories")
|
||||
async def get_project_memories(
|
||||
project_id: str,
|
||||
request: Request,
|
||||
memory_type: Optional[str] = None,
|
||||
chapter_id: Optional[str] = None,
|
||||
limit: int = 50,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""获取项目的记忆列表"""
|
||||
try:
|
||||
user_id = request.state.user_id
|
||||
|
||||
# 构建查询
|
||||
query = select(StoryMemory).where(StoryMemory.project_id == project_id)
|
||||
|
||||
if memory_type:
|
||||
query = query.where(StoryMemory.memory_type == memory_type)
|
||||
if chapter_id:
|
||||
query = query.where(StoryMemory.chapter_id == chapter_id)
|
||||
|
||||
query = query.order_by(desc(StoryMemory.importance_score), desc(StoryMemory.created_at)).limit(limit)
|
||||
|
||||
result = await db.execute(query)
|
||||
memories = result.scalars().all()
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"memories": [mem.to_dict() for mem in memories],
|
||||
"total": len(memories)
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 获取记忆失败: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/projects/{project_id}/analysis/{chapter_id}")
|
||||
async def get_chapter_analysis(
|
||||
project_id: str,
|
||||
chapter_id: str,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""获取章节的剧情分析"""
|
||||
try:
|
||||
result = await db.execute(
|
||||
select(PlotAnalysis).where(
|
||||
and_(
|
||||
PlotAnalysis.project_id == project_id,
|
||||
PlotAnalysis.chapter_id == chapter_id
|
||||
)
|
||||
)
|
||||
)
|
||||
analysis = result.scalar_one_or_none()
|
||||
|
||||
if not analysis:
|
||||
raise HTTPException(status_code=404, detail="该章节还未进行分析")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"analysis": analysis.to_dict()
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 获取分析失败: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/projects/{project_id}/search")
|
||||
async def search_memories(
|
||||
project_id: str,
|
||||
request: Request,
|
||||
query: str,
|
||||
memory_types: Optional[List[str]] = None,
|
||||
limit: int = 10,
|
||||
min_importance: float = 0.0
|
||||
):
|
||||
"""语义搜索项目记忆"""
|
||||
try:
|
||||
user_id = request.state.user_id
|
||||
|
||||
memories = await memory_service.search_memories(
|
||||
user_id=user_id,
|
||||
project_id=project_id,
|
||||
query=query,
|
||||
memory_types=memory_types,
|
||||
limit=limit,
|
||||
min_importance=min_importance
|
||||
)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"query": query,
|
||||
"memories": memories,
|
||||
"total": len(memories)
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 搜索记忆失败: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/projects/{project_id}/foreshadows")
|
||||
async def get_unresolved_foreshadows(
|
||||
project_id: str,
|
||||
request: Request,
|
||||
current_chapter: int,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""获取未完结的伏笔"""
|
||||
try:
|
||||
user_id = request.state.user_id
|
||||
|
||||
# 从向量库搜索
|
||||
foreshadows = await memory_service.find_unresolved_foreshadows(
|
||||
user_id=user_id,
|
||||
project_id=project_id,
|
||||
current_chapter=current_chapter
|
||||
)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"foreshadows": foreshadows,
|
||||
"total": len(foreshadows)
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 获取伏笔失败: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/projects/{project_id}/stats")
|
||||
async def get_memory_stats(
|
||||
project_id: str,
|
||||
request: Request
|
||||
):
|
||||
"""获取记忆统计信息"""
|
||||
try:
|
||||
user_id = request.state.user_id
|
||||
|
||||
stats = await memory_service.get_memory_stats(
|
||||
user_id=user_id,
|
||||
project_id=project_id
|
||||
)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"stats": stats
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 获取统计失败: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.delete("/projects/{project_id}/chapters/{chapter_id}/memories")
|
||||
async def delete_chapter_memories(
|
||||
project_id: str,
|
||||
chapter_id: str,
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""删除章节的所有记忆"""
|
||||
try:
|
||||
user_id = request.state.user_id
|
||||
|
||||
# 从数据库删除
|
||||
result = await db.execute(
|
||||
select(StoryMemory).where(
|
||||
and_(
|
||||
StoryMemory.project_id == project_id,
|
||||
StoryMemory.chapter_id == chapter_id
|
||||
)
|
||||
)
|
||||
)
|
||||
memories = result.scalars().all()
|
||||
|
||||
for memory in memories:
|
||||
await db.delete(memory)
|
||||
|
||||
# 从向量库删除
|
||||
await memory_service.delete_chapter_memories(
|
||||
user_id=user_id,
|
||||
project_id=project_id,
|
||||
chapter_id=chapter_id
|
||||
)
|
||||
|
||||
await db.commit()
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"已删除{len(memories)}条记忆"
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 删除记忆失败: {str(e)}")
|
||||
await db.rollback()
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
+62
-12
@@ -1,5 +1,5 @@
|
||||
"""大纲管理API"""
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, func, delete
|
||||
from typing import List, AsyncGenerator, Dict, Any
|
||||
@@ -21,6 +21,7 @@ from app.schemas.outline import (
|
||||
)
|
||||
from app.services.ai_service import AIService
|
||||
from app.services.prompt_service import prompt_service
|
||||
from app.services.memory_service import memory_service
|
||||
from app.logger import get_logger
|
||||
from app.api.settings import get_user_ai_service
|
||||
from app.utils.sse_response import SSEResponse, create_sse_response
|
||||
@@ -328,6 +329,7 @@ async def reorder_outlines(
|
||||
@router.post("/generate", response_model=OutlineListResponse, summary="AI生成/续写大纲")
|
||||
async def generate_outline(
|
||||
request: OutlineGenerateRequest,
|
||||
http_request: Request,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
user_ai_service: AIService = Depends(get_user_ai_service)
|
||||
):
|
||||
@@ -377,8 +379,10 @@ async def generate_outline(
|
||||
detail="续写模式需要已有大纲,当前项目没有大纲"
|
||||
)
|
||||
|
||||
# 获取用户ID用于记忆检索
|
||||
user_id = getattr(http_request.state, "user_id", "system")
|
||||
return await _continue_outline(
|
||||
request, project, existing_outlines, db, user_ai_service
|
||||
request, project, existing_outlines, db, user_ai_service, user_id
|
||||
)
|
||||
|
||||
else:
|
||||
@@ -478,9 +482,10 @@ async def _continue_outline(
|
||||
project: Project,
|
||||
existing_outlines: List[Outline],
|
||||
db: AsyncSession,
|
||||
user_ai_service: AIService
|
||||
user_ai_service: AIService,
|
||||
user_id: str = "system"
|
||||
) -> OutlineListResponse:
|
||||
"""续写大纲 - 分批生成,每批5章"""
|
||||
"""续写大纲 - 分批生成,每批5章(记忆增强版)"""
|
||||
logger.info(f"续写大纲 - 项目: {project.id}, 已有: {len(existing_outlines)} 章")
|
||||
|
||||
# 分析已有大纲
|
||||
@@ -545,7 +550,25 @@ async def _continue_outline(
|
||||
for o in latest_outlines
|
||||
])
|
||||
|
||||
# 使用标准续写提示词模板
|
||||
# 🧠 构建记忆增强上下文(仅续写模式需要)
|
||||
memory_context = None
|
||||
try:
|
||||
logger.info(f"🧠 为第{batch_num + 1}批构建记忆上下文...")
|
||||
# 使用最近一章的大纲作为查询
|
||||
query_outline = recent_outlines[-1].content if recent_outlines else ""
|
||||
memory_context = await memory_service.build_context_for_generation(
|
||||
user_id=user_id,
|
||||
project_id=project.id,
|
||||
current_chapter=current_start_chapter,
|
||||
chapter_outline=query_outline,
|
||||
character_names=[c.name for c in characters] if characters else None
|
||||
)
|
||||
logger.info(f"✅ 记忆上下文构建完成: {memory_context['stats']}")
|
||||
except Exception as e:
|
||||
logger.warning(f"⚠️ 记忆上下文构建失败,继续不使用记忆: {str(e)}")
|
||||
memory_context = None
|
||||
|
||||
# 使用标准续写提示词模板(支持记忆增强)
|
||||
prompt = prompt_service.get_outline_continue_prompt(
|
||||
title=project.title,
|
||||
theme=request.theme or project.theme or "未设定",
|
||||
@@ -563,7 +586,8 @@ async def _continue_outline(
|
||||
plot_stage_instruction=stage_instruction,
|
||||
start_chapter=current_start_chapter,
|
||||
story_direction=request.story_direction or "自然延续",
|
||||
requirements=request.requirements or ""
|
||||
requirements=request.requirements or "",
|
||||
memory_context=memory_context
|
||||
)
|
||||
|
||||
# 调用AI生成当前批次
|
||||
@@ -834,9 +858,10 @@ async def new_outline_generator(
|
||||
async def continue_outline_generator(
|
||||
data: Dict[str, Any],
|
||||
db: AsyncSession,
|
||||
user_ai_service: AIService
|
||||
user_ai_service: AIService,
|
||||
user_id: str = "system"
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""大纲续写SSE生成器 - 分批生成,推送进度"""
|
||||
"""大纲续写SSE生成器 - 分批生成,推送进度(记忆增强版)"""
|
||||
db_committed = False
|
||||
try:
|
||||
yield await SSEResponse.send_progress("开始续写大纲...", 5)
|
||||
@@ -940,12 +965,32 @@ async def continue_outline_generator(
|
||||
for o in latest_outlines
|
||||
])
|
||||
|
||||
# 🧠 构建记忆增强上下文
|
||||
memory_context = None
|
||||
try:
|
||||
yield await SSEResponse.send_progress(
|
||||
f"🧠 构建记忆上下文...",
|
||||
batch_progress + 3
|
||||
)
|
||||
query_outline = recent_outlines[-1].content if recent_outlines else ""
|
||||
memory_context = await memory_service.build_context_for_generation(
|
||||
user_id=user_id,
|
||||
project_id=project_id,
|
||||
current_chapter=current_start_chapter,
|
||||
chapter_outline=query_outline,
|
||||
character_names=[c.name for c in characters] if characters else None
|
||||
)
|
||||
logger.info(f"✅ 记忆上下文: {memory_context['stats']}")
|
||||
except Exception as e:
|
||||
logger.warning(f"⚠️ 记忆上下文构建失败: {str(e)}")
|
||||
memory_context = None
|
||||
|
||||
yield await SSEResponse.send_progress(
|
||||
f"🤖 调用AI生成第{str(batch_num + 1)}批...",
|
||||
f" 调用AI生成第{str(batch_num + 1)}批...",
|
||||
batch_progress + 5
|
||||
)
|
||||
|
||||
# 使用标准续写提示词模板
|
||||
# 使用标准续写提示词模板(支持记忆增强)
|
||||
prompt = prompt_service.get_outline_continue_prompt(
|
||||
title=project.title,
|
||||
theme=data.get("theme") or project.theme or "未设定",
|
||||
@@ -963,7 +1008,8 @@ async def continue_outline_generator(
|
||||
plot_stage_instruction=stage_instruction,
|
||||
start_chapter=current_start_chapter,
|
||||
story_direction=data.get("story_direction", "自然延续"),
|
||||
requirements=data.get("requirements", "")
|
||||
requirements=data.get("requirements", ""),
|
||||
memory_context=memory_context
|
||||
)
|
||||
|
||||
# 调用AI生成当前批次
|
||||
@@ -1062,6 +1108,7 @@ async def continue_outline_generator(
|
||||
@router.post("/generate-stream", summary="AI生成/续写大纲(SSE流式)")
|
||||
async def generate_outline_stream(
|
||||
data: Dict[str, Any],
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
user_ai_service: AIService = Depends(get_user_ai_service)
|
||||
):
|
||||
@@ -1111,6 +1158,9 @@ async def generate_outline_stream(
|
||||
mode = "continue" if existing_outlines else "new"
|
||||
logger.info(f"自动判断模式:{'续写' if existing_outlines else '新建'}")
|
||||
|
||||
# 获取用户ID
|
||||
user_id = getattr(request.state, "user_id", "system")
|
||||
|
||||
# 根据模式选择生成器
|
||||
if mode == "new":
|
||||
return create_sse_response(new_outline_generator(data, db, user_ai_service))
|
||||
@@ -1120,7 +1170,7 @@ async def generate_outline_stream(
|
||||
status_code=400,
|
||||
detail="续写模式需要已有大纲,当前项目没有大纲"
|
||||
)
|
||||
return create_sse_response(continue_outline_generator(data, db, user_ai_service))
|
||||
return create_sse_response(continue_outline_generator(data, db, user_ai_service, user_id))
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
|
||||
+174
-2
@@ -1,9 +1,11 @@
|
||||
"""项目管理API"""
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi import APIRouter, Depends, HTTPException, UploadFile, File
|
||||
from fastapi.responses import Response
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, func, delete
|
||||
from typing import List
|
||||
import json
|
||||
from urllib.parse import quote
|
||||
from app.database import get_db
|
||||
from app.models.project import Project
|
||||
from app.models.character import Character
|
||||
@@ -17,6 +19,12 @@ from app.schemas.project import (
|
||||
ProjectResponse,
|
||||
ProjectListResponse
|
||||
)
|
||||
from app.schemas.import_export import (
|
||||
ExportOptions,
|
||||
ImportValidationResult,
|
||||
ImportResult
|
||||
)
|
||||
from app.services.import_export_service import ImportExportService
|
||||
from app.logger import get_logger
|
||||
from app.utils.data_consistency import (
|
||||
run_full_data_consistency_check,
|
||||
@@ -412,4 +420,168 @@ async def fix_project_member_counts(
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"修复成员计数失败: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"修复失败: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"修复失败: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/{project_id}/export-data", summary="导出项目数据为JSON")
|
||||
async def export_project_data(
|
||||
project_id: str,
|
||||
options: ExportOptions,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
导出项目完整数据为JSON格式
|
||||
|
||||
Args:
|
||||
project_id: 项目ID
|
||||
options: 导出选项
|
||||
|
||||
Returns:
|
||||
JSON文件下载
|
||||
"""
|
||||
try:
|
||||
logger.info(f"开始导出项目数据: {project_id}")
|
||||
|
||||
# 检查项目是否存在
|
||||
result = await db.execute(
|
||||
select(Project).where(Project.id == project_id)
|
||||
)
|
||||
project = result.scalar_one_or_none()
|
||||
|
||||
if not project:
|
||||
logger.warning(f"项目不存在: {project_id}")
|
||||
raise HTTPException(status_code=404, detail="项目不存在")
|
||||
|
||||
# 导出数据
|
||||
export_data = await ImportExportService.export_project(
|
||||
project_id=project_id,
|
||||
db=db,
|
||||
include_generation_history=options.include_generation_history,
|
||||
include_writing_styles=options.include_writing_styles
|
||||
)
|
||||
|
||||
# 转换为JSON
|
||||
json_content = export_data.model_dump_json(indent=2, exclude_none=True, by_alias=True)
|
||||
|
||||
# 生成文件名
|
||||
safe_title = "".join(c for c in project.title if c.isalnum() or c in (' ', '-', '_'))
|
||||
from datetime import datetime
|
||||
date_str = datetime.now().strftime("%Y%m%d")
|
||||
filename = f"project_{safe_title}_{date_str}.json"
|
||||
encoded_filename = quote(filename)
|
||||
|
||||
logger.info(f"项目数据导出成功: {filename}")
|
||||
|
||||
return Response(
|
||||
content=json_content.encode('utf-8'),
|
||||
media_type="application/json; charset=utf-8",
|
||||
headers={
|
||||
"Content-Disposition": f"attachment; filename*=UTF-8''{encoded_filename}",
|
||||
"Content-Type": "application/json; charset=utf-8"
|
||||
}
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"导出项目数据失败: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"导出失败: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/validate-import", response_model=ImportValidationResult, summary="验证导入文件")
|
||||
async def validate_import_file(
|
||||
file: UploadFile = File(...)
|
||||
):
|
||||
"""
|
||||
验证导入文件的格式和内容
|
||||
|
||||
Args:
|
||||
file: 上传的JSON文件
|
||||
|
||||
Returns:
|
||||
验证结果
|
||||
"""
|
||||
try:
|
||||
logger.info(f"验证导入文件: {file.filename}")
|
||||
|
||||
# 检查文件类型
|
||||
if not file.filename.endswith('.json'):
|
||||
raise HTTPException(status_code=400, detail="只支持JSON格式文件")
|
||||
|
||||
# 读取文件内容
|
||||
content = await file.read()
|
||||
|
||||
# 检查文件大小(50MB限制)
|
||||
max_size = 50 * 1024 * 1024 # 50MB
|
||||
if len(content) > max_size:
|
||||
raise HTTPException(status_code=413, detail="文件大小超过50MB限制")
|
||||
|
||||
# 解析JSON
|
||||
try:
|
||||
data = json.loads(content.decode('utf-8'))
|
||||
except json.JSONDecodeError as e:
|
||||
raise HTTPException(status_code=400, detail=f"无效的JSON格式: {str(e)}")
|
||||
|
||||
# 验证数据
|
||||
validation_result = ImportExportService.validate_import_data(data)
|
||||
|
||||
logger.info(f"文件验证完成: valid={validation_result.valid}")
|
||||
return validation_result
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"验证导入文件失败: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"验证失败: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/import", response_model=ImportResult, summary="导入项目")
|
||||
async def import_project(
|
||||
file: UploadFile = File(...),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
导入项目数据(创建新项目)
|
||||
|
||||
Args:
|
||||
file: 上传的JSON文件
|
||||
|
||||
Returns:
|
||||
导入结果
|
||||
"""
|
||||
try:
|
||||
logger.info(f"开始导入项目: {file.filename}")
|
||||
|
||||
# 检查文件类型
|
||||
if not file.filename.endswith('.json'):
|
||||
raise HTTPException(status_code=400, detail="只支持JSON格式文件")
|
||||
|
||||
# 读取文件内容
|
||||
content = await file.read()
|
||||
|
||||
# 检查文件大小
|
||||
max_size = 50 * 1024 * 1024 # 50MB
|
||||
if len(content) > max_size:
|
||||
raise HTTPException(status_code=413, detail="文件大小超过50MB限制")
|
||||
|
||||
# 解析JSON
|
||||
try:
|
||||
data = json.loads(content.decode('utf-8'))
|
||||
except json.JSONDecodeError as e:
|
||||
raise HTTPException(status_code=400, detail=f"无效的JSON格式: {str(e)}")
|
||||
|
||||
# 导入数据
|
||||
import_result = await ImportExportService.import_project(data, db)
|
||||
|
||||
if import_result.success:
|
||||
logger.info(f"项目导入成功: {import_result.project_id}")
|
||||
else:
|
||||
logger.warning(f"项目导入失败: {import_result.message}")
|
||||
|
||||
return import_result
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"导入项目失败: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"导入失败: {str(e)}")
|
||||
@@ -15,6 +15,15 @@ logger = get_logger(__name__)
|
||||
# 创建基类
|
||||
Base = declarative_base()
|
||||
|
||||
# 导入所有模型,确保 Base.metadata 能够发现它们
|
||||
# 这必须在 Base 创建之后、init_db 之前导入
|
||||
from app.models import (
|
||||
Project, Outline, Character, Chapter, GenerationHistory,
|
||||
Settings, WritingStyle, ProjectDefaultStyle,
|
||||
RelationshipType, CharacterRelationship, Organization, OrganizationMember,
|
||||
StoryMemory, PlotAnalysis, AnalysisTask
|
||||
)
|
||||
|
||||
# 引擎缓存:每个用户一个引擎
|
||||
_engine_cache: Dict[str, Any] = {}
|
||||
|
||||
|
||||
+2
-1
@@ -114,7 +114,7 @@ async def db_session_stats():
|
||||
from app.api import (
|
||||
projects, outlines, characters, chapters,
|
||||
wizard_stream, relationships, organizations,
|
||||
auth, users, settings, writing_styles
|
||||
auth, users, settings, writing_styles, memories
|
||||
)
|
||||
|
||||
app.include_router(auth.router, prefix="/api")
|
||||
@@ -129,6 +129,7 @@ app.include_router(chapters.router, prefix="/api")
|
||||
app.include_router(relationships.router, prefix="/api")
|
||||
app.include_router(organizations.router, prefix="/api")
|
||||
app.include_router(writing_styles.router, prefix="/api")
|
||||
app.include_router(memories.router) # 记忆管理API (已包含/api前缀)
|
||||
|
||||
static_dir = Path(__file__).parent.parent / "static"
|
||||
if static_dir.exists():
|
||||
|
||||
@@ -13,6 +13,8 @@ from app.models.relationship import (
|
||||
Organization,
|
||||
OrganizationMember
|
||||
)
|
||||
from app.models.memory import StoryMemory, PlotAnalysis
|
||||
from app.models.analysis_task import AnalysisTask
|
||||
|
||||
__all__ = [
|
||||
"Project",
|
||||
@@ -27,4 +29,7 @@ __all__ = [
|
||||
"CharacterRelationship",
|
||||
"Organization",
|
||||
"OrganizationMember",
|
||||
"StoryMemory",
|
||||
"PlotAnalysis",
|
||||
"AnalysisTask",
|
||||
]
|
||||
@@ -0,0 +1,38 @@
|
||||
"""分析任务模型 - 追踪异步章节分析任务状态"""
|
||||
from sqlalchemy import Column, String, Integer, Text, DateTime, ForeignKey, Index
|
||||
from sqlalchemy.sql import func
|
||||
from app.database import Base
|
||||
import uuid
|
||||
|
||||
|
||||
class AnalysisTask(Base):
|
||||
"""
|
||||
分析任务表 - 追踪异步分析任务的执行状态
|
||||
|
||||
状态流转: pending -> running -> completed/failed
|
||||
"""
|
||||
__tablename__ = "analysis_tasks"
|
||||
|
||||
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()), comment="任务ID")
|
||||
chapter_id = Column(String(36), ForeignKey('chapters.id', ondelete='CASCADE'), nullable=False, comment="章节ID")
|
||||
user_id = Column(String(50), nullable=False, comment="用户ID")
|
||||
project_id = Column(String(36), nullable=False, comment="项目ID")
|
||||
|
||||
# 任务状态
|
||||
status = Column(String(20), nullable=False, default='pending', comment="任务状态: pending/running/completed/failed")
|
||||
progress = Column(Integer, default=0, comment="进度 0-100")
|
||||
error_message = Column(Text, nullable=True, comment="错误信息")
|
||||
|
||||
# 时间戳
|
||||
created_at = Column(DateTime, server_default=func.now(), comment="创建时间")
|
||||
started_at = Column(DateTime, nullable=True, comment="开始执行时间")
|
||||
completed_at = Column(DateTime, nullable=True, comment="完成时间")
|
||||
|
||||
# 索引优化查询
|
||||
__table_args__ = (
|
||||
Index('idx_chapter_id_created', 'chapter_id', 'created_at'),
|
||||
Index('idx_status', 'status'),
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<AnalysisTask(id={self.id[:8]}..., chapter_id={self.chapter_id[:8]}..., status={self.status})>"
|
||||
@@ -0,0 +1,200 @@
|
||||
"""长期记忆数据模型 - 支持向量检索和剧情分析"""
|
||||
from sqlalchemy import Column, String, Text, Integer, DateTime, ForeignKey, Float, JSON, Boolean
|
||||
from sqlalchemy.sql import func
|
||||
from app.database import Base
|
||||
import uuid
|
||||
|
||||
|
||||
class StoryMemory(Base):
|
||||
"""故事记忆表 - 存储结构化的故事片段和元数据"""
|
||||
__tablename__ = "story_memories"
|
||||
|
||||
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
|
||||
project_id = Column(String(36), ForeignKey("projects.id", ondelete="CASCADE"), nullable=False, index=True)
|
||||
chapter_id = Column(String(36), ForeignKey("chapters.id", ondelete="CASCADE"), nullable=True, index=True)
|
||||
|
||||
# 记忆类型
|
||||
memory_type = Column(String(50), nullable=False, index=True, comment="""
|
||||
记忆类型:
|
||||
- plot_point: 情节点
|
||||
- character_event: 角色事件
|
||||
- world_detail: 世界观细节
|
||||
- hook: 钩子(悬念/冲突)
|
||||
- foreshadow: 伏笔
|
||||
- dialogue: 重要对话
|
||||
- scene: 场景描写
|
||||
""")
|
||||
|
||||
# 记忆内容
|
||||
title = Column(String(200), comment="记忆标题/简述")
|
||||
content = Column(Text, nullable=False, comment="记忆内容摘要(100-500字)")
|
||||
full_context = Column(Text, comment="完整上下文(可选,用于详细记录)")
|
||||
|
||||
# 关联信息
|
||||
related_characters = Column(JSON, comment="涉及角色ID列表: ['char_id_1', 'char_id_2']")
|
||||
related_locations = Column(JSON, comment="涉及地点列表: ['地点1', '地点2']")
|
||||
tags = Column(JSON, comment="标签列表: ['悬念', '转折', '伏笔', '高潮']")
|
||||
|
||||
# 重要性评分 (用于过滤和排序)
|
||||
importance_score = Column(Float, default=0.5, comment="重要性评分 0.0-1.0")
|
||||
|
||||
# 时间线定位
|
||||
story_timeline = Column(Integer, nullable=False, index=True, comment="故事时间线位置(章节序号)")
|
||||
chapter_position = Column(Integer, default=0, comment="章节内位置(字符位置)")
|
||||
text_length = Column(Integer, default=0, comment="文本长度(字符数)")
|
||||
|
||||
# 伏笔相关字段
|
||||
is_foreshadow = Column(Integer, default=0, comment="伏笔状态: 0=普通记忆, 1=已埋下伏笔, 2=伏笔已回收")
|
||||
foreshadow_resolved_at = Column(String(36), ForeignKey("chapters.id", ondelete="SET NULL"), comment="伏笔回收的章节ID")
|
||||
foreshadow_strength = Column(Float, comment="伏笔强度 0.0-1.0")
|
||||
|
||||
# 向量数据库关联
|
||||
vector_id = Column(String(100), unique=True, comment="向量数据库中的唯一ID")
|
||||
embedding_model = Column(String(100), default="paraphrase-multilingual-MiniLM-L12-v2", comment="使用的embedding模型")
|
||||
|
||||
created_at = Column(DateTime, server_default=func.now(), comment="创建时间")
|
||||
updated_at = Column(DateTime, server_default=func.now(), onupdate=func.now(), comment="更新时间")
|
||||
|
||||
def __repr__(self):
|
||||
return f"<StoryMemory(id={self.id[:8]}, type={self.memory_type}, title={self.title})>"
|
||||
|
||||
def to_dict(self):
|
||||
"""转换为字典格式"""
|
||||
return {
|
||||
"id": self.id,
|
||||
"project_id": self.project_id,
|
||||
"chapter_id": self.chapter_id,
|
||||
"memory_type": self.memory_type,
|
||||
"title": self.title,
|
||||
"content": self.content,
|
||||
"related_characters": self.related_characters,
|
||||
"related_locations": self.related_locations,
|
||||
"tags": self.tags,
|
||||
"importance_score": self.importance_score,
|
||||
"story_timeline": self.story_timeline,
|
||||
"is_foreshadow": self.is_foreshadow,
|
||||
"created_at": self.created_at.isoformat() if self.created_at else None
|
||||
}
|
||||
|
||||
|
||||
class PlotAnalysis(Base):
|
||||
"""剧情分析表 - 存储AI分析的章节结构和剧情元素"""
|
||||
__tablename__ = "plot_analysis"
|
||||
|
||||
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
|
||||
project_id = Column(String(36), ForeignKey("projects.id", ondelete="CASCADE"), nullable=False, index=True)
|
||||
chapter_id = Column(String(36), ForeignKey("chapters.id", ondelete="CASCADE"), nullable=False, unique=True, index=True)
|
||||
|
||||
# 剧情结构分析
|
||||
plot_stage = Column(String(50), comment="剧情阶段: 开端/发展/高潮/结局/过渡")
|
||||
conflict_level = Column(Integer, comment="冲突强度 1-10")
|
||||
conflict_types = Column(JSON, comment="冲突类型列表: ['人与人', '人与己', '人与环境']")
|
||||
|
||||
# 情感分析
|
||||
emotional_tone = Column(String(100), comment="主导情感: 紧张/温馨/悲伤/激昂/平静")
|
||||
emotional_intensity = Column(Float, comment="情感强度 0.0-1.0")
|
||||
emotional_curve = Column(JSON, comment="情感曲线: {start: 0.3, middle: 0.7, end: 0.5}")
|
||||
|
||||
# 钩子分析 (Hook Analysis)
|
||||
hooks = Column(JSON, comment="""钩子列表 - 吸引读者的元素: [
|
||||
{
|
||||
"type": "悬念|情感|冲突|认知",
|
||||
"content": "具体内容",
|
||||
"strength": 8,
|
||||
"position": "开头|中段|结尾"
|
||||
}
|
||||
]""")
|
||||
hooks_count = Column(Integer, default=0, comment="钩子数量")
|
||||
hooks_avg_strength = Column(Float, comment="钩子平均强度")
|
||||
|
||||
# 伏笔分析 (Foreshadowing Analysis)
|
||||
foreshadows = Column(JSON, comment="""伏笔列表: [
|
||||
{
|
||||
"content": "伏笔内容",
|
||||
"type": "planted|resolved",
|
||||
"strength": 7,
|
||||
"subtlety": 8,
|
||||
"reference_chapter": 3
|
||||
}
|
||||
]""")
|
||||
foreshadows_planted = Column(Integer, default=0, comment="本章埋下的伏笔数量")
|
||||
foreshadows_resolved = Column(Integer, default=0, comment="本章回收的伏笔数量")
|
||||
|
||||
# 关键情节点 (Plot Points)
|
||||
plot_points = Column(JSON, comment="""情节点列表: [
|
||||
{
|
||||
"content": "情节点描述",
|
||||
"importance": 0.9,
|
||||
"type": "revelation|conflict|resolution|transition",
|
||||
"impact": "对故事的影响描述"
|
||||
}
|
||||
]""")
|
||||
plot_points_count = Column(Integer, default=0, comment="情节点数量")
|
||||
|
||||
# 角色状态追踪 (Character State Tracking)
|
||||
character_states = Column(JSON, comment="""角色状态变化: [
|
||||
{
|
||||
"character_id": "xxx",
|
||||
"character_name": "张三",
|
||||
"state_before": "犹豫不决",
|
||||
"state_after": "坚定信念",
|
||||
"psychological_change": "内心描述",
|
||||
"key_event": "触发事件",
|
||||
"relationship_changes": {"李四": "关系变化"}
|
||||
}
|
||||
]""")
|
||||
|
||||
# 场景和氛围
|
||||
scenes = Column(JSON, comment="场景列表: [{location: '地点', atmosphere: '氛围', duration: '时长'}]")
|
||||
pacing = Column(String(50), comment="节奏: slow|moderate|fast|varied")
|
||||
|
||||
# 质量评分
|
||||
overall_quality_score = Column(Float, comment="整体质量评分 0.0-10.0")
|
||||
pacing_score = Column(Float, comment="节奏评分 0.0-10.0")
|
||||
engagement_score = Column(Float, comment="吸引力评分 0.0-10.0")
|
||||
coherence_score = Column(Float, comment="连贯性评分 0.0-10.0")
|
||||
|
||||
# 文本分析报告
|
||||
analysis_report = Column(Text, comment="完整的文字分析报告")
|
||||
suggestions = Column(JSON, comment="改进建议列表: ['建议1', '建议2']")
|
||||
|
||||
# 统计信息
|
||||
word_count = Column(Integer, comment="章节字数")
|
||||
dialogue_ratio = Column(Float, comment="对话占比 0.0-1.0")
|
||||
description_ratio = Column(Float, comment="描写占比 0.0-1.0")
|
||||
|
||||
created_at = Column(DateTime, server_default=func.now(), comment="分析时间")
|
||||
|
||||
def __repr__(self):
|
||||
return f"<PlotAnalysis(chapter_id={self.chapter_id[:8]}, stage={self.plot_stage}, quality={self.overall_quality_score})>"
|
||||
|
||||
def to_dict(self):
|
||||
"""转换为字典格式"""
|
||||
return {
|
||||
"id": self.id,
|
||||
"chapter_id": self.chapter_id,
|
||||
"plot_stage": self.plot_stage,
|
||||
"conflict_level": self.conflict_level,
|
||||
"conflict_types": self.conflict_types or [],
|
||||
"emotional_tone": self.emotional_tone,
|
||||
"emotional_intensity": self.emotional_intensity or 0.0,
|
||||
"hooks": self.hooks or [],
|
||||
"hooks_count": self.hooks_count or 0,
|
||||
"foreshadows": self.foreshadows or [],
|
||||
"foreshadows_planted": self.foreshadows_planted or 0,
|
||||
"foreshadows_resolved": self.foreshadows_resolved or 0,
|
||||
"plot_points": self.plot_points or [],
|
||||
"plot_points_count": self.plot_points_count or 0,
|
||||
"character_states": self.character_states or [],
|
||||
"scenes": self.scenes or [],
|
||||
"pacing": self.pacing,
|
||||
"overall_quality_score": self.overall_quality_score or 0.0,
|
||||
"pacing_score": self.pacing_score or 0.0,
|
||||
"engagement_score": self.engagement_score or 0.0,
|
||||
"coherence_score": self.coherence_score or 0.0,
|
||||
"analysis_report": self.analysis_report,
|
||||
"suggestions": self.suggestions or [],
|
||||
"dialogue_ratio": self.dialogue_ratio or 0.0,
|
||||
"description_ratio": self.description_ratio or 0.0,
|
||||
"created_at": self.created_at.isoformat() if self.created_at else None
|
||||
}
|
||||
@@ -0,0 +1,136 @@
|
||||
"""导入导出相关的Pydantic模型"""
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Optional, List, Dict, Any
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
class ExportOptions(BaseModel):
|
||||
"""导出选项"""
|
||||
include_generation_history: bool = Field(False, description="是否包含生成历史")
|
||||
include_writing_styles: bool = Field(True, description="是否包含写作风格")
|
||||
|
||||
|
||||
class ChapterExportData(BaseModel):
|
||||
"""章节导出数据"""
|
||||
title: str
|
||||
content: Optional[str] = None
|
||||
summary: Optional[str] = None
|
||||
chapter_number: int
|
||||
word_count: int = 0
|
||||
status: str = "draft"
|
||||
created_at: Optional[str] = None
|
||||
|
||||
|
||||
class CharacterExportData(BaseModel):
|
||||
"""角色导出数据"""
|
||||
name: str
|
||||
age: Optional[str] = None
|
||||
gender: Optional[str] = None
|
||||
is_organization: bool = False
|
||||
role_type: Optional[str] = None
|
||||
personality: Optional[str] = None
|
||||
background: Optional[str] = None
|
||||
appearance: Optional[str] = None
|
||||
traits: Optional[List[str]] = None
|
||||
organization_type: Optional[str] = None
|
||||
organization_purpose: Optional[str] = None
|
||||
created_at: Optional[str] = None
|
||||
|
||||
|
||||
class OutlineExportData(BaseModel):
|
||||
"""大纲导出数据"""
|
||||
title: str
|
||||
content: Optional[str] = None
|
||||
structure: Optional[str] = None
|
||||
order_index: Optional[int] = None
|
||||
created_at: Optional[str] = None
|
||||
|
||||
|
||||
class RelationshipExportData(BaseModel):
|
||||
"""关系导出数据"""
|
||||
source_name: str
|
||||
target_name: str
|
||||
relationship_name: Optional[str] = None
|
||||
intimacy_level: int = 50
|
||||
status: str = "active"
|
||||
description: Optional[str] = None
|
||||
started_at: Optional[str] = None
|
||||
|
||||
|
||||
class OrganizationExportData(BaseModel):
|
||||
"""组织详情导出数据"""
|
||||
character_name: str
|
||||
parent_org_name: Optional[str] = None
|
||||
power_level: int = 50
|
||||
member_count: int = 0
|
||||
location: Optional[str] = None
|
||||
motto: Optional[str] = None
|
||||
color: Optional[str] = None
|
||||
|
||||
|
||||
class OrganizationMemberExportData(BaseModel):
|
||||
"""组织成员导出数据"""
|
||||
organization_name: str
|
||||
character_name: str
|
||||
position: str
|
||||
rank: int = 0
|
||||
status: str = "active"
|
||||
joined_at: Optional[str] = None
|
||||
loyalty: int = 50
|
||||
contribution: int = 0
|
||||
notes: Optional[str] = None
|
||||
|
||||
|
||||
class WritingStyleExportData(BaseModel):
|
||||
"""写作风格导出数据"""
|
||||
name: str
|
||||
style_type: str
|
||||
preset_id: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
prompt_content: str
|
||||
order_index: int = 0
|
||||
|
||||
|
||||
class GenerationHistoryExportData(BaseModel):
|
||||
"""生成历史导出数据"""
|
||||
chapter_title: Optional[str] = None
|
||||
prompt: Optional[str] = None
|
||||
generated_content: Optional[str] = None
|
||||
model: Optional[str] = None
|
||||
tokens_used: Optional[int] = None
|
||||
generation_time: Optional[float] = None
|
||||
created_at: Optional[str] = None
|
||||
|
||||
|
||||
class ProjectExportData(BaseModel):
|
||||
"""项目完整导出数据"""
|
||||
version: str = "1.0.0"
|
||||
export_time: str
|
||||
project: Dict[str, Any]
|
||||
chapters: List[ChapterExportData] = []
|
||||
characters: List[CharacterExportData] = []
|
||||
outlines: List[OutlineExportData] = []
|
||||
relationships: List[RelationshipExportData] = []
|
||||
organizations: List[OrganizationExportData] = []
|
||||
organization_members: List[OrganizationMemberExportData] = []
|
||||
writing_styles: List[WritingStyleExportData] = []
|
||||
generation_history: List[GenerationHistoryExportData] = []
|
||||
|
||||
|
||||
class ImportValidationResult(BaseModel):
|
||||
"""导入验证结果"""
|
||||
valid: bool
|
||||
version: str
|
||||
project_name: Optional[str] = None
|
||||
statistics: Dict[str, int] = {}
|
||||
errors: List[str] = []
|
||||
warnings: List[str] = []
|
||||
|
||||
|
||||
class ImportResult(BaseModel):
|
||||
"""导入结果"""
|
||||
success: bool
|
||||
project_id: Optional[str] = None
|
||||
message: str
|
||||
statistics: Dict[str, int] = {}
|
||||
warnings: List[str] = []
|
||||
@@ -0,0 +1,769 @@
|
||||
"""导入导出服务"""
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select
|
||||
from app.models.project import Project
|
||||
from app.models.chapter import Chapter
|
||||
from app.models.character import Character
|
||||
from app.models.outline import Outline
|
||||
from app.models.relationship import CharacterRelationship, Organization, OrganizationMember
|
||||
from app.models.writing_style import WritingStyle
|
||||
from app.models.generation_history import GenerationHistory
|
||||
from app.schemas.import_export import (
|
||||
ProjectExportData,
|
||||
ChapterExportData,
|
||||
CharacterExportData,
|
||||
OutlineExportData,
|
||||
RelationshipExportData,
|
||||
OrganizationExportData,
|
||||
OrganizationMemberExportData,
|
||||
WritingStyleExportData,
|
||||
GenerationHistoryExportData,
|
||||
ImportValidationResult,
|
||||
ImportResult
|
||||
)
|
||||
from app.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class ImportExportService:
|
||||
"""导入导出服务类"""
|
||||
|
||||
SUPPORTED_VERSION = "1.0.0"
|
||||
|
||||
@staticmethod
|
||||
async def export_project(
|
||||
project_id: str,
|
||||
db: AsyncSession,
|
||||
include_generation_history: bool = False,
|
||||
include_writing_styles: bool = True
|
||||
) -> ProjectExportData:
|
||||
"""
|
||||
导出项目完整数据
|
||||
|
||||
Args:
|
||||
project_id: 项目ID
|
||||
db: 数据库会话
|
||||
include_generation_history: 是否包含生成历史
|
||||
include_writing_styles: 是否包含写作风格
|
||||
|
||||
Returns:
|
||||
ProjectExportData: 导出的项目数据
|
||||
"""
|
||||
logger.info(f"开始导出项目: {project_id}")
|
||||
|
||||
# 获取项目基本信息
|
||||
result = await db.execute(select(Project).where(Project.id == project_id))
|
||||
project = result.scalar_one_or_none()
|
||||
if not project:
|
||||
raise ValueError(f"项目不存在: {project_id}")
|
||||
|
||||
# 项目基本信息
|
||||
project_data = {
|
||||
"title": project.title,
|
||||
"description": project.description,
|
||||
"theme": project.theme,
|
||||
"genre": project.genre,
|
||||
"target_words": project.target_words,
|
||||
"current_words": project.current_words,
|
||||
"status": project.status,
|
||||
"world_time_period": project.world_time_period,
|
||||
"world_location": project.world_location,
|
||||
"world_atmosphere": project.world_atmosphere,
|
||||
"world_rules": project.world_rules,
|
||||
"chapter_count": project.chapter_count,
|
||||
"narrative_perspective": project.narrative_perspective,
|
||||
"character_count": project.character_count,
|
||||
"created_at": project.created_at.isoformat() if project.created_at else None,
|
||||
}
|
||||
|
||||
# 导出章节
|
||||
chapters = await ImportExportService._export_chapters(project_id, db)
|
||||
logger.info(f"导出章节数: {len(chapters)}")
|
||||
|
||||
# 导出角色
|
||||
characters = await ImportExportService._export_characters(project_id, db)
|
||||
logger.info(f"导出角色数: {len(characters)}")
|
||||
|
||||
# 导出大纲
|
||||
outlines = await ImportExportService._export_outlines(project_id, db)
|
||||
logger.info(f"导出大纲数: {len(outlines)}")
|
||||
|
||||
# 导出关系
|
||||
relationships = await ImportExportService._export_relationships(project_id, db)
|
||||
logger.info(f"导出关系数: {len(relationships)}")
|
||||
|
||||
# 导出组织详情
|
||||
organizations = await ImportExportService._export_organizations(project_id, db)
|
||||
logger.info(f"导出组织数: {len(organizations)}")
|
||||
|
||||
# 导出组织成员
|
||||
org_members = await ImportExportService._export_organization_members(project_id, db)
|
||||
logger.info(f"导出组织成员数: {len(org_members)}")
|
||||
|
||||
# 导出写作风格(可选)
|
||||
writing_styles = []
|
||||
if include_writing_styles:
|
||||
writing_styles = await ImportExportService._export_writing_styles(project_id, db)
|
||||
logger.info(f"导出写作风格数: {len(writing_styles)}")
|
||||
|
||||
# 导出生成历史(可选)
|
||||
generation_history = []
|
||||
if include_generation_history:
|
||||
generation_history = await ImportExportService._export_generation_history(project_id, db)
|
||||
logger.info(f"导出生成历史数: {len(generation_history)}")
|
||||
|
||||
export_data = ProjectExportData(
|
||||
version=ImportExportService.SUPPORTED_VERSION,
|
||||
export_time=datetime.utcnow().isoformat(),
|
||||
project=project_data,
|
||||
chapters=chapters,
|
||||
characters=characters,
|
||||
outlines=outlines,
|
||||
relationships=relationships,
|
||||
organizations=organizations,
|
||||
organization_members=org_members,
|
||||
writing_styles=writing_styles,
|
||||
generation_history=generation_history
|
||||
)
|
||||
|
||||
logger.info(f"项目导出完成: {project_id}")
|
||||
return export_data
|
||||
|
||||
@staticmethod
|
||||
async def _export_chapters(project_id: str, db: AsyncSession) -> List[ChapterExportData]:
|
||||
"""导出章节"""
|
||||
result = await db.execute(
|
||||
select(Chapter)
|
||||
.where(Chapter.project_id == project_id)
|
||||
.order_by(Chapter.chapter_number)
|
||||
)
|
||||
chapters = result.scalars().all()
|
||||
|
||||
return [
|
||||
ChapterExportData(
|
||||
title=ch.title,
|
||||
content=ch.content,
|
||||
summary=ch.summary,
|
||||
chapter_number=ch.chapter_number,
|
||||
word_count=ch.word_count or 0,
|
||||
status=ch.status,
|
||||
created_at=ch.created_at.isoformat() if ch.created_at else None
|
||||
)
|
||||
for ch in chapters
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
async def _export_characters(project_id: str, db: AsyncSession) -> List[CharacterExportData]:
|
||||
"""导出角色"""
|
||||
result = await db.execute(
|
||||
select(Character).where(Character.project_id == project_id)
|
||||
)
|
||||
characters = result.scalars().all()
|
||||
|
||||
exported = []
|
||||
for char in characters:
|
||||
# 解析traits JSON
|
||||
traits = None
|
||||
if char.traits:
|
||||
try:
|
||||
traits = json.loads(char.traits) if isinstance(char.traits, str) else char.traits
|
||||
except:
|
||||
traits = None
|
||||
|
||||
exported.append(CharacterExportData(
|
||||
name=char.name,
|
||||
age=char.age,
|
||||
gender=char.gender,
|
||||
is_organization=char.is_organization or False,
|
||||
role_type=char.role_type,
|
||||
personality=char.personality,
|
||||
background=char.background,
|
||||
appearance=char.appearance,
|
||||
traits=traits,
|
||||
organization_type=char.organization_type,
|
||||
organization_purpose=char.organization_purpose,
|
||||
created_at=char.created_at.isoformat() if char.created_at else None
|
||||
))
|
||||
|
||||
return exported
|
||||
|
||||
@staticmethod
|
||||
async def _export_outlines(project_id: str, db: AsyncSession) -> List[OutlineExportData]:
|
||||
"""导出大纲"""
|
||||
result = await db.execute(
|
||||
select(Outline)
|
||||
.where(Outline.project_id == project_id)
|
||||
.order_by(Outline.order_index)
|
||||
)
|
||||
outlines = result.scalars().all()
|
||||
|
||||
return [
|
||||
OutlineExportData(
|
||||
title=ol.title,
|
||||
content=ol.content,
|
||||
structure=ol.structure,
|
||||
order_index=ol.order_index,
|
||||
created_at=ol.created_at.isoformat() if ol.created_at else None
|
||||
)
|
||||
for ol in outlines
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
async def _export_relationships(project_id: str, db: AsyncSession) -> List[RelationshipExportData]:
|
||||
"""导出关系"""
|
||||
result = await db.execute(
|
||||
select(CharacterRelationship, Character)
|
||||
.join(Character, CharacterRelationship.character_from_id == Character.id)
|
||||
.where(CharacterRelationship.project_id == project_id)
|
||||
)
|
||||
relationships = result.all()
|
||||
|
||||
exported = []
|
||||
for rel, char_from in relationships:
|
||||
# 获取目标角色名称
|
||||
target_result = await db.execute(
|
||||
select(Character).where(Character.id == rel.character_to_id)
|
||||
)
|
||||
char_to = target_result.scalar_one_or_none()
|
||||
|
||||
if char_to:
|
||||
exported.append(RelationshipExportData(
|
||||
source_name=char_from.name,
|
||||
target_name=char_to.name,
|
||||
relationship_name=rel.relationship_name,
|
||||
intimacy_level=rel.intimacy_level or 50,
|
||||
status=rel.status or "active",
|
||||
description=rel.description,
|
||||
started_at=rel.started_at
|
||||
))
|
||||
|
||||
return exported
|
||||
|
||||
@staticmethod
|
||||
async def _export_organizations(project_id: str, db: AsyncSession) -> List[OrganizationExportData]:
|
||||
"""导出组织详情"""
|
||||
result = await db.execute(
|
||||
select(Organization, Character)
|
||||
.join(Character, Organization.character_id == Character.id)
|
||||
.where(Organization.project_id == project_id)
|
||||
)
|
||||
organizations = result.all()
|
||||
|
||||
exported = []
|
||||
for org, char in organizations:
|
||||
# 获取父组织名称
|
||||
parent_name = None
|
||||
if org.parent_org_id:
|
||||
parent_result = await db.execute(
|
||||
select(Organization, Character)
|
||||
.join(Character, Organization.character_id == Character.id)
|
||||
.where(Organization.id == org.parent_org_id)
|
||||
)
|
||||
parent_data = parent_result.first()
|
||||
if parent_data:
|
||||
parent_name = parent_data[1].name
|
||||
|
||||
exported.append(OrganizationExportData(
|
||||
character_name=char.name,
|
||||
parent_org_name=parent_name,
|
||||
power_level=org.power_level or 50,
|
||||
member_count=org.member_count or 0,
|
||||
location=org.location,
|
||||
motto=org.motto,
|
||||
color=org.color
|
||||
))
|
||||
|
||||
return exported
|
||||
|
||||
@staticmethod
|
||||
async def _export_organization_members(project_id: str, db: AsyncSession) -> List[OrganizationMemberExportData]:
|
||||
"""导出组织成员"""
|
||||
result = await db.execute(
|
||||
select(OrganizationMember, Organization, Character)
|
||||
.join(Organization, OrganizationMember.organization_id == Organization.id)
|
||||
.join(Character, Organization.character_id == Character.id)
|
||||
.where(Organization.project_id == project_id)
|
||||
)
|
||||
members = result.all()
|
||||
|
||||
exported = []
|
||||
for member, org, org_char in members:
|
||||
# 获取成员角色名称
|
||||
char_result = await db.execute(
|
||||
select(Character).where(Character.id == member.character_id)
|
||||
)
|
||||
member_char = char_result.scalar_one_or_none()
|
||||
|
||||
if member_char:
|
||||
exported.append(OrganizationMemberExportData(
|
||||
organization_name=org_char.name,
|
||||
character_name=member_char.name,
|
||||
position=member.position,
|
||||
rank=member.rank or 0,
|
||||
status=member.status or "active",
|
||||
joined_at=member.joined_at,
|
||||
loyalty=member.loyalty or 50,
|
||||
contribution=member.contribution or 0,
|
||||
notes=member.notes
|
||||
))
|
||||
|
||||
return exported
|
||||
|
||||
@staticmethod
|
||||
async def _export_writing_styles(project_id: str, db: AsyncSession) -> List[WritingStyleExportData]:
|
||||
"""导出写作风格"""
|
||||
result = await db.execute(
|
||||
select(WritingStyle)
|
||||
.where(WritingStyle.project_id == project_id)
|
||||
.order_by(WritingStyle.order_index)
|
||||
)
|
||||
styles = result.scalars().all()
|
||||
|
||||
return [
|
||||
WritingStyleExportData(
|
||||
name=style.name,
|
||||
style_type=style.style_type,
|
||||
preset_id=style.preset_id,
|
||||
description=style.description,
|
||||
prompt_content=style.prompt_content,
|
||||
order_index=style.order_index or 0
|
||||
)
|
||||
for style in styles
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
async def _export_generation_history(project_id: str, db: AsyncSession) -> List[GenerationHistoryExportData]:
|
||||
"""导出生成历史"""
|
||||
result = await db.execute(
|
||||
select(GenerationHistory, Chapter)
|
||||
.outerjoin(Chapter, GenerationHistory.chapter_id == Chapter.id)
|
||||
.where(GenerationHistory.project_id == project_id)
|
||||
.order_by(GenerationHistory.created_at.desc())
|
||||
.limit(100) # 限制最多导出100条历史记录
|
||||
)
|
||||
histories = result.all()
|
||||
|
||||
return [
|
||||
GenerationHistoryExportData(
|
||||
chapter_title=chapter.title if chapter else None,
|
||||
prompt=history.prompt,
|
||||
generated_content=history.generated_content,
|
||||
model=history.model,
|
||||
tokens_used=history.tokens_used,
|
||||
generation_time=history.generation_time,
|
||||
created_at=history.created_at.isoformat() if history.created_at else None
|
||||
)
|
||||
for history, chapter in histories
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def validate_import_data(data: Dict) -> ImportValidationResult:
|
||||
"""
|
||||
验证导入数据
|
||||
|
||||
Args:
|
||||
data: 导入的JSON数据
|
||||
|
||||
Returns:
|
||||
ImportValidationResult: 验证结果
|
||||
"""
|
||||
errors = []
|
||||
warnings = []
|
||||
statistics = {}
|
||||
|
||||
# 检查版本
|
||||
version = data.get("version", "")
|
||||
if not version:
|
||||
errors.append("缺少版本信息")
|
||||
elif version != ImportExportService.SUPPORTED_VERSION:
|
||||
warnings.append(f"版本不匹配: 导入文件版本为 {version}, 当前支持版本为 {ImportExportService.SUPPORTED_VERSION}")
|
||||
|
||||
# 检查必需字段
|
||||
if "project" not in data:
|
||||
errors.append("缺少项目信息")
|
||||
else:
|
||||
project = data["project"]
|
||||
if not project.get("title"):
|
||||
errors.append("项目标题不能为空")
|
||||
|
||||
# 统计数据
|
||||
statistics = {
|
||||
"chapters": len(data.get("chapters", [])),
|
||||
"characters": len(data.get("characters", [])),
|
||||
"outlines": len(data.get("outlines", [])),
|
||||
"relationships": len(data.get("relationships", [])),
|
||||
"organizations": len(data.get("organizations", [])),
|
||||
"organization_members": len(data.get("organization_members", [])),
|
||||
"writing_styles": len(data.get("writing_styles", [])),
|
||||
"generation_history": len(data.get("generation_history", []))
|
||||
}
|
||||
|
||||
# 检查数据完整性
|
||||
if statistics["chapters"] == 0:
|
||||
warnings.append("项目没有章节数据")
|
||||
|
||||
if statistics["characters"] == 0:
|
||||
warnings.append("项目没有角色数据")
|
||||
|
||||
project_name = data.get("project", {}).get("title", "未知项目")
|
||||
|
||||
return ImportValidationResult(
|
||||
valid=len(errors) == 0,
|
||||
version=version,
|
||||
project_name=project_name,
|
||||
statistics=statistics,
|
||||
errors=errors,
|
||||
warnings=warnings
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def import_project(
|
||||
data: Dict,
|
||||
db: AsyncSession
|
||||
) -> ImportResult:
|
||||
"""
|
||||
导入项目数据(创建新项目)
|
||||
|
||||
Args:
|
||||
data: 导入的JSON数据
|
||||
db: 数据库会话
|
||||
|
||||
Returns:
|
||||
ImportResult: 导入结果
|
||||
"""
|
||||
warnings = []
|
||||
statistics = {}
|
||||
|
||||
try:
|
||||
# 验证数据
|
||||
validation = ImportExportService.validate_import_data(data)
|
||||
if not validation.valid:
|
||||
return ImportResult(
|
||||
success=False,
|
||||
message=f"数据验证失败: {', '.join(validation.errors)}",
|
||||
statistics={},
|
||||
warnings=validation.warnings
|
||||
)
|
||||
|
||||
warnings.extend(validation.warnings)
|
||||
|
||||
logger.info(f"开始导入项目: {validation.project_name}")
|
||||
|
||||
# 创建项目
|
||||
project_data = data["project"]
|
||||
new_project = Project(
|
||||
title=project_data.get("title"),
|
||||
description=project_data.get("description"),
|
||||
theme=project_data.get("theme"),
|
||||
genre=project_data.get("genre"),
|
||||
target_words=project_data.get("target_words"),
|
||||
status=project_data.get("status", "planning"),
|
||||
world_time_period=project_data.get("world_time_period"),
|
||||
world_location=project_data.get("world_location"),
|
||||
world_atmosphere=project_data.get("world_atmosphere"),
|
||||
world_rules=project_data.get("world_rules"),
|
||||
chapter_count=project_data.get("chapter_count"),
|
||||
narrative_perspective=project_data.get("narrative_perspective"),
|
||||
character_count=project_data.get("character_count"),
|
||||
current_words=project_data.get("current_words", 0), # 保留原项目的字数
|
||||
wizard_step=4, # 导入的项目设置为向导完成状态
|
||||
wizard_status="completed" # 标记向导已完成
|
||||
)
|
||||
db.add(new_project)
|
||||
await db.flush() # 获取project_id
|
||||
|
||||
logger.info(f"创建项目成功: {new_project.id}")
|
||||
|
||||
# 导入章节
|
||||
chapters_count = await ImportExportService._import_chapters(
|
||||
new_project.id, data.get("chapters", []), db
|
||||
)
|
||||
statistics["chapters"] = chapters_count
|
||||
logger.info(f"导入章节数: {chapters_count}")
|
||||
|
||||
# 导入角色(包括组织)
|
||||
char_mapping = await ImportExportService._import_characters(
|
||||
new_project.id, data.get("characters", []), db
|
||||
)
|
||||
statistics["characters"] = len(char_mapping)
|
||||
logger.info(f"导入角色数: {len(char_mapping)}")
|
||||
|
||||
# 导入大纲
|
||||
outlines_count = await ImportExportService._import_outlines(
|
||||
new_project.id, data.get("outlines", []), db
|
||||
)
|
||||
statistics["outlines"] = outlines_count
|
||||
logger.info(f"导入大纲数: {outlines_count}")
|
||||
|
||||
# 导入关系
|
||||
relationships_count = await ImportExportService._import_relationships(
|
||||
new_project.id, data.get("relationships", []), char_mapping, db
|
||||
)
|
||||
statistics["relationships"] = relationships_count
|
||||
logger.info(f"导入关系数: {relationships_count}")
|
||||
|
||||
# 导入组织详情
|
||||
org_mapping = await ImportExportService._import_organizations(
|
||||
new_project.id, data.get("organizations", []), char_mapping, db
|
||||
)
|
||||
statistics["organizations"] = len(org_mapping)
|
||||
logger.info(f"导入组织数: {len(org_mapping)}")
|
||||
|
||||
# 导入组织成员
|
||||
org_members_count = await ImportExportService._import_organization_members(
|
||||
data.get("organization_members", []), char_mapping, org_mapping, db
|
||||
)
|
||||
statistics["organization_members"] = org_members_count
|
||||
logger.info(f"导入组织成员数: {org_members_count}")
|
||||
|
||||
# 导入写作风格
|
||||
styles_count = await ImportExportService._import_writing_styles(
|
||||
new_project.id, data.get("writing_styles", []), db
|
||||
)
|
||||
statistics["writing_styles"] = styles_count
|
||||
logger.info(f"导入写作风格数: {styles_count}")
|
||||
|
||||
# 提交事务
|
||||
await db.commit()
|
||||
|
||||
logger.info(f"项目导入完成: {new_project.id}")
|
||||
|
||||
return ImportResult(
|
||||
success=True,
|
||||
project_id=new_project.id,
|
||||
message="项目导入成功",
|
||||
statistics=statistics,
|
||||
warnings=warnings
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"导入项目失败: {str(e)}", exc_info=True)
|
||||
return ImportResult(
|
||||
success=False,
|
||||
message=f"导入失败: {str(e)}",
|
||||
statistics=statistics,
|
||||
warnings=warnings
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def _import_chapters(
|
||||
project_id: str,
|
||||
chapters_data: List[Dict],
|
||||
db: AsyncSession
|
||||
) -> int:
|
||||
"""导入章节"""
|
||||
count = 0
|
||||
for ch_data in chapters_data:
|
||||
chapter = Chapter(
|
||||
project_id=project_id,
|
||||
title=ch_data.get("title"),
|
||||
content=ch_data.get("content"),
|
||||
summary=ch_data.get("summary"),
|
||||
chapter_number=ch_data.get("chapter_number"),
|
||||
word_count=ch_data.get("word_count", 0),
|
||||
status=ch_data.get("status", "draft")
|
||||
)
|
||||
db.add(chapter)
|
||||
count += 1
|
||||
|
||||
return count
|
||||
|
||||
@staticmethod
|
||||
async def _import_characters(
|
||||
project_id: str,
|
||||
characters_data: List[Dict],
|
||||
db: AsyncSession
|
||||
) -> Dict[str, str]:
|
||||
"""导入角色,返回名称到ID的映射"""
|
||||
char_mapping = {}
|
||||
|
||||
for char_data in characters_data:
|
||||
# 处理traits
|
||||
traits = char_data.get("traits")
|
||||
if traits and isinstance(traits, list):
|
||||
traits = json.dumps(traits, ensure_ascii=False)
|
||||
|
||||
character = Character(
|
||||
project_id=project_id,
|
||||
name=char_data.get("name"),
|
||||
age=char_data.get("age"),
|
||||
gender=char_data.get("gender"),
|
||||
is_organization=char_data.get("is_organization", False),
|
||||
role_type=char_data.get("role_type"),
|
||||
personality=char_data.get("personality"),
|
||||
background=char_data.get("background"),
|
||||
appearance=char_data.get("appearance"),
|
||||
traits=traits,
|
||||
organization_type=char_data.get("organization_type"),
|
||||
organization_purpose=char_data.get("organization_purpose")
|
||||
)
|
||||
db.add(character)
|
||||
await db.flush() # 获取ID
|
||||
char_mapping[char_data.get("name")] = character.id
|
||||
|
||||
return char_mapping
|
||||
|
||||
@staticmethod
|
||||
async def _import_outlines(
|
||||
project_id: str,
|
||||
outlines_data: List[Dict],
|
||||
db: AsyncSession
|
||||
) -> int:
|
||||
"""导入大纲"""
|
||||
count = 0
|
||||
for ol_data in outlines_data:
|
||||
outline = Outline(
|
||||
project_id=project_id,
|
||||
title=ol_data.get("title"),
|
||||
content=ol_data.get("content"),
|
||||
structure=ol_data.get("structure"),
|
||||
order_index=ol_data.get("order_index")
|
||||
)
|
||||
db.add(outline)
|
||||
count += 1
|
||||
|
||||
return count
|
||||
|
||||
@staticmethod
|
||||
async def _import_relationships(
|
||||
project_id: str,
|
||||
relationships_data: List[Dict],
|
||||
char_mapping: Dict[str, str],
|
||||
db: AsyncSession
|
||||
) -> int:
|
||||
"""导入关系"""
|
||||
count = 0
|
||||
for rel_data in relationships_data:
|
||||
source_name = rel_data.get("source_name")
|
||||
target_name = rel_data.get("target_name")
|
||||
|
||||
# 查找角色ID
|
||||
source_id = char_mapping.get(source_name)
|
||||
target_id = char_mapping.get(target_name)
|
||||
|
||||
if source_id and target_id:
|
||||
relationship = CharacterRelationship(
|
||||
project_id=project_id,
|
||||
character_from_id=source_id,
|
||||
character_to_id=target_id,
|
||||
relationship_name=rel_data.get("relationship_name"),
|
||||
intimacy_level=rel_data.get("intimacy_level", 50),
|
||||
status=rel_data.get("status", "active"),
|
||||
description=rel_data.get("description"),
|
||||
started_at=rel_data.get("started_at")
|
||||
)
|
||||
db.add(relationship)
|
||||
count += 1
|
||||
|
||||
return count
|
||||
|
||||
@staticmethod
|
||||
async def _import_organizations(
|
||||
project_id: str,
|
||||
organizations_data: List[Dict],
|
||||
char_mapping: Dict[str, str],
|
||||
db: AsyncSession
|
||||
) -> Dict[str, str]:
|
||||
"""导入组织详情,返回名称到ID的映射"""
|
||||
org_mapping = {}
|
||||
|
||||
# 第一遍:创建所有组织(不设置父组织)
|
||||
temp_orgs = []
|
||||
for org_data in organizations_data:
|
||||
char_name = org_data.get("character_name")
|
||||
char_id = char_mapping.get(char_name)
|
||||
|
||||
if char_id:
|
||||
organization = Organization(
|
||||
project_id=project_id,
|
||||
character_id=char_id,
|
||||
power_level=org_data.get("power_level", 50),
|
||||
member_count=org_data.get("member_count", 0),
|
||||
location=org_data.get("location"),
|
||||
motto=org_data.get("motto"),
|
||||
color=org_data.get("color")
|
||||
)
|
||||
db.add(organization)
|
||||
temp_orgs.append((organization, org_data.get("parent_org_name")))
|
||||
|
||||
await db.flush() # 获取所有组织的ID
|
||||
|
||||
# 建立名称到ID的映射
|
||||
for org, _ in temp_orgs:
|
||||
# 通过character_id查找角色名
|
||||
result = await db.execute(
|
||||
select(Character).where(Character.id == org.character_id)
|
||||
)
|
||||
char = result.scalar_one_or_none()
|
||||
if char:
|
||||
org_mapping[char.name] = org.id
|
||||
|
||||
# 第二遍:设置父组织关系
|
||||
for org, parent_name in temp_orgs:
|
||||
if parent_name:
|
||||
parent_id = org_mapping.get(parent_name)
|
||||
if parent_id:
|
||||
org.parent_org_id = parent_id
|
||||
|
||||
return org_mapping
|
||||
|
||||
@staticmethod
|
||||
async def _import_organization_members(
|
||||
org_members_data: List[Dict],
|
||||
char_mapping: Dict[str, str],
|
||||
org_mapping: Dict[str, str],
|
||||
db: AsyncSession
|
||||
) -> int:
|
||||
"""导入组织成员"""
|
||||
count = 0
|
||||
for member_data in org_members_data:
|
||||
org_name = member_data.get("organization_name")
|
||||
char_name = member_data.get("character_name")
|
||||
|
||||
org_id = org_mapping.get(org_name)
|
||||
char_id = char_mapping.get(char_name)
|
||||
|
||||
if org_id and char_id:
|
||||
member = OrganizationMember(
|
||||
organization_id=org_id,
|
||||
character_id=char_id,
|
||||
position=member_data.get("position"),
|
||||
rank=member_data.get("rank", 0),
|
||||
status=member_data.get("status", "active"),
|
||||
joined_at=member_data.get("joined_at"),
|
||||
loyalty=member_data.get("loyalty", 50),
|
||||
contribution=member_data.get("contribution", 0),
|
||||
notes=member_data.get("notes")
|
||||
)
|
||||
db.add(member)
|
||||
count += 1
|
||||
|
||||
return count
|
||||
|
||||
@staticmethod
|
||||
async def _import_writing_styles(
|
||||
project_id: str,
|
||||
styles_data: List[Dict],
|
||||
db: AsyncSession
|
||||
) -> int:
|
||||
"""导入写作风格"""
|
||||
count = 0
|
||||
for style_data in styles_data:
|
||||
style = WritingStyle(
|
||||
project_id=project_id,
|
||||
name=style_data.get("name"),
|
||||
style_type=style_data.get("style_type"),
|
||||
preset_id=style_data.get("preset_id"),
|
||||
description=style_data.get("description"),
|
||||
prompt_content=style_data.get("prompt_content"),
|
||||
order_index=style_data.get("order_index", 0)
|
||||
)
|
||||
db.add(style)
|
||||
count += 1
|
||||
|
||||
return count
|
||||
@@ -0,0 +1,739 @@
|
||||
"""向量记忆服务 - 基于ChromaDB实现长期记忆和语义检索"""
|
||||
import chromadb
|
||||
from sentence_transformers import SentenceTransformer
|
||||
from typing import List, Dict, Any, Optional
|
||||
import json
|
||||
from datetime import datetime
|
||||
from app.logger import get_logger
|
||||
import os
|
||||
import hashlib
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# 配置离线模式,避免联网检查
|
||||
os.environ['TRANSFORMERS_OFFLINE'] = '1'
|
||||
os.environ['HF_DATASETS_OFFLINE'] = '1'
|
||||
|
||||
|
||||
class MemoryService:
|
||||
"""向量记忆管理服务 - 实现语义检索和长期记忆"""
|
||||
|
||||
_instance = None
|
||||
_initialized = False
|
||||
|
||||
def __new__(cls):
|
||||
"""单例模式"""
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
"""初始化ChromaDB和Embedding模型"""
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
try:
|
||||
# 确保数据目录存在
|
||||
chroma_dir = "data/chroma_db"
|
||||
os.makedirs(chroma_dir, exist_ok=True)
|
||||
|
||||
# 初始化ChromaDB客户端(使用新API - PersistentClient)
|
||||
self.client = chromadb.PersistentClient(path=chroma_dir)
|
||||
|
||||
# 初始化多语言embedding模型(支持中文)
|
||||
logger.info("🔄 正在加载Embedding模型...")
|
||||
|
||||
# 确保模型缓存目录存在
|
||||
model_cache_dir = 'data/models'
|
||||
os.makedirs(model_cache_dir, exist_ok=True)
|
||||
|
||||
try:
|
||||
# 优先使用本地缓存的模型
|
||||
# cache_folder会让模型优先从本地加载,只有不存在时才联网下载
|
||||
self.embedding_model = SentenceTransformer(
|
||||
'paraphrase-multilingual-MiniLM-L12-v2',
|
||||
cache_folder=model_cache_dir,
|
||||
device='cpu' # 明确指定使用CPU
|
||||
)
|
||||
logger.info("✅ Embedding模型加载成功")
|
||||
except Exception as e:
|
||||
logger.warning(f"⚠️ 无法加载多语言模型: {str(e)}")
|
||||
logger.info("🔄 尝试使用备用模型...")
|
||||
try:
|
||||
# 降级到更小的模型作为备选
|
||||
self.embedding_model = SentenceTransformer(
|
||||
'all-MiniLM-L6-v2',
|
||||
cache_folder=model_cache_dir,
|
||||
device='cpu'
|
||||
)
|
||||
logger.info("✅ 使用备用Embedding模型")
|
||||
except Exception as e2:
|
||||
logger.error(f"❌ 所有模型加载失败: {str(e2)}")
|
||||
logger.error("💡 模型首次使用需要联网下载(约420MB)")
|
||||
logger.error(" 或手动下载模型文件到 data/models 目录")
|
||||
raise RuntimeError("无法加载任何Embedding模型")
|
||||
|
||||
self._initialized = True
|
||||
logger.info("✅ MemoryService初始化成功")
|
||||
logger.info(f" - ChromaDB目录: {chroma_dir}")
|
||||
logger.info(f" - Embedding模型: paraphrase-multilingual-MiniLM-L12-v2")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ MemoryService初始化失败: {str(e)}")
|
||||
raise
|
||||
|
||||
def get_collection(self, user_id: str, project_id: str):
|
||||
"""
|
||||
获取或创建项目的记忆集合
|
||||
|
||||
每个用户的每个项目有独立的collection,实现数据隔离
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
project_id: 项目ID
|
||||
|
||||
Returns:
|
||||
ChromaDB Collection对象
|
||||
"""
|
||||
# ChromaDB collection命名规则:
|
||||
# 1. 3-63字符(最重要!)
|
||||
# 2. 开头和结尾必须是字母或数字
|
||||
# 3. 只能包含字母、数字、下划线或短横线
|
||||
# 4. 不能包含连续的点(..)
|
||||
# 5. 不能是有效的IPv4地址
|
||||
|
||||
# 使用SHA256哈希压缩ID长度,确保不超过63字符
|
||||
# 格式: u_{user_hash}_p_{project_hash} (约30字符)
|
||||
user_hash = hashlib.sha256(user_id.encode()).hexdigest()[:8]
|
||||
project_hash = hashlib.sha256(project_id.encode()).hexdigest()[:8]
|
||||
collection_name = f"u_{user_hash}_p_{project_hash}"
|
||||
|
||||
try:
|
||||
return self.client.get_or_create_collection(
|
||||
name=collection_name,
|
||||
metadata={
|
||||
"user_id": user_id,
|
||||
"project_id": project_id,
|
||||
"created_at": datetime.now().isoformat()
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 获取collection失败: {str(e)}")
|
||||
raise
|
||||
|
||||
async def add_memory(
|
||||
self,
|
||||
user_id: str,
|
||||
project_id: str,
|
||||
memory_id: str,
|
||||
content: str,
|
||||
memory_type: str,
|
||||
metadata: Dict[str, Any]
|
||||
) -> bool:
|
||||
"""
|
||||
添加记忆到向量数据库
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
project_id: 项目ID
|
||||
memory_id: 记忆唯一ID
|
||||
content: 记忆内容(将被转换为向量)
|
||||
memory_type: 记忆类型
|
||||
metadata: 附加元数据
|
||||
|
||||
Returns:
|
||||
是否添加成功
|
||||
"""
|
||||
try:
|
||||
collection = self.get_collection(user_id, project_id)
|
||||
|
||||
# 生成文本的向量表示
|
||||
embedding = self.embedding_model.encode(content).tolist()
|
||||
|
||||
# 准备元数据(ChromaDB要求所有值为基础类型)
|
||||
chroma_metadata = {
|
||||
"memory_type": memory_type,
|
||||
"chapter_id": str(metadata.get("chapter_id", "")),
|
||||
"chapter_number": int(metadata.get("chapter_number", 0)),
|
||||
"importance": float(metadata.get("importance_score", 0.5)),
|
||||
"tags": json.dumps(metadata.get("tags", []), ensure_ascii=False),
|
||||
"title": str(metadata.get("title", ""))[:200], # 限制长度
|
||||
"is_foreshadow": int(metadata.get("is_foreshadow", 0)),
|
||||
"created_at": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
# 添加相关角色信息
|
||||
if metadata.get("related_characters"):
|
||||
chroma_metadata["related_characters"] = json.dumps(
|
||||
metadata["related_characters"],
|
||||
ensure_ascii=False
|
||||
)
|
||||
|
||||
# 存储到向量库
|
||||
collection.add(
|
||||
ids=[memory_id],
|
||||
embeddings=[embedding],
|
||||
documents=[content],
|
||||
metadatas=[chroma_metadata]
|
||||
)
|
||||
|
||||
logger.info(f"✅ 记忆已添加: {memory_id[:8]}... (类型:{memory_type}, 重要性:{chroma_metadata['importance']})")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 添加记忆失败: {str(e)}")
|
||||
return False
|
||||
|
||||
async def batch_add_memories(
|
||||
self,
|
||||
user_id: str,
|
||||
project_id: str,
|
||||
memories: List[Dict[str, Any]]
|
||||
) -> int:
|
||||
"""
|
||||
批量添加记忆(性能更好)
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
project_id: 项目ID
|
||||
memories: 记忆列表,每个包含id、content、type、metadata
|
||||
|
||||
Returns:
|
||||
成功添加的数量
|
||||
"""
|
||||
if not memories:
|
||||
return 0
|
||||
|
||||
try:
|
||||
collection = self.get_collection(user_id, project_id)
|
||||
|
||||
ids = []
|
||||
documents = []
|
||||
metadatas = []
|
||||
embeddings = []
|
||||
|
||||
# 批量准备数据
|
||||
for mem in memories:
|
||||
ids.append(mem['id'])
|
||||
documents.append(mem['content'])
|
||||
|
||||
# 生成embedding
|
||||
embedding = self.embedding_model.encode(mem['content']).tolist()
|
||||
embeddings.append(embedding)
|
||||
|
||||
# 准备元数据
|
||||
metadata = mem.get('metadata', {})
|
||||
chroma_metadata = {
|
||||
"memory_type": mem['type'],
|
||||
"chapter_id": str(metadata.get("chapter_id", "")),
|
||||
"chapter_number": int(metadata.get("chapter_number", 0)),
|
||||
"importance": float(metadata.get("importance_score", 0.5)),
|
||||
"tags": json.dumps(metadata.get("tags", []), ensure_ascii=False),
|
||||
"title": str(metadata.get("title", ""))[:200],
|
||||
"is_foreshadow": int(metadata.get("is_foreshadow", 0)),
|
||||
"created_at": datetime.now().isoformat()
|
||||
}
|
||||
metadatas.append(chroma_metadata)
|
||||
|
||||
# 批量添加
|
||||
collection.add(
|
||||
ids=ids,
|
||||
embeddings=embeddings,
|
||||
documents=documents,
|
||||
metadatas=metadatas
|
||||
)
|
||||
|
||||
logger.info(f"✅ 批量添加记忆成功: {len(memories)}条")
|
||||
return len(memories)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 批量添加记忆失败: {str(e)}")
|
||||
return 0
|
||||
|
||||
async def search_memories(
|
||||
self,
|
||||
user_id: str,
|
||||
project_id: str,
|
||||
query: str,
|
||||
memory_types: Optional[List[str]] = None,
|
||||
limit: int = 10,
|
||||
min_importance: float = 0.0,
|
||||
chapter_range: Optional[tuple] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
语义搜索相关记忆
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
project_id: 项目ID
|
||||
query: 查询文本(会被转换为向量进行相似度搜索)
|
||||
memory_types: 过滤特定类型的记忆
|
||||
limit: 返回结果数量
|
||||
min_importance: 最低重要性阈值
|
||||
chapter_range: 章节范围 (start, end)
|
||||
|
||||
Returns:
|
||||
相关记忆列表,按相似度排序
|
||||
"""
|
||||
try:
|
||||
collection = self.get_collection(user_id, project_id)
|
||||
|
||||
# 生成查询向量
|
||||
query_embedding = self.embedding_model.encode(query).tolist()
|
||||
|
||||
# 构建过滤条件 - ChromaDB要求使用$and组合多个条件
|
||||
where_filter = None
|
||||
conditions = []
|
||||
|
||||
if memory_types:
|
||||
conditions.append({"memory_type": {"$in": memory_types}})
|
||||
if min_importance > 0:
|
||||
conditions.append({"importance": {"$gte": min_importance}})
|
||||
if chapter_range:
|
||||
conditions.append({"chapter_number": {"$gte": chapter_range[0]}})
|
||||
conditions.append({"chapter_number": {"$lte": chapter_range[1]}})
|
||||
|
||||
# 根据条件数量选择合适的格式
|
||||
if len(conditions) == 0:
|
||||
where_filter = None
|
||||
elif len(conditions) == 1:
|
||||
where_filter = conditions[0]
|
||||
else:
|
||||
where_filter = {"$and": conditions}
|
||||
|
||||
# 执行向量相似度搜索
|
||||
results = collection.query(
|
||||
query_embeddings=[query_embedding],
|
||||
n_results=limit,
|
||||
where=where_filter
|
||||
)
|
||||
|
||||
# 格式化结果
|
||||
memories = []
|
||||
if results['ids'] and results['ids'][0]:
|
||||
for i in range(len(results['ids'][0])):
|
||||
memories.append({
|
||||
"id": results['ids'][0][i],
|
||||
"content": results['documents'][0][i],
|
||||
"metadata": results['metadatas'][0][i],
|
||||
"similarity": 1 - results['distances'][0][i] if 'distances' in results else 1.0,
|
||||
"distance": results['distances'][0][i] if 'distances' in results else 0.0
|
||||
})
|
||||
|
||||
logger.info(f"🔍 语义搜索完成: 查询='{query[:30]}...', 找到{len(memories)}条记忆")
|
||||
return memories
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 搜索记忆失败: {str(e)}")
|
||||
return []
|
||||
|
||||
async def get_recent_memories(
|
||||
self,
|
||||
user_id: str,
|
||||
project_id: str,
|
||||
current_chapter: int,
|
||||
recent_count: int = 3,
|
||||
min_importance: float = 0.5
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取最近几章的重要记忆(用于保持连贯性)
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
project_id: 项目ID
|
||||
current_chapter: 当前章节号
|
||||
recent_count: 获取最近几章
|
||||
min_importance: 最低重要性阈值
|
||||
|
||||
Returns:
|
||||
最近章节的记忆列表,按重要性排序
|
||||
"""
|
||||
try:
|
||||
collection = self.get_collection(user_id, project_id)
|
||||
|
||||
# 计算章节范围
|
||||
start_chapter = max(1, current_chapter - recent_count)
|
||||
|
||||
# 获取最近章节的记忆
|
||||
results = collection.get(
|
||||
where={
|
||||
"$and": [
|
||||
{"chapter_number": {"$gte": start_chapter}},
|
||||
{"chapter_number": {"$lt": current_chapter}},
|
||||
{"importance": {"$gte": min_importance}}
|
||||
]
|
||||
},
|
||||
limit=100 # 先获取足够多的记忆
|
||||
)
|
||||
|
||||
memories = []
|
||||
if results['ids']:
|
||||
for i in range(len(results['ids'])):
|
||||
memories.append({
|
||||
"id": results['ids'][i],
|
||||
"content": results['documents'][i],
|
||||
"metadata": results['metadatas'][i]
|
||||
})
|
||||
|
||||
# 按重要性和章节号排序
|
||||
memories.sort(
|
||||
key=lambda x: (float(x['metadata'].get('importance', 0)),
|
||||
int(x['metadata'].get('chapter_number', 0))),
|
||||
reverse=True
|
||||
)
|
||||
|
||||
# 返回最重要的前N条
|
||||
top_memories = memories[:20]
|
||||
logger.info(f"📚 获取最近记忆: 章节{start_chapter}-{current_chapter-1}, 找到{len(top_memories)}条")
|
||||
return top_memories
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 获取最近记忆失败: {str(e)}")
|
||||
return []
|
||||
|
||||
async def find_unresolved_foreshadows(
|
||||
self,
|
||||
user_id: str,
|
||||
project_id: str,
|
||||
current_chapter: int
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
查找未完结的伏笔
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
project_id: 项目ID
|
||||
current_chapter: 当前章节号
|
||||
|
||||
Returns:
|
||||
未完结伏笔列表
|
||||
"""
|
||||
try:
|
||||
collection = self.get_collection(user_id, project_id)
|
||||
|
||||
# 查找伏笔状态为1(已埋下但未回收)的记忆
|
||||
results = collection.get(
|
||||
where={
|
||||
"$and": [
|
||||
{"is_foreshadow": 1},
|
||||
{"chapter_number": {"$lt": current_chapter}}
|
||||
]
|
||||
},
|
||||
limit=50
|
||||
)
|
||||
|
||||
foreshadows = []
|
||||
if results['ids']:
|
||||
for i in range(len(results['ids'])):
|
||||
foreshadows.append({
|
||||
"id": results['ids'][i],
|
||||
"content": results['documents'][i],
|
||||
"metadata": results['metadatas'][i]
|
||||
})
|
||||
|
||||
# 按重要性排序
|
||||
foreshadows.sort(
|
||||
key=lambda x: float(x['metadata'].get('importance', 0)),
|
||||
reverse=True
|
||||
)
|
||||
|
||||
logger.info(f"🎣 找到未完结伏笔: {len(foreshadows)}个")
|
||||
return foreshadows
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 查找伏笔失败: {str(e)}")
|
||||
return []
|
||||
|
||||
async def build_context_for_generation(
|
||||
self,
|
||||
user_id: str,
|
||||
project_id: str,
|
||||
current_chapter: int,
|
||||
chapter_outline: str,
|
||||
character_names: List[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
为章节生成构建智能上下文
|
||||
|
||||
这是核心功能: 结合多种检索策略,为AI生成提供最相关的记忆
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
project_id: 项目ID
|
||||
current_chapter: 当前章节号
|
||||
chapter_outline: 本章大纲
|
||||
character_names: 涉及的角色名列表
|
||||
|
||||
Returns:
|
||||
包含各种上下文信息的字典
|
||||
"""
|
||||
logger.info(f"🧠 开始构建章节{current_chapter}的智能上下文...")
|
||||
|
||||
# 1. 获取最近章节上下文(时间连续性)
|
||||
recent = await self.get_recent_memories(
|
||||
user_id, project_id, current_chapter,
|
||||
recent_count=3, min_importance=0.5
|
||||
)
|
||||
|
||||
# 2. 语义搜索相关记忆
|
||||
relevant = await self.search_memories(
|
||||
user_id=user_id,
|
||||
project_id=project_id,
|
||||
query=chapter_outline,
|
||||
limit=10,
|
||||
min_importance=0.4
|
||||
)
|
||||
|
||||
# 3. 查找未完结伏笔
|
||||
foreshadows = await self.find_unresolved_foreshadows(
|
||||
user_id, project_id, current_chapter
|
||||
)
|
||||
|
||||
# 4. 如果有指定角色,获取角色相关记忆
|
||||
character_memories = []
|
||||
if character_names:
|
||||
character_query = " ".join(character_names) + " 角色 状态 关系"
|
||||
character_memories = await self.search_memories(
|
||||
user_id=user_id,
|
||||
project_id=project_id,
|
||||
query=character_query,
|
||||
memory_types=["character_event", "plot_point"],
|
||||
limit=8
|
||||
)
|
||||
|
||||
# 5. 获取重要情节点
|
||||
# 注意:ChromaDB的where条件需要特殊处理,不能同时使用多个顶层条件
|
||||
try:
|
||||
plot_points = await self.search_memories(
|
||||
user_id=user_id,
|
||||
project_id=project_id,
|
||||
query="重要 转折 高潮 关键",
|
||||
memory_types=["plot_point", "hook"],
|
||||
limit=5,
|
||||
min_importance=0.7
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 搜索记忆失败: {str(e)}")
|
||||
# 降级处理:分别查询
|
||||
plot_points = []
|
||||
try:
|
||||
plot_points = await self.search_memories(
|
||||
user_id=user_id,
|
||||
project_id=project_id,
|
||||
query="重要 转折 高潮 关键",
|
||||
memory_types=["plot_point", "hook"],
|
||||
limit=5
|
||||
)
|
||||
except Exception as e2:
|
||||
logger.warning(f"⚠️ 降级查询也失败: {str(e2)}")
|
||||
plot_points = []
|
||||
|
||||
context = {
|
||||
"recent_context": self._format_memories(recent, "最近章节记忆"),
|
||||
"relevant_memories": self._format_memories(relevant, "语义相关记忆"),
|
||||
"character_states": self._format_memories(character_memories, "角色相关记忆"),
|
||||
"foreshadows": self._format_memories(foreshadows[:5], "未完结伏笔"),
|
||||
"plot_points": self._format_memories(plot_points, "重要情节点"),
|
||||
"stats": {
|
||||
"recent_count": len(recent),
|
||||
"relevant_count": len(relevant),
|
||||
"character_count": len(character_memories),
|
||||
"foreshadow_count": len(foreshadows),
|
||||
"plot_point_count": len(plot_points)
|
||||
}
|
||||
}
|
||||
|
||||
logger.info(f"✅ 上下文构建完成: 最近{len(recent)}条, 相关{len(relevant)}条, 伏笔{len(foreshadows)}个")
|
||||
return context
|
||||
def _format_memories(self, memories: List[Dict], section_title: str = "记忆") -> str:
|
||||
"""
|
||||
格式化记忆列表为文本
|
||||
|
||||
Args:
|
||||
memories: 记忆列表
|
||||
section_title: 章节标题
|
||||
|
||||
Returns:
|
||||
格式化后的文本
|
||||
"""
|
||||
if not memories:
|
||||
return f"【{section_title}】\n暂无相关记忆\n"
|
||||
|
||||
lines = [f"【{section_title}】"]
|
||||
for i, mem in enumerate(memories, 1):
|
||||
meta = mem.get('metadata', {})
|
||||
chapter_num = meta.get('chapter_number', '?')
|
||||
mem_type = meta.get('memory_type', '未知')
|
||||
importance = float(meta.get('importance', 0.5))
|
||||
title = meta.get('title', '')
|
||||
content = mem['content']
|
||||
|
||||
# 格式: [序号] 第X章-类型(重要性) 标题: 内容
|
||||
line = f"{i}. [第{chapter_num}章-{mem_type}★{importance:.1f}]"
|
||||
if title:
|
||||
line += f" {title}: {content[:100]}"
|
||||
else:
|
||||
line += f" {content[:150]}"
|
||||
lines.append(line)
|
||||
|
||||
return "\n".join(lines) + "\n"
|
||||
|
||||
async def delete_chapter_memories(
|
||||
self,
|
||||
user_id: str,
|
||||
project_id: str,
|
||||
chapter_id: str
|
||||
) -> bool:
|
||||
"""
|
||||
删除指定章节的所有记忆
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
project_id: 项目ID
|
||||
chapter_id: 章节ID
|
||||
|
||||
Returns:
|
||||
是否删除成功
|
||||
"""
|
||||
try:
|
||||
collection = self.get_collection(user_id, project_id)
|
||||
|
||||
# 查找该章节的所有记忆
|
||||
results = collection.get(
|
||||
where={"chapter_id": chapter_id}
|
||||
)
|
||||
|
||||
if results['ids']:
|
||||
# 删除这些记忆
|
||||
collection.delete(ids=results['ids'])
|
||||
logger.info(f"🗑️ 已删除章节{chapter_id[:8]}的{len(results['ids'])}条记忆")
|
||||
return True
|
||||
else:
|
||||
logger.info(f"ℹ️ 章节{chapter_id[:8]}没有记忆需要删除")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 删除章节记忆失败: {str(e)}")
|
||||
return False
|
||||
|
||||
async def update_memory(
|
||||
self,
|
||||
user_id: str,
|
||||
project_id: str,
|
||||
memory_id: str,
|
||||
content: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
) -> bool:
|
||||
"""
|
||||
更新记忆内容或元数据
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
project_id: 项目ID
|
||||
memory_id: 记忆ID
|
||||
content: 新内容(可选)
|
||||
metadata: 新元数据(可选)
|
||||
|
||||
Returns:
|
||||
是否更新成功
|
||||
"""
|
||||
try:
|
||||
collection = self.get_collection(user_id, project_id)
|
||||
|
||||
update_data = {}
|
||||
|
||||
if content:
|
||||
# 重新生成embedding
|
||||
embedding = self.embedding_model.encode(content).tolist()
|
||||
update_data['embeddings'] = [embedding]
|
||||
update_data['documents'] = [content]
|
||||
|
||||
if metadata:
|
||||
# 准备新的元数据
|
||||
chroma_metadata = {}
|
||||
for key, value in metadata.items():
|
||||
if isinstance(value, (list, dict)):
|
||||
chroma_metadata[key] = json.dumps(value, ensure_ascii=False)
|
||||
else:
|
||||
chroma_metadata[key] = value
|
||||
update_data['metadatas'] = [chroma_metadata]
|
||||
|
||||
if update_data:
|
||||
collection.update(
|
||||
ids=[memory_id],
|
||||
**update_data
|
||||
)
|
||||
logger.info(f"✅ 记忆已更新: {memory_id[:8]}...")
|
||||
return True
|
||||
else:
|
||||
logger.warning("⚠️ 没有提供更新内容")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 更新记忆失败: {str(e)}")
|
||||
return False
|
||||
|
||||
async def get_memory_stats(
|
||||
self,
|
||||
user_id: str,
|
||||
project_id: str
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
获取记忆统计信息
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
project_id: 项目ID
|
||||
|
||||
Returns:
|
||||
统计信息字典
|
||||
"""
|
||||
try:
|
||||
collection = self.get_collection(user_id, project_id)
|
||||
|
||||
# 获取所有记忆
|
||||
all_memories = collection.get()
|
||||
|
||||
if not all_memories['ids']:
|
||||
return {
|
||||
"total_count": 0,
|
||||
"by_type": {},
|
||||
"by_chapter": {},
|
||||
"foreshadow_count": 0
|
||||
}
|
||||
|
||||
# 统计各类型数量
|
||||
type_counts = {}
|
||||
chapter_counts = {}
|
||||
foreshadow_count = 0
|
||||
|
||||
for i, meta in enumerate(all_memories['metadatas']):
|
||||
mem_type = meta.get('memory_type', 'unknown')
|
||||
chapter_num = meta.get('chapter_number', 0)
|
||||
is_foreshadow = meta.get('is_foreshadow', 0)
|
||||
|
||||
type_counts[mem_type] = type_counts.get(mem_type, 0) + 1
|
||||
chapter_counts[str(chapter_num)] = chapter_counts.get(str(chapter_num), 0) + 1
|
||||
|
||||
if is_foreshadow == 1:
|
||||
foreshadow_count += 1
|
||||
|
||||
stats = {
|
||||
"total_count": len(all_memories['ids']),
|
||||
"by_type": type_counts,
|
||||
"by_chapter": chapter_counts,
|
||||
"foreshadow_count": foreshadow_count,
|
||||
"foreshadow_resolved": sum(1 for m in all_memories['metadatas'] if m.get('is_foreshadow') == 2)
|
||||
}
|
||||
|
||||
logger.info(f"📊 记忆统计: 总计{stats['total_count']}条, 伏笔{foreshadow_count}个")
|
||||
return stats
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 获取统计信息失败: {str(e)}")
|
||||
return {"error": str(e)}
|
||||
|
||||
|
||||
# 创建全局实例
|
||||
memory_service = MemoryService()
|
||||
|
||||
@@ -0,0 +1,559 @@
|
||||
"""剧情分析服务 - 自动分析章节的钩子、伏笔、冲突等元素"""
|
||||
from typing import Dict, Any, List, Optional
|
||||
from app.services.ai_service import AIService
|
||||
from app.logger import get_logger
|
||||
import json
|
||||
import re
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class PlotAnalyzer:
|
||||
"""剧情分析器 - 使用AI分析章节内容"""
|
||||
|
||||
# AI分析提示词模板
|
||||
ANALYSIS_PROMPT = """你是一位专业的小说编辑和剧情分析师。请深度分析以下章节内容:
|
||||
|
||||
**章节信息:**
|
||||
- 章节: 第{chapter_number}章
|
||||
- 标题: {title}
|
||||
- 字数: {word_count}字
|
||||
|
||||
**章节内容:**
|
||||
{content}
|
||||
|
||||
---
|
||||
|
||||
**分析任务:**
|
||||
请从专业编辑的角度,全面分析这一章节:
|
||||
|
||||
### 1. 剧情钩子 (Hooks) - 吸引读者的元素
|
||||
识别能够吸引读者继续阅读的关键元素:
|
||||
- **悬念钩子**: 未解之谜、疑问、谜团
|
||||
- **情感钩子**: 引发共鸣的情感点、触动心弦的时刻
|
||||
- **冲突钩子**: 矛盾对抗、紧张局势
|
||||
- **认知钩子**: 颠覆认知的信息、惊人真相
|
||||
|
||||
每个钩子需要:
|
||||
- 类型分类
|
||||
- 具体内容描述
|
||||
- 强度评分(1-10)
|
||||
- 出现位置(开头/中段/结尾)
|
||||
- **关键词**: 【必填】从章节原文中逐字复制一段关键文本(8-25字),必须是原文中真实存在的连续文字,用于在文本中精确定位。不要概括或改写,必须原样复制!
|
||||
|
||||
### 2. 伏笔分析 (Foreshadowing)
|
||||
- **埋下的新伏笔**: 描述内容、预期作用、隐藏程度(1-10)
|
||||
- **回收的旧伏笔**: 呼应哪一章、回收效果评分
|
||||
- **伏笔质量**: 巧妙性和合理性评估
|
||||
- **关键词**: 【必填】从章节原文中逐字复制一段关键文本(8-25字),必须是原文中真实存在的连续文字,用于在文本中精确定位。不要概括或改写,必须原样复制!
|
||||
|
||||
### 3. 冲突分析 (Conflict)
|
||||
- 冲突类型: 人与人/人与己/人与环境/人与社会
|
||||
- 冲突各方及其立场
|
||||
- 冲突强度评分(1-10)
|
||||
- 冲突解决进度(0-100%)
|
||||
|
||||
### 4. 情感曲线 (Emotional Arc)
|
||||
- 主导情绪: 紧张/温馨/悲伤/激昂/平静等
|
||||
- 情感强度(1-10)
|
||||
- 情绪变化轨迹描述
|
||||
|
||||
### 5. 角色状态追踪 (Character Development)
|
||||
对每个出场角色分析:
|
||||
- 心理状态变化(前→后)
|
||||
- 关系变化
|
||||
- 关键行动和决策
|
||||
- 成长或退步
|
||||
|
||||
### 6. 关键情节点 (Plot Points)
|
||||
列出3-5个核心情节点:
|
||||
- 情节内容
|
||||
- 类型(revelation/conflict/resolution/transition)
|
||||
- 重要性(0.0-1.0)
|
||||
- 对故事的影响
|
||||
- **关键词**: 【必填】从章节原文中逐字复制一段关键文本(8-25字),必须是原文中真实存在的连续文字,用于在文本中精确定位。不要概括或改写,必须原样复制!
|
||||
|
||||
### 7. 场景与节奏
|
||||
- 主要场景
|
||||
- 叙事节奏(快/中/慢)
|
||||
- 对话与描写的比例
|
||||
|
||||
### 8. 质量评分
|
||||
- 节奏把控: 1-10分
|
||||
- 吸引力: 1-10分
|
||||
- 连贯性: 1-10分
|
||||
- 整体质量: 1-10分
|
||||
|
||||
### 9. 改进建议
|
||||
提供3-5条具体的改进建议
|
||||
|
||||
---
|
||||
|
||||
**输出格式(纯JSON,不要markdown标记):**
|
||||
|
||||
{{
|
||||
"hooks": [
|
||||
{{
|
||||
"type": "悬念",
|
||||
"content": "具体描述",
|
||||
"strength": 8,
|
||||
"position": "中段",
|
||||
"keyword": "必须从原文逐字复制的文本片段"
|
||||
}}
|
||||
],
|
||||
"foreshadows": [
|
||||
{{
|
||||
"content": "伏笔内容",
|
||||
"type": "planted",
|
||||
"strength": 7,
|
||||
"subtlety": 8,
|
||||
"reference_chapter": null,
|
||||
"keyword": "必须从原文逐字复制的文本片段"
|
||||
}}
|
||||
],
|
||||
"conflict": {{
|
||||
"types": ["人与人", "人与己"],
|
||||
"parties": ["主角-复仇", "反派-维护现状"],
|
||||
"level": 8,
|
||||
"description": "冲突描述",
|
||||
"resolution_progress": 0.3
|
||||
}},
|
||||
"emotional_arc": {{
|
||||
"primary_emotion": "紧张",
|
||||
"intensity": 8,
|
||||
"curve": "平静→紧张→高潮→释放",
|
||||
"secondary_emotions": ["期待", "焦虑"]
|
||||
}},
|
||||
"character_states": [
|
||||
{{
|
||||
"character_name": "张三",
|
||||
"state_before": "犹豫",
|
||||
"state_after": "坚定",
|
||||
"psychological_change": "心理变化描述",
|
||||
"key_event": "触发事件",
|
||||
"relationship_changes": {{"李四": "关系改善"}}
|
||||
}}
|
||||
],
|
||||
"plot_points": [
|
||||
{{
|
||||
"content": "情节点描述",
|
||||
"type": "revelation",
|
||||
"importance": 0.9,
|
||||
"impact": "推动故事发展",
|
||||
"keyword": "必须从原文逐字复制的文本片段"
|
||||
}}
|
||||
],
|
||||
"scenes": [
|
||||
{{
|
||||
"location": "地点",
|
||||
"atmosphere": "氛围",
|
||||
"duration": "时长估计"
|
||||
}}
|
||||
],
|
||||
"pacing": "varied",
|
||||
"dialogue_ratio": 0.4,
|
||||
"description_ratio": 0.3,
|
||||
"scores": {{
|
||||
"pacing": 8,
|
||||
"engagement": 9,
|
||||
"coherence": 8,
|
||||
"overall": 8.5
|
||||
}},
|
||||
"plot_stage": "发展",
|
||||
"suggestions": [
|
||||
"具体建议1",
|
||||
"具体建议2"
|
||||
]
|
||||
}}
|
||||
|
||||
**重要提示:**
|
||||
1. 每个钩子、伏笔、情节点的keyword字段是必填的,不能为空
|
||||
2. keyword必须是从章节原文中逐字复制的文本,长度8-25字
|
||||
3. keyword用于在前端标注文本位置,所以必须能在原文中精确找到
|
||||
4. 不要使用概括性语句或改写后的文字作为keyword
|
||||
|
||||
只返回JSON,不要其他说明。"""
|
||||
|
||||
def __init__(self, ai_service: AIService):
|
||||
"""
|
||||
初始化剧情分析器
|
||||
|
||||
Args:
|
||||
ai_service: AI服务实例
|
||||
"""
|
||||
self.ai_service = ai_service
|
||||
logger.info("✅ PlotAnalyzer初始化成功")
|
||||
|
||||
async def analyze_chapter(
|
||||
self,
|
||||
chapter_number: int,
|
||||
title: str,
|
||||
content: str,
|
||||
word_count: int
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
分析单章内容
|
||||
|
||||
Args:
|
||||
chapter_number: 章节号
|
||||
title: 章节标题
|
||||
content: 章节内容
|
||||
word_count: 字数
|
||||
|
||||
Returns:
|
||||
分析结果字典,失败返回None
|
||||
"""
|
||||
try:
|
||||
logger.info(f"🔍 开始分析第{chapter_number}章: {title}")
|
||||
|
||||
# 如果内容过长,截取前8000字(避免超token)
|
||||
analysis_content = content[:8000] if len(content) > 8000 else content
|
||||
|
||||
# 构建提示词
|
||||
prompt = self.ANALYSIS_PROMPT.format(
|
||||
chapter_number=chapter_number,
|
||||
title=title,
|
||||
word_count=word_count,
|
||||
content=analysis_content
|
||||
)
|
||||
|
||||
# 调用AI进行分析
|
||||
# 注意:不指定max_tokens,使用用户在设置中配置的值
|
||||
logger.info(f" 调用AI分析(内容长度: {len(analysis_content)}字)...")
|
||||
response = await self.ai_service.generate_text(
|
||||
prompt=prompt,
|
||||
temperature=0.3 # 降低温度以获得更稳定的JSON输出
|
||||
)
|
||||
|
||||
# 解析JSON结果
|
||||
analysis_result = self._parse_analysis_response(response)
|
||||
|
||||
if analysis_result:
|
||||
logger.info(f"✅ 第{chapter_number}章分析完成")
|
||||
logger.info(f" - 钩子: {len(analysis_result.get('hooks', []))}个")
|
||||
logger.info(f" - 伏笔: {len(analysis_result.get('foreshadows', []))}个")
|
||||
logger.info(f" - 情节点: {len(analysis_result.get('plot_points', []))}个")
|
||||
logger.info(f" - 整体评分: {analysis_result.get('scores', {}).get('overall', 'N/A')}")
|
||||
return analysis_result
|
||||
else:
|
||||
logger.error(f"❌ 第{chapter_number}章分析失败: JSON解析错误")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 章节分析异常: {str(e)}")
|
||||
return None
|
||||
|
||||
def _parse_analysis_response(self, response: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
解析AI返回的分析结果
|
||||
|
||||
Args:
|
||||
response: AI返回的文本
|
||||
|
||||
Returns:
|
||||
解析后的字典,失败返回None
|
||||
"""
|
||||
try:
|
||||
# 清理响应文本
|
||||
cleaned = response.strip()
|
||||
|
||||
# 移除可能的markdown标记
|
||||
cleaned = re.sub(r'^```json\s*', '', cleaned)
|
||||
cleaned = re.sub(r'^```\s*', '', cleaned)
|
||||
cleaned = re.sub(r'\s*```$', '', cleaned)
|
||||
|
||||
# 尝试解析JSON
|
||||
result = json.loads(cleaned)
|
||||
|
||||
# 验证必要字段
|
||||
required_fields = ['hooks', 'plot_points', 'scores']
|
||||
for field in required_fields:
|
||||
if field not in result:
|
||||
logger.warning(f"⚠️ 分析结果缺少字段: {field}")
|
||||
result[field] = [] if field != 'scores' else {}
|
||||
|
||||
return result
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"❌ JSON解析失败: {str(e)}")
|
||||
logger.error(f" 原始响应(前500字): {response[:500]}")
|
||||
|
||||
# 尝试提取JSON部分
|
||||
json_match = re.search(r'\{[\s\S]*\}', response)
|
||||
if json_match:
|
||||
try:
|
||||
result = json.loads(json_match.group())
|
||||
logger.info("✅ 通过正则提取成功解析JSON")
|
||||
return result
|
||||
except:
|
||||
pass
|
||||
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 解析异常: {str(e)}")
|
||||
return None
|
||||
|
||||
def extract_memories_from_analysis(
|
||||
self,
|
||||
analysis: Dict[str, Any],
|
||||
chapter_id: str,
|
||||
chapter_number: int,
|
||||
chapter_content: str = ""
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
从分析结果中提取记忆片段
|
||||
|
||||
Args:
|
||||
analysis: 分析结果
|
||||
chapter_id: 章节ID
|
||||
chapter_number: 章节号
|
||||
chapter_content: 章节完整内容(用于计算位置)
|
||||
|
||||
Returns:
|
||||
记忆片段列表
|
||||
"""
|
||||
memories = []
|
||||
|
||||
try:
|
||||
# 1. 提取钩子作为记忆
|
||||
for i, hook in enumerate(analysis.get('hooks', [])):
|
||||
if hook.get('strength', 0) >= 6: # 只保存强度>=6的钩子
|
||||
keyword = hook.get('keyword', '')
|
||||
position, length = self._find_text_position(chapter_content, keyword)
|
||||
|
||||
logger.info(f" 钩子位置: keyword='{keyword[:30]}...', pos={position}, len={length}")
|
||||
|
||||
memories.append({
|
||||
'type': 'hook',
|
||||
'content': f"[{hook.get('type', '未知')}钩子] {hook.get('content', '')}",
|
||||
'title': f"{hook.get('type', '钩子')} - {hook.get('position', '')}",
|
||||
'metadata': {
|
||||
'chapter_id': chapter_id,
|
||||
'chapter_number': chapter_number,
|
||||
'importance_score': min(hook.get('strength', 5) / 10, 1.0),
|
||||
'tags': [hook.get('type', '钩子'), hook.get('position', '')],
|
||||
'is_foreshadow': 0,
|
||||
'keyword': keyword,
|
||||
'text_position': position,
|
||||
'text_length': length,
|
||||
'strength': hook.get('strength', 5),
|
||||
'position_desc': hook.get('position', '')
|
||||
}
|
||||
})
|
||||
|
||||
# 2. 提取伏笔作为记忆
|
||||
for i, foreshadow in enumerate(analysis.get('foreshadows', [])):
|
||||
is_planted = foreshadow.get('type') == 'planted'
|
||||
keyword = foreshadow.get('keyword', '')
|
||||
position, length = self._find_text_position(chapter_content, keyword)
|
||||
|
||||
logger.info(f" 伏笔位置: keyword='{keyword[:30]}...', pos={position}, len={length}")
|
||||
|
||||
memories.append({
|
||||
'type': 'foreshadow',
|
||||
'content': foreshadow.get('content', ''),
|
||||
'title': f"{'埋下伏笔' if is_planted else '回收伏笔'}",
|
||||
'metadata': {
|
||||
'chapter_id': chapter_id,
|
||||
'chapter_number': chapter_number,
|
||||
'importance_score': min(foreshadow.get('strength', 5) / 10, 1.0),
|
||||
'tags': ['伏笔', foreshadow.get('type', 'planted')],
|
||||
'is_foreshadow': 1 if is_planted else 2,
|
||||
'reference_chapter': foreshadow.get('reference_chapter'),
|
||||
'keyword': keyword,
|
||||
'text_position': position,
|
||||
'text_length': length,
|
||||
'foreshadow_type': foreshadow.get('type', 'planted'),
|
||||
'strength': foreshadow.get('strength', 5)
|
||||
}
|
||||
})
|
||||
|
||||
# 3. 提取关键情节点
|
||||
for i, plot_point in enumerate(analysis.get('plot_points', [])):
|
||||
if plot_point.get('importance', 0) >= 0.6: # 只保存重要性>=0.6的情节点
|
||||
keyword = plot_point.get('keyword', '')
|
||||
position, length = self._find_text_position(chapter_content, keyword)
|
||||
|
||||
logger.info(f" 情节点位置: keyword='{keyword[:30]}...', pos={position}, len={length}")
|
||||
|
||||
memories.append({
|
||||
'type': 'plot_point',
|
||||
'content': f"{plot_point.get('content', '')}。影响: {plot_point.get('impact', '')}",
|
||||
'title': f"情节点 - {plot_point.get('type', '未知')}",
|
||||
'metadata': {
|
||||
'chapter_id': chapter_id,
|
||||
'chapter_number': chapter_number,
|
||||
'importance_score': plot_point.get('importance', 0.5),
|
||||
'tags': ['情节点', plot_point.get('type', '未知')],
|
||||
'is_foreshadow': 0,
|
||||
'keyword': keyword,
|
||||
'text_position': position,
|
||||
'text_length': length
|
||||
}
|
||||
})
|
||||
|
||||
# 4. 提取角色状态变化
|
||||
for i, char_state in enumerate(analysis.get('character_states', [])):
|
||||
char_name = char_state.get('character_name', '未知角色')
|
||||
memories.append({
|
||||
'type': 'character_event',
|
||||
'content': f"{char_name}的状态变化: {char_state.get('state_before', '')} → {char_state.get('state_after', '')}。{char_state.get('psychological_change', '')}",
|
||||
'title': f"{char_name}的变化",
|
||||
'metadata': {
|
||||
'chapter_id': chapter_id,
|
||||
'chapter_number': chapter_number,
|
||||
'importance_score': 0.7,
|
||||
'tags': ['角色', char_name, '状态变化'],
|
||||
'related_characters': [char_name],
|
||||
'is_foreshadow': 0
|
||||
}
|
||||
})
|
||||
|
||||
# 5. 如果有重要冲突,也记录下来
|
||||
conflict = analysis.get('conflict', {})
|
||||
|
||||
if conflict and conflict.get('level', 0) >= 7:
|
||||
# 确保 parties 和 types 都是字符串列表
|
||||
parties = conflict.get('parties', [])
|
||||
if parties and isinstance(parties, list):
|
||||
parties = [str(p) for p in parties]
|
||||
|
||||
types = conflict.get('types', [])
|
||||
if types and isinstance(types, list):
|
||||
types = [str(t) for t in types]
|
||||
|
||||
memories.append({
|
||||
'type': 'plot_point',
|
||||
'content': f"重要冲突: {conflict.get('description', '')}。冲突各方: {', '.join(parties)}",
|
||||
'title': f"冲突 - 强度{conflict.get('level', 0)}",
|
||||
'metadata': {
|
||||
'chapter_id': chapter_id,
|
||||
'chapter_number': chapter_number,
|
||||
'importance_score': min(conflict.get('level', 5) / 10, 1.0),
|
||||
'tags': ['冲突'] + types,
|
||||
'is_foreshadow': 0
|
||||
}
|
||||
})
|
||||
|
||||
logger.info(f"📝 从分析中提取了{len(memories)}条记忆")
|
||||
return memories
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 提取记忆失败: {str(e)}")
|
||||
return []
|
||||
|
||||
def _find_text_position(self, full_text: str, keyword: str) -> tuple[int, int]:
|
||||
"""
|
||||
在全文中查找关键词位置
|
||||
|
||||
Args:
|
||||
full_text: 完整文本
|
||||
keyword: 关键词
|
||||
|
||||
Returns:
|
||||
(起始位置, 长度) 如果未找到返回(-1, 0)
|
||||
"""
|
||||
if not keyword or not full_text:
|
||||
return (-1, 0)
|
||||
|
||||
try:
|
||||
# 1. 精确匹配
|
||||
pos = full_text.find(keyword)
|
||||
if pos != -1:
|
||||
return (pos, len(keyword))
|
||||
|
||||
# 2. 去除标点符号后匹配
|
||||
import re
|
||||
clean_keyword = re.sub(r'[,。!?、;:""''()《》【】]', '', keyword)
|
||||
clean_text = re.sub(r'[,。!?、;:""''()《》【】]', '', full_text)
|
||||
pos = clean_text.find(clean_keyword)
|
||||
|
||||
if pos != -1:
|
||||
# 反向映射到原文位置(简化处理)
|
||||
return (pos, len(clean_keyword))
|
||||
|
||||
# 3. 模糊匹配:查找关键词的前半部分
|
||||
if len(keyword) > 10:
|
||||
partial = keyword[:min(15, len(keyword))]
|
||||
pos = full_text.find(partial)
|
||||
if pos != -1:
|
||||
return (pos, len(partial))
|
||||
|
||||
# 4. 未找到
|
||||
logger.debug(f"未找到关键词位置: {keyword[:30]}...")
|
||||
return (-1, 0)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"查找位置失败: {str(e)}")
|
||||
return (-1, 0)
|
||||
|
||||
def generate_analysis_summary(self, analysis: Dict[str, Any]) -> str:
|
||||
"""
|
||||
生成分析摘要文本
|
||||
|
||||
Args:
|
||||
analysis: 分析结果
|
||||
|
||||
Returns:
|
||||
格式化的摘要文本
|
||||
"""
|
||||
try:
|
||||
lines = ["=== 章节分析报告 ===\n"]
|
||||
|
||||
# 整体评分
|
||||
scores = analysis.get('scores', {})
|
||||
lines.append(f"【整体评分】")
|
||||
lines.append(f" 整体质量: {scores.get('overall', 'N/A')}/10")
|
||||
lines.append(f" 节奏把控: {scores.get('pacing', 'N/A')}/10")
|
||||
lines.append(f" 吸引力: {scores.get('engagement', 'N/A')}/10")
|
||||
lines.append(f" 连贯性: {scores.get('coherence', 'N/A')}/10\n")
|
||||
|
||||
# 剧情阶段
|
||||
lines.append(f"【剧情阶段】{analysis.get('plot_stage', '未知')}\n")
|
||||
|
||||
# 钩子统计
|
||||
hooks = analysis.get('hooks', [])
|
||||
if hooks:
|
||||
lines.append(f"【钩子分析】共{len(hooks)}个")
|
||||
for hook in hooks[:3]: # 只显示前3个
|
||||
lines.append(f" • [{hook.get('type')}] {hook.get('content', '')[:50]}... (强度:{hook.get('strength', 0)})")
|
||||
lines.append("")
|
||||
|
||||
# 伏笔统计
|
||||
foreshadows = analysis.get('foreshadows', [])
|
||||
if foreshadows:
|
||||
planted = sum(1 for f in foreshadows if f.get('type') == 'planted')
|
||||
resolved = sum(1 for f in foreshadows if f.get('type') == 'resolved')
|
||||
lines.append(f"【伏笔分析】埋下{planted}个, 回收{resolved}个\n")
|
||||
|
||||
# 冲突分析
|
||||
conflict = analysis.get('conflict', {})
|
||||
if conflict:
|
||||
lines.append(f"【冲突分析】")
|
||||
lines.append(f" 类型: {', '.join(conflict.get('types', []))}")
|
||||
lines.append(f" 强度: {conflict.get('level', 0)}/10")
|
||||
lines.append(f" 进度: {int(conflict.get('resolution_progress', 0) * 100)}%\n")
|
||||
|
||||
# 改进建议
|
||||
suggestions = analysis.get('suggestions', [])
|
||||
if suggestions:
|
||||
lines.append(f"【改进建议】")
|
||||
for i, sug in enumerate(suggestions, 1):
|
||||
lines.append(f" {i}. {sug}")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 生成摘要失败: {str(e)}")
|
||||
return "分析摘要生成失败"
|
||||
|
||||
|
||||
# 创建全局实例(需要时手动初始化)
|
||||
_plot_analyzer_instance = None
|
||||
|
||||
def get_plot_analyzer(ai_service: AIService) -> PlotAnalyzer:
|
||||
"""获取剧情分析器实例"""
|
||||
global _plot_analyzer_instance
|
||||
if _plot_analyzer_instance is None:
|
||||
_plot_analyzer_instance = PlotAnalyzer(ai_service)
|
||||
return _plot_analyzer_instance
|
||||
@@ -315,7 +315,7 @@ class PromptService:
|
||||
2. 数组中要包含{chapter_count}个章节对象
|
||||
3. 文本中不要使用中文引号(""),改用【】或《》"""
|
||||
|
||||
# 大纲续写提示词
|
||||
# 大纲续写提示词(记忆增强版)
|
||||
OUTLINE_CONTINUE_GENERATION = """你是一位经验丰富的小说作家和编剧。请基于以下信息续写小说大纲:
|
||||
|
||||
【项目信息】
|
||||
@@ -340,6 +340,11 @@ class PromptService:
|
||||
【最近剧情】
|
||||
{recent_plot}
|
||||
|
||||
【🧠 智能记忆系统 - 续写参考】
|
||||
以下是从故事记忆库中检索到的相关信息,请在续写大纲时参考:
|
||||
|
||||
{memory_context}
|
||||
|
||||
【续写指导】
|
||||
- 当前情节阶段:{plot_stage_instruction}
|
||||
- 起始章节编号:第{start_chapter}章
|
||||
@@ -348,10 +353,12 @@ class PromptService:
|
||||
|
||||
请生成第{start_chapter}章到第{end_chapter}章的大纲。
|
||||
要求:
|
||||
- 与前文自然衔接,保持故事连贯性
|
||||
- 遵循情节阶段的发展要求
|
||||
- 保持与已有章节相同的风格和详细程度
|
||||
- 推进角色成长和情节发展
|
||||
- **剧情连贯性**:与前文自然衔接,保持故事连贯性
|
||||
- **记忆参考**:适当参考记忆系统中的伏笔、钩子和情节点
|
||||
- **伏笔回收**:可以考虑回收未完结的伏笔,制造呼应
|
||||
- **角色发展**:遵循角色在前文中的成长轨迹
|
||||
- **情节阶段**:遵循情节阶段的发展要求
|
||||
- **风格一致**:保持与已有章节相同的风格和详细程度
|
||||
|
||||
**重要格式要求:**
|
||||
1. 只返回纯JSON数组格式,不要包含任何markdown标记、代码块标记或其他说明文字
|
||||
@@ -465,7 +472,7 @@ class PromptService:
|
||||
|
||||
请直接输出章节正文内容,不要包含章节标题和其他说明文字。"""
|
||||
|
||||
# 章节完整创作提示词(带前置章节上下文)
|
||||
# 章节完整创作提示词(带前置章节上下文和记忆增强)
|
||||
CHAPTER_GENERATION_WITH_CONTEXT = """你是一位专业的小说作家。请根据以下信息创作本章内容:
|
||||
|
||||
项目信息:
|
||||
@@ -489,6 +496,11 @@ class PromptService:
|
||||
【已完成的前置章节内容】
|
||||
{previous_content}
|
||||
|
||||
【🧠 智能记忆系统 - 重要参考】
|
||||
以下是从故事记忆库中检索到的相关信息,请在创作时适当参考和呼应:
|
||||
|
||||
{memory_context}
|
||||
|
||||
本章信息:
|
||||
- 章节序号:第{chapter_number}章
|
||||
- 章节标题:{chapter_title}
|
||||
@@ -518,8 +530,15 @@ class PromptService:
|
||||
- 体现世界观特色
|
||||
|
||||
5. **承上启下**:
|
||||
- 开头自然衔接上一章结尾
|
||||
- 结尾为下一章做好铺垫
|
||||
- 开头自然衔接上一章结尾
|
||||
- 结尾为下一章做好铺垫
|
||||
|
||||
6. **记忆系统使用指南**:
|
||||
- **最近章节记忆**:保持情节连贯,注意角色状态和剧情发展
|
||||
- **语义相关记忆**:参考相似情节的处理方式
|
||||
- **未完结伏笔**:适当时机可以回收伏笔,制造呼应效果
|
||||
- **角色状态记忆**:确保角色行为符合其发展轨迹
|
||||
- **重要情节点**:与关键剧情保持一致
|
||||
|
||||
请直接输出章节正文内容,不要包含章节标题和其他说明文字。"""
|
||||
|
||||
@@ -746,14 +765,26 @@ class PromptService:
|
||||
characters_info: str, outlines_context: str,
|
||||
chapter_number: int, chapter_title: str,
|
||||
chapter_outline: str, style_content: str = "",
|
||||
target_word_count: int = 3000) -> str:
|
||||
target_word_count: int = 3000,
|
||||
memory_context: dict = None) -> str:
|
||||
"""
|
||||
获取章节完整创作提示词
|
||||
|
||||
Args:
|
||||
style_content: 写作风格要求内容,如果提供则会追加到提示词中
|
||||
target_word_count: 目标字数,默认3000字
|
||||
memory_context: 记忆上下文(可选)
|
||||
"""
|
||||
# 格式化记忆上下文
|
||||
memory_text = ""
|
||||
if memory_context:
|
||||
memory_text = "\n【🧠 智能记忆系统 - 重要参考】\n"
|
||||
memory_text += memory_context.get('recent_context', '')
|
||||
memory_text += "\n" + memory_context.get('relevant_memories', '')
|
||||
memory_text += "\n" + memory_context.get('foreshadows', '')
|
||||
memory_text += "\n" + memory_context.get('character_states', '')
|
||||
memory_text += "\n" + memory_context.get('plot_points', '')
|
||||
|
||||
base_prompt = cls.format_prompt(
|
||||
cls.CHAPTER_GENERATION,
|
||||
title=title,
|
||||
@@ -772,6 +803,13 @@ class PromptService:
|
||||
target_word_count=target_word_count
|
||||
)
|
||||
|
||||
# 插入记忆上下文
|
||||
if memory_text:
|
||||
base_prompt = base_prompt.replace(
|
||||
"本章信息:",
|
||||
memory_text + "\n\n本章信息:"
|
||||
)
|
||||
|
||||
# 如果有风格要求,应用到提示词中
|
||||
if style_content:
|
||||
return WritingStyleManager.apply_style_to_prompt(base_prompt, style_content)
|
||||
@@ -786,14 +824,27 @@ class PromptService:
|
||||
previous_content: str, chapter_number: int,
|
||||
chapter_title: str, chapter_outline: str,
|
||||
style_content: str = "",
|
||||
target_word_count: int = 3000) -> str:
|
||||
target_word_count: int = 3000,
|
||||
memory_context: dict = None) -> str:
|
||||
"""
|
||||
获取章节完整创作提示词(带前置章节上下文)
|
||||
获取章节完整创作提示词(带前置章节上下文和记忆增强)
|
||||
|
||||
Args:
|
||||
style_content: 写作风格要求内容,如果提供则会追加到提示词中
|
||||
target_word_count: 目标字数,默认3000字
|
||||
memory_context: 记忆上下文(可选)
|
||||
"""
|
||||
# 格式化记忆上下文
|
||||
memory_text = ""
|
||||
if memory_context:
|
||||
memory_text = memory_context.get('recent_context', '')
|
||||
memory_text += "\n" + memory_context.get('relevant_memories', '')
|
||||
memory_text += "\n" + memory_context.get('foreshadows', '')
|
||||
memory_text += "\n" + memory_context.get('character_states', '')
|
||||
memory_text += "\n" + memory_context.get('plot_points', '')
|
||||
else:
|
||||
memory_text = "暂无相关记忆"
|
||||
|
||||
base_prompt = cls.format_prompt(
|
||||
cls.CHAPTER_GENERATION_WITH_CONTEXT,
|
||||
title=title,
|
||||
@@ -810,7 +861,8 @@ class PromptService:
|
||||
chapter_number=chapter_number,
|
||||
chapter_title=chapter_title,
|
||||
chapter_outline=chapter_outline,
|
||||
target_word_count=target_word_count
|
||||
target_word_count=target_word_count,
|
||||
memory_context=memory_text
|
||||
)
|
||||
|
||||
# 如果有风格要求,应用到提示词中
|
||||
@@ -839,9 +891,22 @@ class PromptService:
|
||||
current_chapter_count: int, all_chapters_brief: str,
|
||||
recent_plot: str, plot_stage_instruction: str,
|
||||
start_chapter: int, story_direction: str,
|
||||
requirements: str = "") -> str:
|
||||
"""获取大纲续写提示词"""
|
||||
requirements: str = "",
|
||||
memory_context: dict = None) -> str:
|
||||
"""获取大纲续写提示词(支持记忆增强)"""
|
||||
end_chapter = start_chapter + chapter_count - 1
|
||||
|
||||
# 格式化记忆上下文
|
||||
memory_text = ""
|
||||
if memory_context:
|
||||
memory_text = memory_context.get('recent_context', '')
|
||||
memory_text += "\n" + memory_context.get('relevant_memories', '')
|
||||
memory_text += "\n" + memory_context.get('foreshadows', '')
|
||||
memory_text += "\n" + memory_context.get('character_states', '')
|
||||
memory_text += "\n" + memory_context.get('plot_points', '')
|
||||
else:
|
||||
memory_text = "暂无相关记忆(可能是首次续写或记忆库为空)"
|
||||
|
||||
return cls.format_prompt(
|
||||
cls.OUTLINE_CONTINUE_GENERATION,
|
||||
title=title,
|
||||
@@ -861,7 +926,8 @@ class PromptService:
|
||||
start_chapter=start_chapter,
|
||||
end_chapter=end_chapter,
|
||||
story_direction=story_direction,
|
||||
requirements=requirements or "无特殊要求"
|
||||
requirements=requirements or "无特殊要求",
|
||||
memory_context=memory_text
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
||||
Reference in New Issue
Block a user