update:1.更新导入导出功能 2.实现RAG记忆功能,引入剧情分析功能

This commit is contained in:
xiamuceer
2025-11-04 14:38:59 +08:00
parent 1cde345ed9
commit e4f90d5da0
26 changed files with 6722 additions and 84 deletions
@@ -0,0 +1,769 @@
"""导入导出服务"""
import json
from datetime import datetime
from typing import Dict, List, Optional, Tuple
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
from app.models.project import Project
from app.models.chapter import Chapter
from app.models.character import Character
from app.models.outline import Outline
from app.models.relationship import CharacterRelationship, Organization, OrganizationMember
from app.models.writing_style import WritingStyle
from app.models.generation_history import GenerationHistory
from app.schemas.import_export import (
ProjectExportData,
ChapterExportData,
CharacterExportData,
OutlineExportData,
RelationshipExportData,
OrganizationExportData,
OrganizationMemberExportData,
WritingStyleExportData,
GenerationHistoryExportData,
ImportValidationResult,
ImportResult
)
from app.logger import get_logger
logger = get_logger(__name__)
class ImportExportService:
"""导入导出服务类"""
SUPPORTED_VERSION = "1.0.0"
@staticmethod
async def export_project(
project_id: str,
db: AsyncSession,
include_generation_history: bool = False,
include_writing_styles: bool = True
) -> ProjectExportData:
"""
导出项目完整数据
Args:
project_id: 项目ID
db: 数据库会话
include_generation_history: 是否包含生成历史
include_writing_styles: 是否包含写作风格
Returns:
ProjectExportData: 导出的项目数据
"""
logger.info(f"开始导出项目: {project_id}")
# 获取项目基本信息
result = await db.execute(select(Project).where(Project.id == project_id))
project = result.scalar_one_or_none()
if not project:
raise ValueError(f"项目不存在: {project_id}")
# 项目基本信息
project_data = {
"title": project.title,
"description": project.description,
"theme": project.theme,
"genre": project.genre,
"target_words": project.target_words,
"current_words": project.current_words,
"status": project.status,
"world_time_period": project.world_time_period,
"world_location": project.world_location,
"world_atmosphere": project.world_atmosphere,
"world_rules": project.world_rules,
"chapter_count": project.chapter_count,
"narrative_perspective": project.narrative_perspective,
"character_count": project.character_count,
"created_at": project.created_at.isoformat() if project.created_at else None,
}
# 导出章节
chapters = await ImportExportService._export_chapters(project_id, db)
logger.info(f"导出章节数: {len(chapters)}")
# 导出角色
characters = await ImportExportService._export_characters(project_id, db)
logger.info(f"导出角色数: {len(characters)}")
# 导出大纲
outlines = await ImportExportService._export_outlines(project_id, db)
logger.info(f"导出大纲数: {len(outlines)}")
# 导出关系
relationships = await ImportExportService._export_relationships(project_id, db)
logger.info(f"导出关系数: {len(relationships)}")
# 导出组织详情
organizations = await ImportExportService._export_organizations(project_id, db)
logger.info(f"导出组织数: {len(organizations)}")
# 导出组织成员
org_members = await ImportExportService._export_organization_members(project_id, db)
logger.info(f"导出组织成员数: {len(org_members)}")
# 导出写作风格(可选)
writing_styles = []
if include_writing_styles:
writing_styles = await ImportExportService._export_writing_styles(project_id, db)
logger.info(f"导出写作风格数: {len(writing_styles)}")
# 导出生成历史(可选)
generation_history = []
if include_generation_history:
generation_history = await ImportExportService._export_generation_history(project_id, db)
logger.info(f"导出生成历史数: {len(generation_history)}")
export_data = ProjectExportData(
version=ImportExportService.SUPPORTED_VERSION,
export_time=datetime.utcnow().isoformat(),
project=project_data,
chapters=chapters,
characters=characters,
outlines=outlines,
relationships=relationships,
organizations=organizations,
organization_members=org_members,
writing_styles=writing_styles,
generation_history=generation_history
)
logger.info(f"项目导出完成: {project_id}")
return export_data
@staticmethod
async def _export_chapters(project_id: str, db: AsyncSession) -> List[ChapterExportData]:
"""导出章节"""
result = await db.execute(
select(Chapter)
.where(Chapter.project_id == project_id)
.order_by(Chapter.chapter_number)
)
chapters = result.scalars().all()
return [
ChapterExportData(
title=ch.title,
content=ch.content,
summary=ch.summary,
chapter_number=ch.chapter_number,
word_count=ch.word_count or 0,
status=ch.status,
created_at=ch.created_at.isoformat() if ch.created_at else None
)
for ch in chapters
]
@staticmethod
async def _export_characters(project_id: str, db: AsyncSession) -> List[CharacterExportData]:
"""导出角色"""
result = await db.execute(
select(Character).where(Character.project_id == project_id)
)
characters = result.scalars().all()
exported = []
for char in characters:
# 解析traits JSON
traits = None
if char.traits:
try:
traits = json.loads(char.traits) if isinstance(char.traits, str) else char.traits
except:
traits = None
exported.append(CharacterExportData(
name=char.name,
age=char.age,
gender=char.gender,
is_organization=char.is_organization or False,
role_type=char.role_type,
personality=char.personality,
background=char.background,
appearance=char.appearance,
traits=traits,
organization_type=char.organization_type,
organization_purpose=char.organization_purpose,
created_at=char.created_at.isoformat() if char.created_at else None
))
return exported
@staticmethod
async def _export_outlines(project_id: str, db: AsyncSession) -> List[OutlineExportData]:
"""导出大纲"""
result = await db.execute(
select(Outline)
.where(Outline.project_id == project_id)
.order_by(Outline.order_index)
)
outlines = result.scalars().all()
return [
OutlineExportData(
title=ol.title,
content=ol.content,
structure=ol.structure,
order_index=ol.order_index,
created_at=ol.created_at.isoformat() if ol.created_at else None
)
for ol in outlines
]
@staticmethod
async def _export_relationships(project_id: str, db: AsyncSession) -> List[RelationshipExportData]:
"""导出关系"""
result = await db.execute(
select(CharacterRelationship, Character)
.join(Character, CharacterRelationship.character_from_id == Character.id)
.where(CharacterRelationship.project_id == project_id)
)
relationships = result.all()
exported = []
for rel, char_from in relationships:
# 获取目标角色名称
target_result = await db.execute(
select(Character).where(Character.id == rel.character_to_id)
)
char_to = target_result.scalar_one_or_none()
if char_to:
exported.append(RelationshipExportData(
source_name=char_from.name,
target_name=char_to.name,
relationship_name=rel.relationship_name,
intimacy_level=rel.intimacy_level or 50,
status=rel.status or "active",
description=rel.description,
started_at=rel.started_at
))
return exported
@staticmethod
async def _export_organizations(project_id: str, db: AsyncSession) -> List[OrganizationExportData]:
"""导出组织详情"""
result = await db.execute(
select(Organization, Character)
.join(Character, Organization.character_id == Character.id)
.where(Organization.project_id == project_id)
)
organizations = result.all()
exported = []
for org, char in organizations:
# 获取父组织名称
parent_name = None
if org.parent_org_id:
parent_result = await db.execute(
select(Organization, Character)
.join(Character, Organization.character_id == Character.id)
.where(Organization.id == org.parent_org_id)
)
parent_data = parent_result.first()
if parent_data:
parent_name = parent_data[1].name
exported.append(OrganizationExportData(
character_name=char.name,
parent_org_name=parent_name,
power_level=org.power_level or 50,
member_count=org.member_count or 0,
location=org.location,
motto=org.motto,
color=org.color
))
return exported
@staticmethod
async def _export_organization_members(project_id: str, db: AsyncSession) -> List[OrganizationMemberExportData]:
"""导出组织成员"""
result = await db.execute(
select(OrganizationMember, Organization, Character)
.join(Organization, OrganizationMember.organization_id == Organization.id)
.join(Character, Organization.character_id == Character.id)
.where(Organization.project_id == project_id)
)
members = result.all()
exported = []
for member, org, org_char in members:
# 获取成员角色名称
char_result = await db.execute(
select(Character).where(Character.id == member.character_id)
)
member_char = char_result.scalar_one_or_none()
if member_char:
exported.append(OrganizationMemberExportData(
organization_name=org_char.name,
character_name=member_char.name,
position=member.position,
rank=member.rank or 0,
status=member.status or "active",
joined_at=member.joined_at,
loyalty=member.loyalty or 50,
contribution=member.contribution or 0,
notes=member.notes
))
return exported
@staticmethod
async def _export_writing_styles(project_id: str, db: AsyncSession) -> List[WritingStyleExportData]:
"""导出写作风格"""
result = await db.execute(
select(WritingStyle)
.where(WritingStyle.project_id == project_id)
.order_by(WritingStyle.order_index)
)
styles = result.scalars().all()
return [
WritingStyleExportData(
name=style.name,
style_type=style.style_type,
preset_id=style.preset_id,
description=style.description,
prompt_content=style.prompt_content,
order_index=style.order_index or 0
)
for style in styles
]
@staticmethod
async def _export_generation_history(project_id: str, db: AsyncSession) -> List[GenerationHistoryExportData]:
"""导出生成历史"""
result = await db.execute(
select(GenerationHistory, Chapter)
.outerjoin(Chapter, GenerationHistory.chapter_id == Chapter.id)
.where(GenerationHistory.project_id == project_id)
.order_by(GenerationHistory.created_at.desc())
.limit(100) # 限制最多导出100条历史记录
)
histories = result.all()
return [
GenerationHistoryExportData(
chapter_title=chapter.title if chapter else None,
prompt=history.prompt,
generated_content=history.generated_content,
model=history.model,
tokens_used=history.tokens_used,
generation_time=history.generation_time,
created_at=history.created_at.isoformat() if history.created_at else None
)
for history, chapter in histories
]
@staticmethod
def validate_import_data(data: Dict) -> ImportValidationResult:
"""
验证导入数据
Args:
data: 导入的JSON数据
Returns:
ImportValidationResult: 验证结果
"""
errors = []
warnings = []
statistics = {}
# 检查版本
version = data.get("version", "")
if not version:
errors.append("缺少版本信息")
elif version != ImportExportService.SUPPORTED_VERSION:
warnings.append(f"版本不匹配: 导入文件版本为 {version}, 当前支持版本为 {ImportExportService.SUPPORTED_VERSION}")
# 检查必需字段
if "project" not in data:
errors.append("缺少项目信息")
else:
project = data["project"]
if not project.get("title"):
errors.append("项目标题不能为空")
# 统计数据
statistics = {
"chapters": len(data.get("chapters", [])),
"characters": len(data.get("characters", [])),
"outlines": len(data.get("outlines", [])),
"relationships": len(data.get("relationships", [])),
"organizations": len(data.get("organizations", [])),
"organization_members": len(data.get("organization_members", [])),
"writing_styles": len(data.get("writing_styles", [])),
"generation_history": len(data.get("generation_history", []))
}
# 检查数据完整性
if statistics["chapters"] == 0:
warnings.append("项目没有章节数据")
if statistics["characters"] == 0:
warnings.append("项目没有角色数据")
project_name = data.get("project", {}).get("title", "未知项目")
return ImportValidationResult(
valid=len(errors) == 0,
version=version,
project_name=project_name,
statistics=statistics,
errors=errors,
warnings=warnings
)
@staticmethod
async def import_project(
data: Dict,
db: AsyncSession
) -> ImportResult:
"""
导入项目数据(创建新项目)
Args:
data: 导入的JSON数据
db: 数据库会话
Returns:
ImportResult: 导入结果
"""
warnings = []
statistics = {}
try:
# 验证数据
validation = ImportExportService.validate_import_data(data)
if not validation.valid:
return ImportResult(
success=False,
message=f"数据验证失败: {', '.join(validation.errors)}",
statistics={},
warnings=validation.warnings
)
warnings.extend(validation.warnings)
logger.info(f"开始导入项目: {validation.project_name}")
# 创建项目
project_data = data["project"]
new_project = Project(
title=project_data.get("title"),
description=project_data.get("description"),
theme=project_data.get("theme"),
genre=project_data.get("genre"),
target_words=project_data.get("target_words"),
status=project_data.get("status", "planning"),
world_time_period=project_data.get("world_time_period"),
world_location=project_data.get("world_location"),
world_atmosphere=project_data.get("world_atmosphere"),
world_rules=project_data.get("world_rules"),
chapter_count=project_data.get("chapter_count"),
narrative_perspective=project_data.get("narrative_perspective"),
character_count=project_data.get("character_count"),
current_words=project_data.get("current_words", 0), # 保留原项目的字数
wizard_step=4, # 导入的项目设置为向导完成状态
wizard_status="completed" # 标记向导已完成
)
db.add(new_project)
await db.flush() # 获取project_id
logger.info(f"创建项目成功: {new_project.id}")
# 导入章节
chapters_count = await ImportExportService._import_chapters(
new_project.id, data.get("chapters", []), db
)
statistics["chapters"] = chapters_count
logger.info(f"导入章节数: {chapters_count}")
# 导入角色(包括组织)
char_mapping = await ImportExportService._import_characters(
new_project.id, data.get("characters", []), db
)
statistics["characters"] = len(char_mapping)
logger.info(f"导入角色数: {len(char_mapping)}")
# 导入大纲
outlines_count = await ImportExportService._import_outlines(
new_project.id, data.get("outlines", []), db
)
statistics["outlines"] = outlines_count
logger.info(f"导入大纲数: {outlines_count}")
# 导入关系
relationships_count = await ImportExportService._import_relationships(
new_project.id, data.get("relationships", []), char_mapping, db
)
statistics["relationships"] = relationships_count
logger.info(f"导入关系数: {relationships_count}")
# 导入组织详情
org_mapping = await ImportExportService._import_organizations(
new_project.id, data.get("organizations", []), char_mapping, db
)
statistics["organizations"] = len(org_mapping)
logger.info(f"导入组织数: {len(org_mapping)}")
# 导入组织成员
org_members_count = await ImportExportService._import_organization_members(
data.get("organization_members", []), char_mapping, org_mapping, db
)
statistics["organization_members"] = org_members_count
logger.info(f"导入组织成员数: {org_members_count}")
# 导入写作风格
styles_count = await ImportExportService._import_writing_styles(
new_project.id, data.get("writing_styles", []), db
)
statistics["writing_styles"] = styles_count
logger.info(f"导入写作风格数: {styles_count}")
# 提交事务
await db.commit()
logger.info(f"项目导入完成: {new_project.id}")
return ImportResult(
success=True,
project_id=new_project.id,
message="项目导入成功",
statistics=statistics,
warnings=warnings
)
except Exception as e:
await db.rollback()
logger.error(f"导入项目失败: {str(e)}", exc_info=True)
return ImportResult(
success=False,
message=f"导入失败: {str(e)}",
statistics=statistics,
warnings=warnings
)
@staticmethod
async def _import_chapters(
project_id: str,
chapters_data: List[Dict],
db: AsyncSession
) -> int:
"""导入章节"""
count = 0
for ch_data in chapters_data:
chapter = Chapter(
project_id=project_id,
title=ch_data.get("title"),
content=ch_data.get("content"),
summary=ch_data.get("summary"),
chapter_number=ch_data.get("chapter_number"),
word_count=ch_data.get("word_count", 0),
status=ch_data.get("status", "draft")
)
db.add(chapter)
count += 1
return count
@staticmethod
async def _import_characters(
project_id: str,
characters_data: List[Dict],
db: AsyncSession
) -> Dict[str, str]:
"""导入角色,返回名称到ID的映射"""
char_mapping = {}
for char_data in characters_data:
# 处理traits
traits = char_data.get("traits")
if traits and isinstance(traits, list):
traits = json.dumps(traits, ensure_ascii=False)
character = Character(
project_id=project_id,
name=char_data.get("name"),
age=char_data.get("age"),
gender=char_data.get("gender"),
is_organization=char_data.get("is_organization", False),
role_type=char_data.get("role_type"),
personality=char_data.get("personality"),
background=char_data.get("background"),
appearance=char_data.get("appearance"),
traits=traits,
organization_type=char_data.get("organization_type"),
organization_purpose=char_data.get("organization_purpose")
)
db.add(character)
await db.flush() # 获取ID
char_mapping[char_data.get("name")] = character.id
return char_mapping
@staticmethod
async def _import_outlines(
project_id: str,
outlines_data: List[Dict],
db: AsyncSession
) -> int:
"""导入大纲"""
count = 0
for ol_data in outlines_data:
outline = Outline(
project_id=project_id,
title=ol_data.get("title"),
content=ol_data.get("content"),
structure=ol_data.get("structure"),
order_index=ol_data.get("order_index")
)
db.add(outline)
count += 1
return count
@staticmethod
async def _import_relationships(
project_id: str,
relationships_data: List[Dict],
char_mapping: Dict[str, str],
db: AsyncSession
) -> int:
"""导入关系"""
count = 0
for rel_data in relationships_data:
source_name = rel_data.get("source_name")
target_name = rel_data.get("target_name")
# 查找角色ID
source_id = char_mapping.get(source_name)
target_id = char_mapping.get(target_name)
if source_id and target_id:
relationship = CharacterRelationship(
project_id=project_id,
character_from_id=source_id,
character_to_id=target_id,
relationship_name=rel_data.get("relationship_name"),
intimacy_level=rel_data.get("intimacy_level", 50),
status=rel_data.get("status", "active"),
description=rel_data.get("description"),
started_at=rel_data.get("started_at")
)
db.add(relationship)
count += 1
return count
@staticmethod
async def _import_organizations(
project_id: str,
organizations_data: List[Dict],
char_mapping: Dict[str, str],
db: AsyncSession
) -> Dict[str, str]:
"""导入组织详情,返回名称到ID的映射"""
org_mapping = {}
# 第一遍:创建所有组织(不设置父组织)
temp_orgs = []
for org_data in organizations_data:
char_name = org_data.get("character_name")
char_id = char_mapping.get(char_name)
if char_id:
organization = Organization(
project_id=project_id,
character_id=char_id,
power_level=org_data.get("power_level", 50),
member_count=org_data.get("member_count", 0),
location=org_data.get("location"),
motto=org_data.get("motto"),
color=org_data.get("color")
)
db.add(organization)
temp_orgs.append((organization, org_data.get("parent_org_name")))
await db.flush() # 获取所有组织的ID
# 建立名称到ID的映射
for org, _ in temp_orgs:
# 通过character_id查找角色名
result = await db.execute(
select(Character).where(Character.id == org.character_id)
)
char = result.scalar_one_or_none()
if char:
org_mapping[char.name] = org.id
# 第二遍:设置父组织关系
for org, parent_name in temp_orgs:
if parent_name:
parent_id = org_mapping.get(parent_name)
if parent_id:
org.parent_org_id = parent_id
return org_mapping
@staticmethod
async def _import_organization_members(
org_members_data: List[Dict],
char_mapping: Dict[str, str],
org_mapping: Dict[str, str],
db: AsyncSession
) -> int:
"""导入组织成员"""
count = 0
for member_data in org_members_data:
org_name = member_data.get("organization_name")
char_name = member_data.get("character_name")
org_id = org_mapping.get(org_name)
char_id = char_mapping.get(char_name)
if org_id and char_id:
member = OrganizationMember(
organization_id=org_id,
character_id=char_id,
position=member_data.get("position"),
rank=member_data.get("rank", 0),
status=member_data.get("status", "active"),
joined_at=member_data.get("joined_at"),
loyalty=member_data.get("loyalty", 50),
contribution=member_data.get("contribution", 0),
notes=member_data.get("notes")
)
db.add(member)
count += 1
return count
@staticmethod
async def _import_writing_styles(
project_id: str,
styles_data: List[Dict],
db: AsyncSession
) -> int:
"""导入写作风格"""
count = 0
for style_data in styles_data:
style = WritingStyle(
project_id=project_id,
name=style_data.get("name"),
style_type=style_data.get("style_type"),
preset_id=style_data.get("preset_id"),
description=style_data.get("description"),
prompt_content=style_data.get("prompt_content"),
order_index=style_data.get("order_index", 0)
)
db.add(style)
count += 1
return count
+739
View File
@@ -0,0 +1,739 @@
"""向量记忆服务 - 基于ChromaDB实现长期记忆和语义检索"""
import chromadb
from sentence_transformers import SentenceTransformer
from typing import List, Dict, Any, Optional
import json
from datetime import datetime
from app.logger import get_logger
import os
import hashlib
logger = get_logger(__name__)
# 配置离线模式,避免联网检查
os.environ['TRANSFORMERS_OFFLINE'] = '1'
os.environ['HF_DATASETS_OFFLINE'] = '1'
class MemoryService:
"""向量记忆管理服务 - 实现语义检索和长期记忆"""
_instance = None
_initialized = False
def __new__(cls):
"""单例模式"""
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self):
"""初始化ChromaDB和Embedding模型"""
if self._initialized:
return
try:
# 确保数据目录存在
chroma_dir = "data/chroma_db"
os.makedirs(chroma_dir, exist_ok=True)
# 初始化ChromaDB客户端(使用新API - PersistentClient)
self.client = chromadb.PersistentClient(path=chroma_dir)
# 初始化多语言embedding模型(支持中文)
logger.info("🔄 正在加载Embedding模型...")
# 确保模型缓存目录存在
model_cache_dir = 'data/models'
os.makedirs(model_cache_dir, exist_ok=True)
try:
# 优先使用本地缓存的模型
# cache_folder会让模型优先从本地加载,只有不存在时才联网下载
self.embedding_model = SentenceTransformer(
'paraphrase-multilingual-MiniLM-L12-v2',
cache_folder=model_cache_dir,
device='cpu' # 明确指定使用CPU
)
logger.info("✅ Embedding模型加载成功")
except Exception as e:
logger.warning(f"⚠️ 无法加载多语言模型: {str(e)}")
logger.info("🔄 尝试使用备用模型...")
try:
# 降级到更小的模型作为备选
self.embedding_model = SentenceTransformer(
'all-MiniLM-L6-v2',
cache_folder=model_cache_dir,
device='cpu'
)
logger.info("✅ 使用备用Embedding模型")
except Exception as e2:
logger.error(f"❌ 所有模型加载失败: {str(e2)}")
logger.error("💡 模型首次使用需要联网下载(约420MB)")
logger.error(" 或手动下载模型文件到 data/models 目录")
raise RuntimeError("无法加载任何Embedding模型")
self._initialized = True
logger.info("✅ MemoryService初始化成功")
logger.info(f" - ChromaDB目录: {chroma_dir}")
logger.info(f" - Embedding模型: paraphrase-multilingual-MiniLM-L12-v2")
except Exception as e:
logger.error(f"❌ MemoryService初始化失败: {str(e)}")
raise
def get_collection(self, user_id: str, project_id: str):
"""
获取或创建项目的记忆集合
每个用户的每个项目有独立的collection,实现数据隔离
Args:
user_id: 用户ID
project_id: 项目ID
Returns:
ChromaDB Collection对象
"""
# ChromaDB collection命名规则:
# 1. 3-63字符(最重要!)
# 2. 开头和结尾必须是字母或数字
# 3. 只能包含字母、数字、下划线或短横线
# 4. 不能包含连续的点(..)
# 5. 不能是有效的IPv4地址
# 使用SHA256哈希压缩ID长度,确保不超过63字符
# 格式: u_{user_hash}_p_{project_hash} (约30字符)
user_hash = hashlib.sha256(user_id.encode()).hexdigest()[:8]
project_hash = hashlib.sha256(project_id.encode()).hexdigest()[:8]
collection_name = f"u_{user_hash}_p_{project_hash}"
try:
return self.client.get_or_create_collection(
name=collection_name,
metadata={
"user_id": user_id,
"project_id": project_id,
"created_at": datetime.now().isoformat()
}
)
except Exception as e:
logger.error(f"❌ 获取collection失败: {str(e)}")
raise
async def add_memory(
self,
user_id: str,
project_id: str,
memory_id: str,
content: str,
memory_type: str,
metadata: Dict[str, Any]
) -> bool:
"""
添加记忆到向量数据库
Args:
user_id: 用户ID
project_id: 项目ID
memory_id: 记忆唯一ID
content: 记忆内容(将被转换为向量)
memory_type: 记忆类型
metadata: 附加元数据
Returns:
是否添加成功
"""
try:
collection = self.get_collection(user_id, project_id)
# 生成文本的向量表示
embedding = self.embedding_model.encode(content).tolist()
# 准备元数据(ChromaDB要求所有值为基础类型)
chroma_metadata = {
"memory_type": memory_type,
"chapter_id": str(metadata.get("chapter_id", "")),
"chapter_number": int(metadata.get("chapter_number", 0)),
"importance": float(metadata.get("importance_score", 0.5)),
"tags": json.dumps(metadata.get("tags", []), ensure_ascii=False),
"title": str(metadata.get("title", ""))[:200], # 限制长度
"is_foreshadow": int(metadata.get("is_foreshadow", 0)),
"created_at": datetime.now().isoformat()
}
# 添加相关角色信息
if metadata.get("related_characters"):
chroma_metadata["related_characters"] = json.dumps(
metadata["related_characters"],
ensure_ascii=False
)
# 存储到向量库
collection.add(
ids=[memory_id],
embeddings=[embedding],
documents=[content],
metadatas=[chroma_metadata]
)
logger.info(f"✅ 记忆已添加: {memory_id[:8]}... (类型:{memory_type}, 重要性:{chroma_metadata['importance']})")
return True
except Exception as e:
logger.error(f"❌ 添加记忆失败: {str(e)}")
return False
async def batch_add_memories(
self,
user_id: str,
project_id: str,
memories: List[Dict[str, Any]]
) -> int:
"""
批量添加记忆(性能更好)
Args:
user_id: 用户ID
project_id: 项目ID
memories: 记忆列表,每个包含id、content、type、metadata
Returns:
成功添加的数量
"""
if not memories:
return 0
try:
collection = self.get_collection(user_id, project_id)
ids = []
documents = []
metadatas = []
embeddings = []
# 批量准备数据
for mem in memories:
ids.append(mem['id'])
documents.append(mem['content'])
# 生成embedding
embedding = self.embedding_model.encode(mem['content']).tolist()
embeddings.append(embedding)
# 准备元数据
metadata = mem.get('metadata', {})
chroma_metadata = {
"memory_type": mem['type'],
"chapter_id": str(metadata.get("chapter_id", "")),
"chapter_number": int(metadata.get("chapter_number", 0)),
"importance": float(metadata.get("importance_score", 0.5)),
"tags": json.dumps(metadata.get("tags", []), ensure_ascii=False),
"title": str(metadata.get("title", ""))[:200],
"is_foreshadow": int(metadata.get("is_foreshadow", 0)),
"created_at": datetime.now().isoformat()
}
metadatas.append(chroma_metadata)
# 批量添加
collection.add(
ids=ids,
embeddings=embeddings,
documents=documents,
metadatas=metadatas
)
logger.info(f"✅ 批量添加记忆成功: {len(memories)}")
return len(memories)
except Exception as e:
logger.error(f"❌ 批量添加记忆失败: {str(e)}")
return 0
async def search_memories(
self,
user_id: str,
project_id: str,
query: str,
memory_types: Optional[List[str]] = None,
limit: int = 10,
min_importance: float = 0.0,
chapter_range: Optional[tuple] = None
) -> List[Dict[str, Any]]:
"""
语义搜索相关记忆
Args:
user_id: 用户ID
project_id: 项目ID
query: 查询文本(会被转换为向量进行相似度搜索)
memory_types: 过滤特定类型的记忆
limit: 返回结果数量
min_importance: 最低重要性阈值
chapter_range: 章节范围 (start, end)
Returns:
相关记忆列表,按相似度排序
"""
try:
collection = self.get_collection(user_id, project_id)
# 生成查询向量
query_embedding = self.embedding_model.encode(query).tolist()
# 构建过滤条件 - ChromaDB要求使用$and组合多个条件
where_filter = None
conditions = []
if memory_types:
conditions.append({"memory_type": {"$in": memory_types}})
if min_importance > 0:
conditions.append({"importance": {"$gte": min_importance}})
if chapter_range:
conditions.append({"chapter_number": {"$gte": chapter_range[0]}})
conditions.append({"chapter_number": {"$lte": chapter_range[1]}})
# 根据条件数量选择合适的格式
if len(conditions) == 0:
where_filter = None
elif len(conditions) == 1:
where_filter = conditions[0]
else:
where_filter = {"$and": conditions}
# 执行向量相似度搜索
results = collection.query(
query_embeddings=[query_embedding],
n_results=limit,
where=where_filter
)
# 格式化结果
memories = []
if results['ids'] and results['ids'][0]:
for i in range(len(results['ids'][0])):
memories.append({
"id": results['ids'][0][i],
"content": results['documents'][0][i],
"metadata": results['metadatas'][0][i],
"similarity": 1 - results['distances'][0][i] if 'distances' in results else 1.0,
"distance": results['distances'][0][i] if 'distances' in results else 0.0
})
logger.info(f"🔍 语义搜索完成: 查询='{query[:30]}...', 找到{len(memories)}条记忆")
return memories
except Exception as e:
logger.error(f"❌ 搜索记忆失败: {str(e)}")
return []
async def get_recent_memories(
self,
user_id: str,
project_id: str,
current_chapter: int,
recent_count: int = 3,
min_importance: float = 0.5
) -> List[Dict[str, Any]]:
"""
获取最近几章的重要记忆(用于保持连贯性)
Args:
user_id: 用户ID
project_id: 项目ID
current_chapter: 当前章节号
recent_count: 获取最近几章
min_importance: 最低重要性阈值
Returns:
最近章节的记忆列表,按重要性排序
"""
try:
collection = self.get_collection(user_id, project_id)
# 计算章节范围
start_chapter = max(1, current_chapter - recent_count)
# 获取最近章节的记忆
results = collection.get(
where={
"$and": [
{"chapter_number": {"$gte": start_chapter}},
{"chapter_number": {"$lt": current_chapter}},
{"importance": {"$gte": min_importance}}
]
},
limit=100 # 先获取足够多的记忆
)
memories = []
if results['ids']:
for i in range(len(results['ids'])):
memories.append({
"id": results['ids'][i],
"content": results['documents'][i],
"metadata": results['metadatas'][i]
})
# 按重要性和章节号排序
memories.sort(
key=lambda x: (float(x['metadata'].get('importance', 0)),
int(x['metadata'].get('chapter_number', 0))),
reverse=True
)
# 返回最重要的前N条
top_memories = memories[:20]
logger.info(f"📚 获取最近记忆: 章节{start_chapter}-{current_chapter-1}, 找到{len(top_memories)}")
return top_memories
except Exception as e:
logger.error(f"❌ 获取最近记忆失败: {str(e)}")
return []
async def find_unresolved_foreshadows(
self,
user_id: str,
project_id: str,
current_chapter: int
) -> List[Dict[str, Any]]:
"""
查找未完结的伏笔
Args:
user_id: 用户ID
project_id: 项目ID
current_chapter: 当前章节号
Returns:
未完结伏笔列表
"""
try:
collection = self.get_collection(user_id, project_id)
# 查找伏笔状态为1(已埋下但未回收)的记忆
results = collection.get(
where={
"$and": [
{"is_foreshadow": 1},
{"chapter_number": {"$lt": current_chapter}}
]
},
limit=50
)
foreshadows = []
if results['ids']:
for i in range(len(results['ids'])):
foreshadows.append({
"id": results['ids'][i],
"content": results['documents'][i],
"metadata": results['metadatas'][i]
})
# 按重要性排序
foreshadows.sort(
key=lambda x: float(x['metadata'].get('importance', 0)),
reverse=True
)
logger.info(f"🎣 找到未完结伏笔: {len(foreshadows)}")
return foreshadows
except Exception as e:
logger.error(f"❌ 查找伏笔失败: {str(e)}")
return []
async def build_context_for_generation(
self,
user_id: str,
project_id: str,
current_chapter: int,
chapter_outline: str,
character_names: List[str] = None
) -> Dict[str, Any]:
"""
为章节生成构建智能上下文
这是核心功能: 结合多种检索策略,为AI生成提供最相关的记忆
Args:
user_id: 用户ID
project_id: 项目ID
current_chapter: 当前章节号
chapter_outline: 本章大纲
character_names: 涉及的角色名列表
Returns:
包含各种上下文信息的字典
"""
logger.info(f"🧠 开始构建章节{current_chapter}的智能上下文...")
# 1. 获取最近章节上下文(时间连续性)
recent = await self.get_recent_memories(
user_id, project_id, current_chapter,
recent_count=3, min_importance=0.5
)
# 2. 语义搜索相关记忆
relevant = await self.search_memories(
user_id=user_id,
project_id=project_id,
query=chapter_outline,
limit=10,
min_importance=0.4
)
# 3. 查找未完结伏笔
foreshadows = await self.find_unresolved_foreshadows(
user_id, project_id, current_chapter
)
# 4. 如果有指定角色,获取角色相关记忆
character_memories = []
if character_names:
character_query = " ".join(character_names) + " 角色 状态 关系"
character_memories = await self.search_memories(
user_id=user_id,
project_id=project_id,
query=character_query,
memory_types=["character_event", "plot_point"],
limit=8
)
# 5. 获取重要情节点
# 注意:ChromaDB的where条件需要特殊处理,不能同时使用多个顶层条件
try:
plot_points = await self.search_memories(
user_id=user_id,
project_id=project_id,
query="重要 转折 高潮 关键",
memory_types=["plot_point", "hook"],
limit=5,
min_importance=0.7
)
except Exception as e:
logger.error(f"❌ 搜索记忆失败: {str(e)}")
# 降级处理:分别查询
plot_points = []
try:
plot_points = await self.search_memories(
user_id=user_id,
project_id=project_id,
query="重要 转折 高潮 关键",
memory_types=["plot_point", "hook"],
limit=5
)
except Exception as e2:
logger.warning(f"⚠️ 降级查询也失败: {str(e2)}")
plot_points = []
context = {
"recent_context": self._format_memories(recent, "最近章节记忆"),
"relevant_memories": self._format_memories(relevant, "语义相关记忆"),
"character_states": self._format_memories(character_memories, "角色相关记忆"),
"foreshadows": self._format_memories(foreshadows[:5], "未完结伏笔"),
"plot_points": self._format_memories(plot_points, "重要情节点"),
"stats": {
"recent_count": len(recent),
"relevant_count": len(relevant),
"character_count": len(character_memories),
"foreshadow_count": len(foreshadows),
"plot_point_count": len(plot_points)
}
}
logger.info(f"✅ 上下文构建完成: 最近{len(recent)}条, 相关{len(relevant)}条, 伏笔{len(foreshadows)}")
return context
def _format_memories(self, memories: List[Dict], section_title: str = "记忆") -> str:
"""
格式化记忆列表为文本
Args:
memories: 记忆列表
section_title: 章节标题
Returns:
格式化后的文本
"""
if not memories:
return f"{section_title}\n暂无相关记忆\n"
lines = [f"{section_title}"]
for i, mem in enumerate(memories, 1):
meta = mem.get('metadata', {})
chapter_num = meta.get('chapter_number', '?')
mem_type = meta.get('memory_type', '未知')
importance = float(meta.get('importance', 0.5))
title = meta.get('title', '')
content = mem['content']
# 格式: [序号] 第X章-类型(重要性) 标题: 内容
line = f"{i}. [第{chapter_num}章-{mem_type}{importance:.1f}]"
if title:
line += f" {title}: {content[:100]}"
else:
line += f" {content[:150]}"
lines.append(line)
return "\n".join(lines) + "\n"
async def delete_chapter_memories(
self,
user_id: str,
project_id: str,
chapter_id: str
) -> bool:
"""
删除指定章节的所有记忆
Args:
user_id: 用户ID
project_id: 项目ID
chapter_id: 章节ID
Returns:
是否删除成功
"""
try:
collection = self.get_collection(user_id, project_id)
# 查找该章节的所有记忆
results = collection.get(
where={"chapter_id": chapter_id}
)
if results['ids']:
# 删除这些记忆
collection.delete(ids=results['ids'])
logger.info(f"🗑️ 已删除章节{chapter_id[:8]}{len(results['ids'])}条记忆")
return True
else:
logger.info(f"️ 章节{chapter_id[:8]}没有记忆需要删除")
return True
except Exception as e:
logger.error(f"❌ 删除章节记忆失败: {str(e)}")
return False
async def update_memory(
self,
user_id: str,
project_id: str,
memory_id: str,
content: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None
) -> bool:
"""
更新记忆内容或元数据
Args:
user_id: 用户ID
project_id: 项目ID
memory_id: 记忆ID
content: 新内容(可选)
metadata: 新元数据(可选)
Returns:
是否更新成功
"""
try:
collection = self.get_collection(user_id, project_id)
update_data = {}
if content:
# 重新生成embedding
embedding = self.embedding_model.encode(content).tolist()
update_data['embeddings'] = [embedding]
update_data['documents'] = [content]
if metadata:
# 准备新的元数据
chroma_metadata = {}
for key, value in metadata.items():
if isinstance(value, (list, dict)):
chroma_metadata[key] = json.dumps(value, ensure_ascii=False)
else:
chroma_metadata[key] = value
update_data['metadatas'] = [chroma_metadata]
if update_data:
collection.update(
ids=[memory_id],
**update_data
)
logger.info(f"✅ 记忆已更新: {memory_id[:8]}...")
return True
else:
logger.warning("⚠️ 没有提供更新内容")
return False
except Exception as e:
logger.error(f"❌ 更新记忆失败: {str(e)}")
return False
async def get_memory_stats(
self,
user_id: str,
project_id: str
) -> Dict[str, Any]:
"""
获取记忆统计信息
Args:
user_id: 用户ID
project_id: 项目ID
Returns:
统计信息字典
"""
try:
collection = self.get_collection(user_id, project_id)
# 获取所有记忆
all_memories = collection.get()
if not all_memories['ids']:
return {
"total_count": 0,
"by_type": {},
"by_chapter": {},
"foreshadow_count": 0
}
# 统计各类型数量
type_counts = {}
chapter_counts = {}
foreshadow_count = 0
for i, meta in enumerate(all_memories['metadatas']):
mem_type = meta.get('memory_type', 'unknown')
chapter_num = meta.get('chapter_number', 0)
is_foreshadow = meta.get('is_foreshadow', 0)
type_counts[mem_type] = type_counts.get(mem_type, 0) + 1
chapter_counts[str(chapter_num)] = chapter_counts.get(str(chapter_num), 0) + 1
if is_foreshadow == 1:
foreshadow_count += 1
stats = {
"total_count": len(all_memories['ids']),
"by_type": type_counts,
"by_chapter": chapter_counts,
"foreshadow_count": foreshadow_count,
"foreshadow_resolved": sum(1 for m in all_memories['metadatas'] if m.get('is_foreshadow') == 2)
}
logger.info(f"📊 记忆统计: 总计{stats['total_count']}条, 伏笔{foreshadow_count}")
return stats
except Exception as e:
logger.error(f"❌ 获取统计信息失败: {str(e)}")
return {"error": str(e)}
# 创建全局实例
memory_service = MemoryService()
+559
View File
@@ -0,0 +1,559 @@
"""剧情分析服务 - 自动分析章节的钩子、伏笔、冲突等元素"""
from typing import Dict, Any, List, Optional
from app.services.ai_service import AIService
from app.logger import get_logger
import json
import re
logger = get_logger(__name__)
class PlotAnalyzer:
"""剧情分析器 - 使用AI分析章节内容"""
# AI分析提示词模板
ANALYSIS_PROMPT = """你是一位专业的小说编辑和剧情分析师。请深度分析以下章节内容:
**章节信息:**
- 章节: 第{chapter_number}
- 标题: {title}
- 字数: {word_count}
**章节内容:**
{content}
---
**分析任务:**
请从专业编辑的角度,全面分析这一章节:
### 1. 剧情钩子 (Hooks) - 吸引读者的元素
识别能够吸引读者继续阅读的关键元素:
- **悬念钩子**: 未解之谜、疑问、谜团
- **情感钩子**: 引发共鸣的情感点、触动心弦的时刻
- **冲突钩子**: 矛盾对抗、紧张局势
- **认知钩子**: 颠覆认知的信息、惊人真相
每个钩子需要:
- 类型分类
- 具体内容描述
- 强度评分(1-10)
- 出现位置(开头/中段/结尾)
- **关键词**: 【必填】从章节原文中逐字复制一段关键文本(8-25字),必须是原文中真实存在的连续文字,用于在文本中精确定位。不要概括或改写,必须原样复制!
### 2. 伏笔分析 (Foreshadowing)
- **埋下的新伏笔**: 描述内容、预期作用、隐藏程度(1-10)
- **回收的旧伏笔**: 呼应哪一章、回收效果评分
- **伏笔质量**: 巧妙性和合理性评估
- **关键词**: 【必填】从章节原文中逐字复制一段关键文本(8-25字),必须是原文中真实存在的连续文字,用于在文本中精确定位。不要概括或改写,必须原样复制!
### 3. 冲突分析 (Conflict)
- 冲突类型: 人与人/人与己/人与环境/人与社会
- 冲突各方及其立场
- 冲突强度评分(1-10)
- 冲突解决进度(0-100%)
### 4. 情感曲线 (Emotional Arc)
- 主导情绪: 紧张/温馨/悲伤/激昂/平静等
- 情感强度(1-10)
- 情绪变化轨迹描述
### 5. 角色状态追踪 (Character Development)
对每个出场角色分析:
- 心理状态变化(前→后)
- 关系变化
- 关键行动和决策
- 成长或退步
### 6. 关键情节点 (Plot Points)
列出3-5个核心情节点:
- 情节内容
- 类型(revelation/conflict/resolution/transition)
- 重要性(0.0-1.0)
- 对故事的影响
- **关键词**: 【必填】从章节原文中逐字复制一段关键文本(8-25字),必须是原文中真实存在的连续文字,用于在文本中精确定位。不要概括或改写,必须原样复制!
### 7. 场景与节奏
- 主要场景
- 叙事节奏(快/中/慢)
- 对话与描写的比例
### 8. 质量评分
- 节奏把控: 1-10分
- 吸引力: 1-10分
- 连贯性: 1-10分
- 整体质量: 1-10分
### 9. 改进建议
提供3-5条具体的改进建议
---
**输出格式(纯JSON,不要markdown标记):**
{{
"hooks": [
{{
"type": "悬念",
"content": "具体描述",
"strength": 8,
"position": "中段",
"keyword": "必须从原文逐字复制的文本片段"
}}
],
"foreshadows": [
{{
"content": "伏笔内容",
"type": "planted",
"strength": 7,
"subtlety": 8,
"reference_chapter": null,
"keyword": "必须从原文逐字复制的文本片段"
}}
],
"conflict": {{
"types": ["人与人", "人与己"],
"parties": ["主角-复仇", "反派-维护现状"],
"level": 8,
"description": "冲突描述",
"resolution_progress": 0.3
}},
"emotional_arc": {{
"primary_emotion": "紧张",
"intensity": 8,
"curve": "平静→紧张→高潮→释放",
"secondary_emotions": ["期待", "焦虑"]
}},
"character_states": [
{{
"character_name": "张三",
"state_before": "犹豫",
"state_after": "坚定",
"psychological_change": "心理变化描述",
"key_event": "触发事件",
"relationship_changes": {{"李四": "关系改善"}}
}}
],
"plot_points": [
{{
"content": "情节点描述",
"type": "revelation",
"importance": 0.9,
"impact": "推动故事发展",
"keyword": "必须从原文逐字复制的文本片段"
}}
],
"scenes": [
{{
"location": "地点",
"atmosphere": "氛围",
"duration": "时长估计"
}}
],
"pacing": "varied",
"dialogue_ratio": 0.4,
"description_ratio": 0.3,
"scores": {{
"pacing": 8,
"engagement": 9,
"coherence": 8,
"overall": 8.5
}},
"plot_stage": "发展",
"suggestions": [
"具体建议1",
"具体建议2"
]
}}
**重要提示:**
1. 每个钩子、伏笔、情节点的keyword字段是必填的,不能为空
2. keyword必须是从章节原文中逐字复制的文本,长度8-25字
3. keyword用于在前端标注文本位置,所以必须能在原文中精确找到
4. 不要使用概括性语句或改写后的文字作为keyword
只返回JSON,不要其他说明。"""
def __init__(self, ai_service: AIService):
"""
初始化剧情分析器
Args:
ai_service: AI服务实例
"""
self.ai_service = ai_service
logger.info("✅ PlotAnalyzer初始化成功")
async def analyze_chapter(
self,
chapter_number: int,
title: str,
content: str,
word_count: int
) -> Optional[Dict[str, Any]]:
"""
分析单章内容
Args:
chapter_number: 章节号
title: 章节标题
content: 章节内容
word_count: 字数
Returns:
分析结果字典,失败返回None
"""
try:
logger.info(f"🔍 开始分析第{chapter_number}章: {title}")
# 如果内容过长,截取前8000字(避免超token)
analysis_content = content[:8000] if len(content) > 8000 else content
# 构建提示词
prompt = self.ANALYSIS_PROMPT.format(
chapter_number=chapter_number,
title=title,
word_count=word_count,
content=analysis_content
)
# 调用AI进行分析
# 注意:不指定max_tokens,使用用户在设置中配置的值
logger.info(f" 调用AI分析(内容长度: {len(analysis_content)}字)...")
response = await self.ai_service.generate_text(
prompt=prompt,
temperature=0.3 # 降低温度以获得更稳定的JSON输出
)
# 解析JSON结果
analysis_result = self._parse_analysis_response(response)
if analysis_result:
logger.info(f"✅ 第{chapter_number}章分析完成")
logger.info(f" - 钩子: {len(analysis_result.get('hooks', []))}")
logger.info(f" - 伏笔: {len(analysis_result.get('foreshadows', []))}")
logger.info(f" - 情节点: {len(analysis_result.get('plot_points', []))}")
logger.info(f" - 整体评分: {analysis_result.get('scores', {}).get('overall', 'N/A')}")
return analysis_result
else:
logger.error(f"❌ 第{chapter_number}章分析失败: JSON解析错误")
return None
except Exception as e:
logger.error(f"❌ 章节分析异常: {str(e)}")
return None
def _parse_analysis_response(self, response: str) -> Optional[Dict[str, Any]]:
"""
解析AI返回的分析结果
Args:
response: AI返回的文本
Returns:
解析后的字典,失败返回None
"""
try:
# 清理响应文本
cleaned = response.strip()
# 移除可能的markdown标记
cleaned = re.sub(r'^```json\s*', '', cleaned)
cleaned = re.sub(r'^```\s*', '', cleaned)
cleaned = re.sub(r'\s*```$', '', cleaned)
# 尝试解析JSON
result = json.loads(cleaned)
# 验证必要字段
required_fields = ['hooks', 'plot_points', 'scores']
for field in required_fields:
if field not in result:
logger.warning(f"⚠️ 分析结果缺少字段: {field}")
result[field] = [] if field != 'scores' else {}
return result
except json.JSONDecodeError as e:
logger.error(f"❌ JSON解析失败: {str(e)}")
logger.error(f" 原始响应(前500字): {response[:500]}")
# 尝试提取JSON部分
json_match = re.search(r'\{[\s\S]*\}', response)
if json_match:
try:
result = json.loads(json_match.group())
logger.info("✅ 通过正则提取成功解析JSON")
return result
except:
pass
return None
except Exception as e:
logger.error(f"❌ 解析异常: {str(e)}")
return None
def extract_memories_from_analysis(
self,
analysis: Dict[str, Any],
chapter_id: str,
chapter_number: int,
chapter_content: str = ""
) -> List[Dict[str, Any]]:
"""
从分析结果中提取记忆片段
Args:
analysis: 分析结果
chapter_id: 章节ID
chapter_number: 章节号
chapter_content: 章节完整内容(用于计算位置)
Returns:
记忆片段列表
"""
memories = []
try:
# 1. 提取钩子作为记忆
for i, hook in enumerate(analysis.get('hooks', [])):
if hook.get('strength', 0) >= 6: # 只保存强度>=6的钩子
keyword = hook.get('keyword', '')
position, length = self._find_text_position(chapter_content, keyword)
logger.info(f" 钩子位置: keyword='{keyword[:30]}...', pos={position}, len={length}")
memories.append({
'type': 'hook',
'content': f"[{hook.get('type', '未知')}钩子] {hook.get('content', '')}",
'title': f"{hook.get('type', '钩子')} - {hook.get('position', '')}",
'metadata': {
'chapter_id': chapter_id,
'chapter_number': chapter_number,
'importance_score': min(hook.get('strength', 5) / 10, 1.0),
'tags': [hook.get('type', '钩子'), hook.get('position', '')],
'is_foreshadow': 0,
'keyword': keyword,
'text_position': position,
'text_length': length,
'strength': hook.get('strength', 5),
'position_desc': hook.get('position', '')
}
})
# 2. 提取伏笔作为记忆
for i, foreshadow in enumerate(analysis.get('foreshadows', [])):
is_planted = foreshadow.get('type') == 'planted'
keyword = foreshadow.get('keyword', '')
position, length = self._find_text_position(chapter_content, keyword)
logger.info(f" 伏笔位置: keyword='{keyword[:30]}...', pos={position}, len={length}")
memories.append({
'type': 'foreshadow',
'content': foreshadow.get('content', ''),
'title': f"{'埋下伏笔' if is_planted else '回收伏笔'}",
'metadata': {
'chapter_id': chapter_id,
'chapter_number': chapter_number,
'importance_score': min(foreshadow.get('strength', 5) / 10, 1.0),
'tags': ['伏笔', foreshadow.get('type', 'planted')],
'is_foreshadow': 1 if is_planted else 2,
'reference_chapter': foreshadow.get('reference_chapter'),
'keyword': keyword,
'text_position': position,
'text_length': length,
'foreshadow_type': foreshadow.get('type', 'planted'),
'strength': foreshadow.get('strength', 5)
}
})
# 3. 提取关键情节点
for i, plot_point in enumerate(analysis.get('plot_points', [])):
if plot_point.get('importance', 0) >= 0.6: # 只保存重要性>=0.6的情节点
keyword = plot_point.get('keyword', '')
position, length = self._find_text_position(chapter_content, keyword)
logger.info(f" 情节点位置: keyword='{keyword[:30]}...', pos={position}, len={length}")
memories.append({
'type': 'plot_point',
'content': f"{plot_point.get('content', '')}。影响: {plot_point.get('impact', '')}",
'title': f"情节点 - {plot_point.get('type', '未知')}",
'metadata': {
'chapter_id': chapter_id,
'chapter_number': chapter_number,
'importance_score': plot_point.get('importance', 0.5),
'tags': ['情节点', plot_point.get('type', '未知')],
'is_foreshadow': 0,
'keyword': keyword,
'text_position': position,
'text_length': length
}
})
# 4. 提取角色状态变化
for i, char_state in enumerate(analysis.get('character_states', [])):
char_name = char_state.get('character_name', '未知角色')
memories.append({
'type': 'character_event',
'content': f"{char_name}的状态变化: {char_state.get('state_before', '')}{char_state.get('state_after', '')}{char_state.get('psychological_change', '')}",
'title': f"{char_name}的变化",
'metadata': {
'chapter_id': chapter_id,
'chapter_number': chapter_number,
'importance_score': 0.7,
'tags': ['角色', char_name, '状态变化'],
'related_characters': [char_name],
'is_foreshadow': 0
}
})
# 5. 如果有重要冲突,也记录下来
conflict = analysis.get('conflict', {})
if conflict and conflict.get('level', 0) >= 7:
# 确保 parties 和 types 都是字符串列表
parties = conflict.get('parties', [])
if parties and isinstance(parties, list):
parties = [str(p) for p in parties]
types = conflict.get('types', [])
if types and isinstance(types, list):
types = [str(t) for t in types]
memories.append({
'type': 'plot_point',
'content': f"重要冲突: {conflict.get('description', '')}。冲突各方: {', '.join(parties)}",
'title': f"冲突 - 强度{conflict.get('level', 0)}",
'metadata': {
'chapter_id': chapter_id,
'chapter_number': chapter_number,
'importance_score': min(conflict.get('level', 5) / 10, 1.0),
'tags': ['冲突'] + types,
'is_foreshadow': 0
}
})
logger.info(f"📝 从分析中提取了{len(memories)}条记忆")
return memories
except Exception as e:
logger.error(f"❌ 提取记忆失败: {str(e)}")
return []
def _find_text_position(self, full_text: str, keyword: str) -> tuple[int, int]:
"""
在全文中查找关键词位置
Args:
full_text: 完整文本
keyword: 关键词
Returns:
(起始位置, 长度) 如果未找到返回(-1, 0)
"""
if not keyword or not full_text:
return (-1, 0)
try:
# 1. 精确匹配
pos = full_text.find(keyword)
if pos != -1:
return (pos, len(keyword))
# 2. 去除标点符号后匹配
import re
clean_keyword = re.sub(r'[,。!?、;:""''()《》【】]', '', keyword)
clean_text = re.sub(r'[,。!?、;:""''()《》【】]', '', full_text)
pos = clean_text.find(clean_keyword)
if pos != -1:
# 反向映射到原文位置(简化处理)
return (pos, len(clean_keyword))
# 3. 模糊匹配:查找关键词的前半部分
if len(keyword) > 10:
partial = keyword[:min(15, len(keyword))]
pos = full_text.find(partial)
if pos != -1:
return (pos, len(partial))
# 4. 未找到
logger.debug(f"未找到关键词位置: {keyword[:30]}...")
return (-1, 0)
except Exception as e:
logger.error(f"查找位置失败: {str(e)}")
return (-1, 0)
def generate_analysis_summary(self, analysis: Dict[str, Any]) -> str:
"""
生成分析摘要文本
Args:
analysis: 分析结果
Returns:
格式化的摘要文本
"""
try:
lines = ["=== 章节分析报告 ===\n"]
# 整体评分
scores = analysis.get('scores', {})
lines.append(f"【整体评分】")
lines.append(f" 整体质量: {scores.get('overall', 'N/A')}/10")
lines.append(f" 节奏把控: {scores.get('pacing', 'N/A')}/10")
lines.append(f" 吸引力: {scores.get('engagement', 'N/A')}/10")
lines.append(f" 连贯性: {scores.get('coherence', 'N/A')}/10\n")
# 剧情阶段
lines.append(f"【剧情阶段】{analysis.get('plot_stage', '未知')}\n")
# 钩子统计
hooks = analysis.get('hooks', [])
if hooks:
lines.append(f"【钩子分析】共{len(hooks)}")
for hook in hooks[:3]: # 只显示前3个
lines.append(f" • [{hook.get('type')}] {hook.get('content', '')[:50]}... (强度:{hook.get('strength', 0)})")
lines.append("")
# 伏笔统计
foreshadows = analysis.get('foreshadows', [])
if foreshadows:
planted = sum(1 for f in foreshadows if f.get('type') == 'planted')
resolved = sum(1 for f in foreshadows if f.get('type') == 'resolved')
lines.append(f"【伏笔分析】埋下{planted}个, 回收{resolved}\n")
# 冲突分析
conflict = analysis.get('conflict', {})
if conflict:
lines.append(f"【冲突分析】")
lines.append(f" 类型: {', '.join(conflict.get('types', []))}")
lines.append(f" 强度: {conflict.get('level', 0)}/10")
lines.append(f" 进度: {int(conflict.get('resolution_progress', 0) * 100)}%\n")
# 改进建议
suggestions = analysis.get('suggestions', [])
if suggestions:
lines.append(f"【改进建议】")
for i, sug in enumerate(suggestions, 1):
lines.append(f" {i}. {sug}")
return "\n".join(lines)
except Exception as e:
logger.error(f"❌ 生成摘要失败: {str(e)}")
return "分析摘要生成失败"
# 创建全局实例(需要时手动初始化)
_plot_analyzer_instance = None
def get_plot_analyzer(ai_service: AIService) -> PlotAnalyzer:
"""获取剧情分析器实例"""
global _plot_analyzer_instance
if _plot_analyzer_instance is None:
_plot_analyzer_instance = PlotAnalyzer(ai_service)
return _plot_analyzer_instance
+81 -15
View File
@@ -315,7 +315,7 @@ class PromptService:
2. 数组中要包含{chapter_count}个章节对象
3. 文本中不要使用中文引号(""),改用【】或《》"""
# 大纲续写提示词
# 大纲续写提示词(记忆增强版)
OUTLINE_CONTINUE_GENERATION = """你是一位经验丰富的小说作家和编剧。请基于以下信息续写小说大纲:
【项目信息】
@@ -340,6 +340,11 @@ class PromptService:
【最近剧情】
{recent_plot}
【🧠 智能记忆系统 - 续写参考】
以下是从故事记忆库中检索到的相关信息,请在续写大纲时参考:
{memory_context}
【续写指导】
- 当前情节阶段:{plot_stage_instruction}
- 起始章节编号:第{start_chapter}
@@ -348,10 +353,12 @@ class PromptService:
请生成第{start_chapter}章到第{end_chapter}章的大纲。
要求:
- 与前文自然衔接,保持故事连贯性
- 遵循情节阶段的发展要求
- 保持与已有章节相同的风格和详细程度
- 推进角色成长和情节发展
- **剧情连贯性**与前文自然衔接,保持故事连贯性
- **记忆参考**:适当参考记忆系统中的伏笔、钩子和情节点
- **伏笔回收**:可以考虑回收未完结的伏笔,制造呼应
- **角色发展**:遵循角色在前文中的成长轨迹
- **情节阶段**:遵循情节阶段的发展要求
- **风格一致**:保持与已有章节相同的风格和详细程度
**重要格式要求:**
1. 只返回纯JSON数组格式,不要包含任何markdown标记、代码块标记或其他说明文字
@@ -465,7 +472,7 @@ class PromptService:
请直接输出章节正文内容,不要包含章节标题和其他说明文字。"""
# 章节完整创作提示词(带前置章节上下文)
# 章节完整创作提示词(带前置章节上下文和记忆增强
CHAPTER_GENERATION_WITH_CONTEXT = """你是一位专业的小说作家。请根据以下信息创作本章内容:
项目信息:
@@ -489,6 +496,11 @@ class PromptService:
【已完成的前置章节内容】
{previous_content}
【🧠 智能记忆系统 - 重要参考】
以下是从故事记忆库中检索到的相关信息,请在创作时适当参考和呼应:
{memory_context}
本章信息:
- 章节序号:第{chapter_number}
- 章节标题:{chapter_title}
@@ -518,8 +530,15 @@ class PromptService:
- 体现世界观特色
5. **承上启下**
- 开头自然衔接上一章结尾
- 结尾为下一章做好铺垫
- 开头自然衔接上一章结尾
- 结尾为下一章做好铺垫
6. **记忆系统使用指南**
- **最近章节记忆**:保持情节连贯,注意角色状态和剧情发展
- **语义相关记忆**:参考相似情节的处理方式
- **未完结伏笔**:适当时机可以回收伏笔,制造呼应效果
- **角色状态记忆**:确保角色行为符合其发展轨迹
- **重要情节点**:与关键剧情保持一致
请直接输出章节正文内容,不要包含章节标题和其他说明文字。"""
@@ -746,14 +765,26 @@ class PromptService:
characters_info: str, outlines_context: str,
chapter_number: int, chapter_title: str,
chapter_outline: str, style_content: str = "",
target_word_count: int = 3000) -> str:
target_word_count: int = 3000,
memory_context: dict = None) -> str:
"""
获取章节完整创作提示词
Args:
style_content: 写作风格要求内容,如果提供则会追加到提示词中
target_word_count: 目标字数,默认3000字
memory_context: 记忆上下文(可选)
"""
# 格式化记忆上下文
memory_text = ""
if memory_context:
memory_text = "\n【🧠 智能记忆系统 - 重要参考】\n"
memory_text += memory_context.get('recent_context', '')
memory_text += "\n" + memory_context.get('relevant_memories', '')
memory_text += "\n" + memory_context.get('foreshadows', '')
memory_text += "\n" + memory_context.get('character_states', '')
memory_text += "\n" + memory_context.get('plot_points', '')
base_prompt = cls.format_prompt(
cls.CHAPTER_GENERATION,
title=title,
@@ -772,6 +803,13 @@ class PromptService:
target_word_count=target_word_count
)
# 插入记忆上下文
if memory_text:
base_prompt = base_prompt.replace(
"本章信息:",
memory_text + "\n\n本章信息:"
)
# 如果有风格要求,应用到提示词中
if style_content:
return WritingStyleManager.apply_style_to_prompt(base_prompt, style_content)
@@ -786,14 +824,27 @@ class PromptService:
previous_content: str, chapter_number: int,
chapter_title: str, chapter_outline: str,
style_content: str = "",
target_word_count: int = 3000) -> str:
target_word_count: int = 3000,
memory_context: dict = None) -> str:
"""
获取章节完整创作提示词(带前置章节上下文)
获取章节完整创作提示词(带前置章节上下文和记忆增强
Args:
style_content: 写作风格要求内容,如果提供则会追加到提示词中
target_word_count: 目标字数,默认3000字
memory_context: 记忆上下文(可选)
"""
# 格式化记忆上下文
memory_text = ""
if memory_context:
memory_text = memory_context.get('recent_context', '')
memory_text += "\n" + memory_context.get('relevant_memories', '')
memory_text += "\n" + memory_context.get('foreshadows', '')
memory_text += "\n" + memory_context.get('character_states', '')
memory_text += "\n" + memory_context.get('plot_points', '')
else:
memory_text = "暂无相关记忆"
base_prompt = cls.format_prompt(
cls.CHAPTER_GENERATION_WITH_CONTEXT,
title=title,
@@ -810,7 +861,8 @@ class PromptService:
chapter_number=chapter_number,
chapter_title=chapter_title,
chapter_outline=chapter_outline,
target_word_count=target_word_count
target_word_count=target_word_count,
memory_context=memory_text
)
# 如果有风格要求,应用到提示词中
@@ -839,9 +891,22 @@ class PromptService:
current_chapter_count: int, all_chapters_brief: str,
recent_plot: str, plot_stage_instruction: str,
start_chapter: int, story_direction: str,
requirements: str = "") -> str:
"""获取大纲续写提示词"""
requirements: str = "",
memory_context: dict = None) -> str:
"""获取大纲续写提示词(支持记忆增强)"""
end_chapter = start_chapter + chapter_count - 1
# 格式化记忆上下文
memory_text = ""
if memory_context:
memory_text = memory_context.get('recent_context', '')
memory_text += "\n" + memory_context.get('relevant_memories', '')
memory_text += "\n" + memory_context.get('foreshadows', '')
memory_text += "\n" + memory_context.get('character_states', '')
memory_text += "\n" + memory_context.get('plot_points', '')
else:
memory_text = "暂无相关记忆(可能是首次续写或记忆库为空)"
return cls.format_prompt(
cls.OUTLINE_CONTINUE_GENERATION,
title=title,
@@ -861,7 +926,8 @@ class PromptService:
start_chapter=start_chapter,
end_chapter=end_chapter,
story_direction=story_direction,
requirements=requirements or "无特殊要求"
requirements=requirements or "无特殊要求",
memory_context=memory_text
)
@classmethod