refactor: 导入导出功能增强:版本升级至1.1.0,新增职业系统、故事记忆、剧情分析的导出选项

This commit is contained in:
xiamuceer-j
2026-01-14 19:47:28 +08:00
parent 7ba2b2e5fa
commit fb16cc4072
3 changed files with 707 additions and 34 deletions
+540 -13
View File
@@ -3,7 +3,7 @@ import json
from datetime import datetime
from typing import Dict, List, Optional, Tuple, Any
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
from sqlalchemy import select, or_
from app.models.project import Project
from app.models.chapter import Chapter
from app.models.character import Character
@@ -11,6 +11,9 @@ from app.models.outline import Outline
from app.models.relationship import CharacterRelationship, Organization, OrganizationMember
from app.models.writing_style import WritingStyle
from app.models.generation_history import GenerationHistory
from app.models.career import Career, CharacterCareer
from app.models.memory import StoryMemory, PlotAnalysis
from app.models.project_default_style import ProjectDefaultStyle
from app.schemas.import_export import (
ProjectExportData,
ChapterExportData,
@@ -21,6 +24,11 @@ from app.schemas.import_export import (
OrganizationMemberExportData,
WritingStyleExportData,
GenerationHistoryExportData,
CareerExportData,
CharacterCareerExportData,
StoryMemoryExportData,
PlotAnalysisExportData,
ProjectDefaultStyleExportData,
ImportValidationResult,
ImportResult
)
@@ -32,14 +40,18 @@ logger = get_logger(__name__)
class ImportExportService:
"""导入导出服务类"""
SUPPORTED_VERSION = "1.0.0"
SUPPORTED_VERSIONS = ["1.0.0", "1.1.0"] # 支持的版本列表
CURRENT_VERSION = "1.1.0" # 当前导出版本
@staticmethod
async def export_project(
project_id: str,
db: AsyncSession,
include_generation_history: bool = False,
include_writing_styles: bool = True
include_writing_styles: bool = True,
include_careers: bool = True,
include_memories: bool = False,
include_plot_analysis: bool = False
) -> ProjectExportData:
"""
导出项目完整数据
@@ -49,6 +61,9 @@ class ImportExportService:
db: 数据库会话
include_generation_history: 是否包含生成历史
include_writing_styles: 是否包含写作风格
include_careers: 是否包含职业系统
include_memories: 是否包含故事记忆
include_plot_analysis: 是否包含剧情分析
Returns:
ProjectExportData: 导出的项目数据
@@ -77,7 +92,7 @@ class ImportExportService:
"chapter_count": project.chapter_count,
"narrative_perspective": project.narrative_perspective,
"character_count": project.character_count,
"outline_mode": project.outline_mode,
"outline_mode": project.outline_mode,
"user_id": project.user_id,
"created_at": project.created_at.isoformat() if project.created_at else None,
}
@@ -118,8 +133,34 @@ class ImportExportService:
generation_history = await ImportExportService._export_generation_history(project_id, db)
logger.info(f"导出生成历史数: {len(generation_history)}")
# 导出职业系统(可选)
careers = []
character_careers = []
if include_careers:
careers = await ImportExportService._export_careers(project_id, db)
logger.info(f"导出职业数: {len(careers)}")
character_careers = await ImportExportService._export_character_careers(project_id, db)
logger.info(f"导出角色职业关联数: {len(character_careers)}")
# 导出故事记忆(可选)
story_memories = []
if include_memories:
story_memories = await ImportExportService._export_story_memories(project_id, db)
logger.info(f"导出故事记忆数: {len(story_memories)}")
# 导出剧情分析(可选)
plot_analysis = []
if include_plot_analysis:
plot_analysis = await ImportExportService._export_plot_analysis(project_id, db)
logger.info(f"导出剧情分析数: {len(plot_analysis)}")
# 导出项目默认风格
project_default_style = await ImportExportService._export_project_default_style(project_id, db)
if project_default_style:
logger.info(f"导出项目默认风格: {project_default_style.style_name}")
export_data = ProjectExportData(
version=ImportExportService.SUPPORTED_VERSION,
version=ImportExportService.CURRENT_VERSION,
export_time=datetime.utcnow().isoformat(),
project=project_data,
chapters=chapters,
@@ -129,7 +170,12 @@ class ImportExportService:
organizations=organizations,
organization_members=org_members,
writing_styles=writing_styles,
generation_history=generation_history
generation_history=generation_history,
careers=careers,
character_careers=character_careers,
story_memories=story_memories,
plot_analysis=plot_analysis,
project_default_style=project_default_style
)
logger.info(f"项目导出完成: {project_id}")
@@ -394,6 +440,185 @@ class ImportExportService:
for history, chapter in histories
]
@staticmethod
async def _export_careers(project_id: str, db: AsyncSession) -> List[CareerExportData]:
"""导出职业系统"""
result = await db.execute(
select(Career)
.where(Career.project_id == project_id)
.order_by(Career.type, Career.created_at)
)
careers = result.scalars().all()
return [
CareerExportData(
name=career.name,
type=career.type,
description=career.description,
category=career.category,
stages=career.stages,
max_stage=career.max_stage or 10,
requirements=career.requirements,
special_abilities=career.special_abilities,
worldview_rules=career.worldview_rules,
attribute_bonuses=career.attribute_bonuses,
source=career.source or "ai",
created_at=career.created_at.isoformat() if career.created_at else None
)
for career in careers
]
@staticmethod
async def _export_character_careers(project_id: str, db: AsyncSession) -> List[CharacterCareerExportData]:
"""导出角色职业关联"""
# 查询所有属于该项目的角色职业关联
result = await db.execute(
select(CharacterCareer, Character, Career)
.join(Character, CharacterCareer.character_id == Character.id)
.join(Career, CharacterCareer.career_id == Career.id)
.where(Character.project_id == project_id)
)
character_careers = result.all()
return [
CharacterCareerExportData(
character_name=char.name,
career_name=career.name,
career_type=cc.career_type,
current_stage=cc.current_stage or 1,
stage_progress=cc.stage_progress or 0,
started_at=cc.started_at,
reached_current_stage_at=cc.reached_current_stage_at,
notes=cc.notes
)
for cc, char, career in character_careers
]
@staticmethod
async def _export_story_memories(project_id: str, db: AsyncSession) -> List[StoryMemoryExportData]:
"""导出故事记忆"""
# 构建章节ID到标题的映射
chapter_result = await db.execute(
select(Chapter).where(Chapter.project_id == project_id)
)
chapters = chapter_result.scalars().all()
chapter_mapping = {ch.id: ch.title for ch in chapters}
# 构建角色ID到名称的映射
char_result = await db.execute(
select(Character).where(Character.project_id == project_id)
)
characters = char_result.scalars().all()
char_mapping = {char.id: char.name for char in characters}
result = await db.execute(
select(StoryMemory)
.where(StoryMemory.project_id == project_id)
.order_by(StoryMemory.story_timeline, StoryMemory.chapter_position)
)
memories = result.scalars().all()
exported = []
for mem in memories:
# 将角色ID列表转换为名称列表
related_char_names = None
if mem.related_characters:
related_char_names = [
char_mapping.get(char_id, char_id)
for char_id in mem.related_characters
]
exported.append(StoryMemoryExportData(
chapter_title=chapter_mapping.get(mem.chapter_id) if mem.chapter_id else None,
memory_type=mem.memory_type,
title=mem.title,
content=mem.content,
full_context=mem.full_context,
related_characters=related_char_names,
related_locations=mem.related_locations,
tags=mem.tags,
importance_score=mem.importance_score or 0.5,
story_timeline=mem.story_timeline,
chapter_position=mem.chapter_position or 0,
text_length=mem.text_length or 0,
is_foreshadow=mem.is_foreshadow or 0,
foreshadow_strength=mem.foreshadow_strength,
created_at=mem.created_at.isoformat() if mem.created_at else None
))
return exported
@staticmethod
async def _export_plot_analysis(project_id: str, db: AsyncSession) -> List[PlotAnalysisExportData]:
"""导出剧情分析"""
# 构建章节ID到标题的映射
chapter_result = await db.execute(
select(Chapter).where(Chapter.project_id == project_id)
)
chapters = chapter_result.scalars().all()
chapter_mapping = {ch.id: ch.title for ch in chapters}
result = await db.execute(
select(PlotAnalysis)
.where(PlotAnalysis.project_id == project_id)
)
analyses = result.scalars().all()
exported = []
for analysis in analyses:
chapter_title = chapter_mapping.get(analysis.chapter_id)
if not chapter_title:
continue # 跳过没有关联章节的分析
exported.append(PlotAnalysisExportData(
chapter_title=chapter_title,
plot_stage=analysis.plot_stage,
conflict_level=analysis.conflict_level,
conflict_types=analysis.conflict_types,
emotional_tone=analysis.emotional_tone,
emotional_intensity=analysis.emotional_intensity,
emotional_curve=analysis.emotional_curve,
hooks=analysis.hooks,
hooks_count=analysis.hooks_count or 0,
hooks_avg_strength=analysis.hooks_avg_strength,
foreshadows=analysis.foreshadows,
foreshadows_planted=analysis.foreshadows_planted or 0,
foreshadows_resolved=analysis.foreshadows_resolved or 0,
plot_points=analysis.plot_points,
plot_points_count=analysis.plot_points_count or 0,
character_states=analysis.character_states,
scenes=analysis.scenes,
pacing=analysis.pacing,
overall_quality_score=analysis.overall_quality_score,
pacing_score=analysis.pacing_score,
engagement_score=analysis.engagement_score,
coherence_score=analysis.coherence_score,
analysis_report=analysis.analysis_report,
suggestions=analysis.suggestions,
word_count=analysis.word_count,
dialogue_ratio=analysis.dialogue_ratio,
description_ratio=analysis.description_ratio,
created_at=analysis.created_at.isoformat() if analysis.created_at else None
))
return exported
@staticmethod
async def _export_project_default_style(project_id: str, db: AsyncSession) -> Optional[ProjectDefaultStyleExportData]:
"""导出项目默认风格"""
result = await db.execute(
select(ProjectDefaultStyle, WritingStyle)
.join(WritingStyle, ProjectDefaultStyle.style_id == WritingStyle.id)
.where(ProjectDefaultStyle.project_id == project_id)
)
row = result.first()
if row:
_, style = row
return ProjectDefaultStyleExportData(style_name=style.name)
return None
@staticmethod
def validate_import_data(data: Dict) -> ImportValidationResult:
"""
@@ -413,8 +638,8 @@ class ImportExportService:
version = data.get("version", "")
if not version:
errors.append("缺少版本信息")
elif version != ImportExportService.SUPPORTED_VERSION:
warnings.append(f"版本不匹配: 导入文件版本为 {version}, 当前支持版本为 {ImportExportService.SUPPORTED_VERSION}")
elif version not in ImportExportService.SUPPORTED_VERSIONS:
warnings.append(f"版本不匹配: 导入文件版本为 {version}, 当前支持版本为 {', '.join(ImportExportService.SUPPORTED_VERSIONS)}")
# 检查必需字段
if "project" not in data:
@@ -424,7 +649,7 @@ class ImportExportService:
if not project.get("title"):
errors.append("项目标题不能为空")
# 统计数据
# 统计数据(包含新增字段)
statistics = {
"chapters": len(data.get("chapters", [])),
"characters": len(data.get("characters", [])),
@@ -433,7 +658,12 @@ class ImportExportService:
"organizations": len(data.get("organizations", [])),
"organization_members": len(data.get("organization_members", [])),
"writing_styles": len(data.get("writing_styles", [])),
"generation_history": len(data.get("generation_history", []))
"generation_history": len(data.get("generation_history", [])),
"careers": len(data.get("careers", [])),
"character_careers": len(data.get("character_careers", [])),
"story_memories": len(data.get("story_memories", [])),
"plot_analysis": len(data.get("plot_analysis", [])),
"has_default_style": data.get("project_default_style") is not None
}
# 检查数据完整性
@@ -565,6 +795,58 @@ class ImportExportService:
statistics["writing_styles"] = styles_count
logger.info(f"导入写作风格数: {styles_count}")
# 导入职业系统
career_mapping = await ImportExportService._import_careers(
new_project.id, data.get("careers", []), db
)
statistics["careers"] = len(career_mapping)
logger.info(f"导入职业数: {len(career_mapping)}")
# 导入角色职业关联
char_careers_count = await ImportExportService._import_character_careers(
data.get("character_careers", []), char_mapping, career_mapping, db
)
statistics["character_careers"] = char_careers_count
logger.info(f"导入角色职业关联数: {char_careers_count}")
# 导入故事记忆
# 需要先构建章节标题到ID的映射
chapter_title_to_id = {}
for ch_data in data.get("chapters", []):
title = ch_data.get("title")
if title:
# 查询刚导入的章节
ch_result = await db.execute(
select(Chapter).where(
Chapter.project_id == new_project.id,
Chapter.title == title
)
)
ch = ch_result.scalar_one_or_none()
if ch:
chapter_title_to_id[title] = ch.id
memories_count = await ImportExportService._import_story_memories(
new_project.id, data.get("story_memories", []), chapter_title_to_id, char_mapping, db
)
statistics["story_memories"] = memories_count
logger.info(f"导入故事记忆数: {memories_count}")
# 导入剧情分析
plot_analysis_count = await ImportExportService._import_plot_analysis(
new_project.id, data.get("plot_analysis", []), chapter_title_to_id, db
)
statistics["plot_analysis"] = plot_analysis_count
logger.info(f"导入剧情分析数: {plot_analysis_count}")
# 导入项目默认风格
default_style_imported = await ImportExportService._import_project_default_style(
new_project.id, data.get("project_default_style"), db
)
statistics["project_default_style"] = 1 if default_style_imported else 0
if default_style_imported:
logger.info("导入项目默认风格成功")
# 提交事务
await db.commit()
@@ -842,6 +1124,251 @@ class ImportExportService:
return count
@staticmethod
async def _import_careers(
project_id: str,
careers_data: List[Dict],
db: AsyncSession
) -> Dict[str, str]:
"""导入职业,返回名称到ID的映射"""
career_mapping = {}
for career_data in careers_data:
career = Career(
project_id=project_id,
name=career_data.get("name"),
type=career_data.get("type", "main"),
description=career_data.get("description"),
category=career_data.get("category"),
stages=career_data.get("stages", "[]"),
max_stage=career_data.get("max_stage", 10),
requirements=career_data.get("requirements"),
special_abilities=career_data.get("special_abilities"),
worldview_rules=career_data.get("worldview_rules"),
attribute_bonuses=career_data.get("attribute_bonuses"),
source=career_data.get("source", "ai")
)
db.add(career)
await db.flush()
career_mapping[career_data.get("name")] = career.id
return career_mapping
@staticmethod
async def _import_character_careers(
character_careers_data: List[Dict],
char_mapping: Dict[str, str],
career_mapping: Dict[str, str],
db: AsyncSession
) -> int:
"""导入角色职业关联"""
count = 0
for cc_data in character_careers_data:
char_name = cc_data.get("character_name")
career_name = cc_data.get("career_name")
char_id = char_mapping.get(char_name)
career_id = career_mapping.get(career_name)
if char_id and career_id:
# 检查是否已存在
existing = await db.execute(
select(CharacterCareer).where(
CharacterCareer.character_id == char_id,
CharacterCareer.career_id == career_id
)
)
if existing.scalar_one_or_none():
continue
char_career = CharacterCareer(
character_id=char_id,
career_id=career_id,
career_type=cc_data.get("career_type", "main"),
current_stage=cc_data.get("current_stage", 1),
stage_progress=cc_data.get("stage_progress", 0),
started_at=cc_data.get("started_at"),
reached_current_stage_at=cc_data.get("reached_current_stage_at"),
notes=cc_data.get("notes")
)
db.add(char_career)
count += 1
# 同时更新角色的主职业信息
if cc_data.get("career_type") == "main":
char_result = await db.execute(
select(Character).where(Character.id == char_id)
)
char = char_result.scalar_one_or_none()
if char:
char.main_career_id = career_id
char.main_career_stage = cc_data.get("current_stage", 1)
return count
@staticmethod
async def _import_story_memories(
project_id: str,
memories_data: List[Dict],
chapter_mapping: Dict[str, str],
char_mapping: Dict[str, str],
db: AsyncSession
) -> int:
"""导入故事记忆"""
count = 0
for mem_data in memories_data:
# 将章节标题转换为ID
chapter_id = None
chapter_title = mem_data.get("chapter_title")
if chapter_title and chapter_title in chapter_mapping:
chapter_id = chapter_mapping[chapter_title]
# 将角色名称列表转换为ID列表
related_char_ids = None
related_char_names = mem_data.get("related_characters")
if related_char_names:
related_char_ids = [
char_mapping.get(name)
for name in related_char_names
if char_mapping.get(name)
]
memory = StoryMemory(
project_id=project_id,
chapter_id=chapter_id,
memory_type=mem_data.get("memory_type"),
title=mem_data.get("title"),
content=mem_data.get("content"),
full_context=mem_data.get("full_context"),
related_characters=related_char_ids,
related_locations=mem_data.get("related_locations"),
tags=mem_data.get("tags"),
importance_score=mem_data.get("importance_score", 0.5),
story_timeline=mem_data.get("story_timeline", 0),
chapter_position=mem_data.get("chapter_position", 0),
text_length=mem_data.get("text_length", 0),
is_foreshadow=mem_data.get("is_foreshadow", 0),
foreshadow_strength=mem_data.get("foreshadow_strength")
)
db.add(memory)
count += 1
return count
@staticmethod
async def _import_plot_analysis(
project_id: str,
plot_data: List[Dict],
chapter_mapping: Dict[str, str],
db: AsyncSession
) -> int:
"""导入剧情分析"""
count = 0
for analysis_data in plot_data:
chapter_title = analysis_data.get("chapter_title")
chapter_id = chapter_mapping.get(chapter_title)
if not chapter_id:
continue # 跳过找不到章节的分析
# 检查是否已存在该章节的分析
existing = await db.execute(
select(PlotAnalysis).where(PlotAnalysis.chapter_id == chapter_id)
)
if existing.scalar_one_or_none():
continue
analysis = PlotAnalysis(
project_id=project_id,
chapter_id=chapter_id,
plot_stage=analysis_data.get("plot_stage"),
conflict_level=analysis_data.get("conflict_level"),
conflict_types=analysis_data.get("conflict_types"),
emotional_tone=analysis_data.get("emotional_tone"),
emotional_intensity=analysis_data.get("emotional_intensity"),
emotional_curve=analysis_data.get("emotional_curve"),
hooks=analysis_data.get("hooks"),
hooks_count=analysis_data.get("hooks_count", 0),
hooks_avg_strength=analysis_data.get("hooks_avg_strength"),
foreshadows=analysis_data.get("foreshadows"),
foreshadows_planted=analysis_data.get("foreshadows_planted", 0),
foreshadows_resolved=analysis_data.get("foreshadows_resolved", 0),
plot_points=analysis_data.get("plot_points"),
plot_points_count=analysis_data.get("plot_points_count", 0),
character_states=analysis_data.get("character_states"),
scenes=analysis_data.get("scenes"),
pacing=analysis_data.get("pacing"),
overall_quality_score=analysis_data.get("overall_quality_score"),
pacing_score=analysis_data.get("pacing_score"),
engagement_score=analysis_data.get("engagement_score"),
coherence_score=analysis_data.get("coherence_score"),
analysis_report=analysis_data.get("analysis_report"),
suggestions=analysis_data.get("suggestions"),
word_count=analysis_data.get("word_count"),
dialogue_ratio=analysis_data.get("dialogue_ratio"),
description_ratio=analysis_data.get("description_ratio")
)
db.add(analysis)
count += 1
return count
@staticmethod
async def _import_project_default_style(
project_id: str,
default_style_data: Optional[Dict],
db: AsyncSession
) -> bool:
"""导入项目默认风格"""
if not default_style_data:
return False
style_name = default_style_data.get("style_name")
if not style_name:
return False
# 获取项目所属用户
project_result = await db.execute(
select(Project).where(Project.id == project_id)
)
project = project_result.scalar_one_or_none()
if not project:
return False
# 查找对应的风格(优先查找用户自定义风格,然后是全局预设风格)
# 先查用户自定义风格
style_result = await db.execute(
select(WritingStyle).where(
WritingStyle.user_id == project.user_id,
WritingStyle.name == style_name
)
)
style = style_result.scalar_one_or_none()
# 如果用户自定义风格不存在,查找全局预设风格
if not style:
style_result = await db.execute(
select(WritingStyle).where(
WritingStyle.user_id.is_(None),
WritingStyle.name == style_name
)
)
style = style_result.scalar_one_or_none()
if not style:
logger.warning(f"导入项目默认风格时未找到风格: {style_name}")
return False
# 创建项目默认风格关联
default_style = ProjectDefaultStyle(
project_id=project_id,
style_id=style.id
)
db.add(default_style)
logger.info(f"项目默认风格导入成功: {style_name}, style_id={style.id}")
return True
@staticmethod
async def export_characters(
character_ids: List[str],
@@ -919,7 +1446,7 @@ class ImportExportService:
exported_characters.append(char_data)
export_data = {
"version": ImportExportService.SUPPORTED_VERSION,
"version": ImportExportService.CURRENT_VERSION,
"export_time": datetime.utcnow().isoformat(),
"export_type": "characters",
"count": len(exported_characters),
@@ -1200,8 +1727,8 @@ class ImportExportService:
version = data.get("version", "")
if not version:
errors.append("缺少版本信息")
elif version != ImportExportService.SUPPORTED_VERSION:
warnings.append(f"版本不匹配: 导入文件版本为 {version}, 当前支持版本为 {ImportExportService.SUPPORTED_VERSION}")
elif version not in ImportExportService.SUPPORTED_VERSIONS:
warnings.append(f"版本不匹配: 导入文件版本为 {version}, 当前支持版本为 {', '.join(ImportExportService.SUPPORTED_VERSIONS)}")
# 检查导出类型
export_type = data.get("export_type", "")