feature:1.新增角色/组织卡片导入导出功能,支持批量

This commit is contained in:
xiamuceer
2025-12-29 16:48:02 +08:00
parent f2158cd36e
commit 3b97e88128
10 changed files with 1068 additions and 114 deletions
+1 -1
View File
@@ -8,7 +8,7 @@
# 应用配置
# ==========================================
APP_NAME=MuMuAINovel
APP_VERSION=1.2.2
APP_VERSION=1.2.3
APP_HOST=0.0.0.0
APP_PORT=8000
DEBUG=false
+164 -1
View File
@@ -1,5 +1,6 @@
"""角色管理API"""
from fastapi import APIRouter, Depends, HTTPException, Request
from fastapi import APIRouter, Depends, HTTPException, Request, UploadFile, File
from fastapi.responses import JSONResponse
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, func
import json
@@ -20,6 +21,8 @@ from app.schemas.character import (
)
from app.services.ai_service import AIService
from app.services.prompt_service import prompt_service, PromptService
from app.services.import_export_service import ImportExportService
from app.schemas.import_export import CharactersExportRequest, CharactersImportResult
from app.logger import get_logger
from app.api.settings import get_user_ai_service
@@ -1306,3 +1309,163 @@ async def generate_character_stream(
yield await SSEResponse.send_error(f"生成角色失败: {str(e)}")
return create_sse_response(generate())
@router.post("/export", summary="批量导出角色/组织")
async def export_characters(
export_request: CharactersExportRequest,
request: Request,
db: AsyncSession = Depends(get_db)
):
"""
批量导出角色/组织为JSON格式
- 支持单个或多个角色/组织导出
- 包含角色的所有信息(基础信息、职业、组织详情等)
- 返回JSON文件供下载
"""
user_id = getattr(request.state, 'user_id', None)
if not user_id:
raise HTTPException(status_code=401, detail="未登录")
if not export_request.character_ids:
raise HTTPException(status_code=400, detail="请至少选择一个角色/组织")
try:
# 验证所有角色的权限
for char_id in export_request.character_ids:
result = await db.execute(
select(Character).where(Character.id == char_id)
)
character = result.scalar_one_or_none()
if not character:
raise HTTPException(status_code=404, detail=f"角色不存在: {char_id}")
# 验证项目权限
await verify_project_access(character.project_id, user_id, db)
# 执行导出
export_data = await ImportExportService.export_characters(
character_ids=export_request.character_ids,
db=db
)
# 生成文件名
from datetime import datetime
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
count = len(export_request.character_ids)
filename = f"characters_export_{count}_{timestamp}.json"
logger.info(f"用户 {user_id} 导出了 {count} 个角色/组织")
# 返回JSON文件
return JSONResponse(
content=export_data,
headers={
"Content-Disposition": f"attachment; filename={filename}",
"Content-Type": "application/json; charset=utf-8"
}
)
except HTTPException:
raise
except Exception as e:
logger.error(f"导出角色/组织失败: {str(e)}")
raise HTTPException(status_code=500, detail=f"导出失败: {str(e)}")
@router.post("/import", response_model=CharactersImportResult, summary="导入角色/组织")
async def import_characters(
project_id: str,
file: UploadFile = File(...),
request: Request = None,
db: AsyncSession = Depends(get_db)
):
"""
从JSON文件导入角色/组织
- 支持导入之前导出的角色/组织JSON文件
- 自动处理重复名称(跳过)
- 验证职业ID的有效性
- 自动创建组织详情记录
"""
user_id = getattr(request.state, 'user_id', None)
if not user_id:
raise HTTPException(status_code=401, detail="未登录")
# 验证项目权限
await verify_project_access(project_id, user_id, db)
# 验证文件类型
if not file.filename.endswith('.json'):
raise HTTPException(status_code=400, detail="只支持JSON格式文件")
try:
# 读取文件内容
content = await file.read()
data = json.loads(content.decode('utf-8'))
# 执行导入
result = await ImportExportService.import_characters(
data=data,
project_id=project_id,
user_id=user_id,
db=db
)
logger.info(f"用户 {user_id} 导入角色/组织到项目 {project_id}: {result['message']}")
return result
except json.JSONDecodeError as e:
raise HTTPException(status_code=400, detail=f"JSON格式错误: {str(e)}")
except Exception as e:
logger.error(f"导入角色/组织失败: {str(e)}")
raise HTTPException(status_code=500, detail=f"导入失败: {str(e)}")
@router.post("/validate-import", summary="验证导入文件")
async def validate_import(
file: UploadFile = File(...),
request: Request = None
):
"""
验证角色/组织导入文件的格式和内容
- 检查文件格式
- 验证版本兼容性
- 统计数据量
- 返回验证结果和警告信息
"""
user_id = getattr(request.state, 'user_id', None)
if not user_id:
raise HTTPException(status_code=401, detail="未登录")
# 验证文件类型
if not file.filename.endswith('.json'):
raise HTTPException(status_code=400, detail="只支持JSON格式文件")
try:
# 读取文件内容
content = await file.read()
data = json.loads(content.decode('utf-8'))
# 验证数据
validation_result = ImportExportService.validate_characters_import(data)
logger.info(f"用户 {user_id} 验证导入文件: {file.filename}")
return validation_result
except json.JSONDecodeError as e:
return {
"valid": False,
"version": "",
"statistics": {"characters": 0, "organizations": 0},
"errors": [f"JSON格式错误: {str(e)}"],
"warnings": []
}
except Exception as e:
logger.error(f"验证导入文件失败: {str(e)}")
raise HTTPException(status_code=500, detail=f"验证失败: {str(e)}")
+35
View File
@@ -36,9 +36,20 @@ class CharacterExportData(BaseModel):
personality: Optional[str] = None
background: Optional[str] = None
appearance: Optional[str] = None
relationships: Optional[str] = None
traits: Optional[List[str]] = None
organization_type: Optional[str] = None
organization_purpose: Optional[str] = None
organization_members: Optional[str] = None
avatar_url: Optional[str] = None
main_career_id: Optional[str] = None
main_career_stage: Optional[int] = None
sub_careers: Optional[str] = None
# 组织专属字段
power_level: Optional[int] = None
location: Optional[str] = None
motto: Optional[str] = None
color: Optional[str] = None
created_at: Optional[str] = None
@@ -138,4 +149,28 @@ class ImportResult(BaseModel):
project_id: Optional[str] = None
message: str
statistics: Dict[str, int] = {}
details: Optional[Dict[str, List[str]]] = None
warnings: List[str] = []
class CharactersExportRequest(BaseModel):
"""角色/组织批量导出请求"""
character_ids: List[str] = Field(..., description="要导出的角色/组织ID列表")
class CharactersExportData(BaseModel):
"""角色/组织批量导出数据"""
version: str = "1.0.0"
export_time: str
export_type: str = "characters"
count: int
data: List[CharacterExportData]
class CharactersImportResult(BaseModel):
"""角色/组织导入结果"""
success: bool
message: str
statistics: Dict[str, int]
details: Dict[str, List[str]]
warnings: List[str] = []
+401 -2
View File
@@ -1,7 +1,7 @@
"""导入导出服务"""
import json
from datetime import datetime
from typing import Dict, List, Optional, Tuple
from typing import Dict, List, Optional, Tuple, Any
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
from app.models.project import Project
@@ -840,4 +840,403 @@ class ImportExportService:
db.add(style)
count += 1
return count
return count
@staticmethod
async def export_characters(
character_ids: List[str],
db: AsyncSession
) -> Dict[str, Any]:
"""
导出角色/组织卡片
Args:
character_ids: 要导出的角色/组织ID列表
db: 数据库会话
Returns:
Dict: 导出的角色数据
"""
logger.info(f"开始导出角色/组织: {len(character_ids)}")
# 查询角色数据
result = await db.execute(
select(Character).where(Character.id.in_(character_ids))
)
characters = result.scalars().all()
if not characters:
raise ValueError("未找到指定的角色/组织")
# 导出角色数据
exported_characters = []
for char in characters:
# 解析 traits
traits = None
if char.traits:
try:
traits = json.loads(char.traits) if isinstance(char.traits, str) else char.traits
except:
traits = None
# 基础角色数据
char_data = {
"name": char.name,
"age": char.age,
"gender": char.gender,
"is_organization": char.is_organization or False,
"role_type": char.role_type,
"personality": char.personality,
"background": char.background,
"appearance": char.appearance,
"relationships": char.relationships,
"traits": traits,
"organization_type": char.organization_type,
"organization_purpose": char.organization_purpose,
"organization_members": char.organization_members,
"avatar_url": char.avatar_url,
"main_career_id": char.main_career_id,
"main_career_stage": char.main_career_stage,
"sub_careers": char.sub_careers,
"created_at": char.created_at.isoformat() if char.created_at else None
}
# 如果是组织,添加组织专属字段
if char.is_organization:
org_result = await db.execute(
select(Organization).where(Organization.character_id == char.id)
)
org = org_result.scalar_one_or_none()
if org:
char_data.update({
"power_level": org.power_level,
"location": org.location,
"motto": org.motto,
"color": org.color
})
exported_characters.append(char_data)
export_data = {
"version": ImportExportService.SUPPORTED_VERSION,
"export_time": datetime.utcnow().isoformat(),
"export_type": "characters",
"count": len(exported_characters),
"data": exported_characters
}
logger.info(f"角色/组织导出完成: {len(exported_characters)}")
return export_data
@staticmethod
async def import_characters(
data: Dict,
project_id: str,
user_id: str,
db: AsyncSession
) -> Dict[str, Any]:
"""
导入角色/组织卡片
Args:
data: 导入的JSON数据
project_id: 目标项目ID
user_id: 用户ID
db: 数据库会话
Returns:
Dict: 导入结果
"""
from app.models.career import CharacterCareer, Career
warnings = []
imported_characters = []
imported_organizations = []
skipped = []
errors = []
try:
# 验证数据格式
if "data" not in data:
raise ValueError("导入数据格式错误:缺少data字段")
characters_data = data["data"]
if not isinstance(characters_data, list):
raise ValueError("导入数据格式错误:data字段必须是数组")
# 验证项目权限
project_result = await db.execute(
select(Project).where(
Project.id == project_id,
Project.user_id == user_id
)
)
project = project_result.scalar_one_or_none()
if not project:
raise ValueError("项目不存在或无权访问")
logger.info(f"开始导入 {len(characters_data)} 个角色/组织到项目 {project_id}")
# 处理每个角色/组织
for idx, char_data in enumerate(characters_data):
try:
name = char_data.get("name")
if not name:
errors.append(f"{idx+1}个角色缺少name字段")
continue
# 检查重复名称
existing_result = await db.execute(
select(Character).where(
Character.project_id == project_id,
Character.name == name
)
)
existing = existing_result.scalar_one_or_none()
if existing:
warnings.append(f"角色'{name}'已存在,已跳过")
skipped.append(name)
continue
# 处理traits
traits = char_data.get("traits")
if isinstance(traits, list):
traits = json.dumps(traits, ensure_ascii=False)
is_organization = char_data.get("is_organization", False)
# 创建角色
character = Character(
project_id=project_id,
name=name,
age=char_data.get("age"),
gender=char_data.get("gender"),
is_organization=is_organization,
role_type=char_data.get("role_type"),
personality=char_data.get("personality"),
background=char_data.get("background"),
appearance=char_data.get("appearance"),
relationships=char_data.get("relationships"),
traits=traits,
organization_type=char_data.get("organization_type"),
organization_purpose=char_data.get("organization_purpose"),
organization_members=char_data.get("organization_members"),
avatar_url=char_data.get("avatar_url"),
main_career_id=None, # 职业ID需要验证后再设置
main_career_stage=char_data.get("main_career_stage"),
sub_careers=None # 副职业需要验证后再设置
)
db.add(character)
await db.flush() # 获取character.id
# 处理主职业(如果有)
main_career_id = char_data.get("main_career_id")
main_career_stage = char_data.get("main_career_stage")
if main_career_id and not is_organization:
# 验证职业是否存在
career_result = await db.execute(
select(Career).where(
Career.id == main_career_id,
Career.project_id == project_id,
Career.type == 'main'
)
)
career = career_result.scalar_one_or_none()
if career:
character.main_career_id = main_career_id
character.main_career_stage = main_career_stage or 1
# 创建职业关联
char_career = CharacterCareer(
character_id=character.id,
career_id=main_career_id,
career_type='main',
current_stage=main_career_stage or 1,
stage_progress=0
)
db.add(char_career)
else:
warnings.append(f"角色'{name}'的主职业ID不存在,已忽略职业信息")
# 处理副职业(如果有)
sub_careers = char_data.get("sub_careers")
if sub_careers and not is_organization:
try:
sub_careers_data = json.loads(sub_careers) if isinstance(sub_careers, str) else sub_careers
if isinstance(sub_careers_data, list):
valid_sub_careers = []
for sub_data in sub_careers_data[:2]: # 最多2个副职业
if isinstance(sub_data, dict):
career_id = sub_data.get('career_id')
stage = sub_data.get('stage', 1)
if career_id:
# 验证副职业是否存在
career_result = await db.execute(
select(Career).where(
Career.id == career_id,
Career.project_id == project_id,
Career.type == 'sub'
)
)
career = career_result.scalar_one_or_none()
if career:
valid_sub_careers.append({
'career_id': career_id,
'stage': stage
})
# 创建副职业关联
char_career = CharacterCareer(
character_id=character.id,
career_id=career_id,
career_type='sub',
current_stage=stage,
stage_progress=0
)
db.add(char_career)
if valid_sub_careers:
character.sub_careers = json.dumps(valid_sub_careers, ensure_ascii=False)
elif sub_careers_data:
warnings.append(f"角色'{name}'的副职业ID不存在,已忽略副职业信息")
except Exception as e:
warnings.append(f"角色'{name}'的副职业数据解析失败: {str(e)}")
# 如果是组织,创建Organization记录
if is_organization:
organization = Organization(
character_id=character.id,
project_id=project_id,
member_count=0,
power_level=char_data.get("power_level", 50),
location=char_data.get("location"),
motto=char_data.get("motto"),
color=char_data.get("color")
)
db.add(organization)
await db.flush()
imported_organizations.append(name)
else:
imported_characters.append(name)
logger.info(f"导入{'组织' if is_organization else '角色'}成功: {name}")
except Exception as e:
error_msg = f"导入角色'{char_data.get('name', f'{idx+1}')}'失败: {str(e)}"
logger.error(error_msg)
errors.append(error_msg)
continue
# 提交事务
await db.commit()
total = len(imported_characters) + len(imported_organizations)
result = {
"success": True,
"message": f"成功导入 {total} 个角色/组织",
"statistics": {
"total": len(characters_data),
"imported": total,
"skipped": len(skipped),
"errors": len(errors)
},
"details": {
"imported_characters": imported_characters,
"imported_organizations": imported_organizations,
"skipped": skipped,
"errors": errors
},
"warnings": warnings
}
logger.info(f"角色/组织导入完成: 成功{total}个,跳过{len(skipped)}个,失败{len(errors)}")
return result
except Exception as e:
await db.rollback()
logger.error(f"导入角色/组织失败: {str(e)}", exc_info=True)
return {
"success": False,
"message": f"导入失败: {str(e)}",
"statistics": {
"total": len(characters_data) if "data" in data else 0,
"imported": len(imported_characters) + len(imported_organizations),
"skipped": len(skipped),
"errors": len(errors)
},
"details": {
"imported_characters": imported_characters,
"imported_organizations": imported_organizations,
"skipped": skipped,
"errors": errors
},
"warnings": warnings
}
@staticmethod
def validate_characters_import(data: Dict) -> Dict[str, Any]:
"""
验证角色/组织导入数据
Args:
data: 导入的JSON数据
Returns:
Dict: 验证结果
"""
errors = []
warnings = []
# 检查版本
version = data.get("version", "")
if not version:
errors.append("缺少版本信息")
elif version != ImportExportService.SUPPORTED_VERSION:
warnings.append(f"版本不匹配: 导入文件版本为 {version}, 当前支持版本为 {ImportExportService.SUPPORTED_VERSION}")
# 检查导出类型
export_type = data.get("export_type", "")
if export_type != "characters":
errors.append(f"导出类型错误: 期望'characters',实际'{export_type}'")
# 检查数据字段
if "data" not in data:
errors.append("缺少data字段")
elif not isinstance(data["data"], list):
errors.append("data字段必须是数组")
else:
characters_data = data["data"]
# 统计信息
character_count = sum(1 for c in characters_data if not c.get("is_organization", False))
org_count = sum(1 for c in characters_data if c.get("is_organization", False))
# 检查必填字段
for idx, char_data in enumerate(characters_data):
if not char_data.get("name"):
errors.append(f"{idx+1}个角色缺少name字段")
statistics = {
"characters": character_count,
"organizations": org_count
}
if "data" not in data or errors:
statistics = {"characters": 0, "organizations": 0}
return {
"valid": len(errors) == 0,
"version": version,
"statistics": statistics,
"errors": errors,
"warnings": warnings
}