update:1.重构项目数据库初始化和迁移逻辑,使用Alembic数据库管理工具

This commit is contained in:
xiamuceer
2025-12-26 15:05:48 +08:00
parent a5788e75ae
commit f32e51b594
39 changed files with 2249 additions and 2037 deletions
@@ -1,45 +0,0 @@
-- 为Chapter表添加与Outline的关联关系
-- 实现大纲到章节的一对多关系
-- 添加outline_id外键字段
ALTER TABLE chapters
ADD COLUMN outline_id VARCHAR(36) NULL;
-- 添加sub_index字段,表示在该大纲下的子章节序号
ALTER TABLE chapters
ADD COLUMN sub_index INTEGER DEFAULT 1;
-- 添加字段注释(PostgreSQL语法)
COMMENT ON COLUMN chapters.outline_id IS '关联的大纲ID';
COMMENT ON COLUMN chapters.sub_index IS '大纲下的子章节序号';
-- 添加外键约束
ALTER TABLE chapters
ADD CONSTRAINT fk_chapter_outline
FOREIGN KEY (outline_id)
REFERENCES outlines(id)
ON DELETE SET NULL;
-- 创建索引优化查询性能
CREATE INDEX idx_chapters_outline_id ON chapters(outline_id);
CREATE INDEX idx_chapters_outline_sub ON chapters(outline_id, sub_index);
-- 说明:
-- outline_id为NULL表示旧数据或独立章节
-- outline_id有值表示该章节由某个大纲展开生成
-- sub_index表示在该大纲下的第几个子章节(从1开始)
-- 为 chapters 表添加 expansion_plan 字段
-- 用于存储大纲展开规划的详细数据(JSON格式)
-- 添加字段
ALTER TABLE chapters ADD COLUMN IF NOT EXISTS expansion_plan TEXT;
-- 添加注释
COMMENT ON COLUMN chapters.expansion_plan IS '展开规划详情(JSON): 包含key_events, character_focus, emotional_tone等';
-- 查看修改结果
SELECT column_name, data_type, is_nullable, column_default
FROM information_schema.columns
WHERE table_name = 'chapters'
ORDER BY ordinal_position;
-200
View File
@@ -1,200 +0,0 @@
-- 职业体系模块数据库迁移脚本(PostgreSQL版本)
-- 创建时间: 2025-12-20
-- 说明: 添加职业表和角色职业关联表
-- ===== 1. 创建职业表 =====
CREATE TABLE IF NOT EXISTS careers (
id VARCHAR(36) PRIMARY KEY,
project_id VARCHAR(36) NOT NULL,
-- 基本信息
name VARCHAR(100) NOT NULL,
type VARCHAR(20) NOT NULL, -- 职业类型: main(主职业)/sub(副职业)
description TEXT, -- 职业描述
category VARCHAR(50), -- 职业分类(如:战斗系、生产系、辅助系)
-- 阶段设定
stages TEXT NOT NULL, -- 职业阶段列表(JSON): [{"level":1, "name":"", "description":""}, ...]
max_stage INT NOT NULL DEFAULT 10, -- 最大阶段数
-- 职业特性
requirements TEXT, -- 职业要求/限制
special_abilities TEXT, -- 特殊能力描述
worldview_rules TEXT, -- 世界观规则关联
-- 职业属性加成(可选,JSON格式)
attribute_bonuses TEXT, -- 属性加成(JSON): {"strength": "+10%", "intelligence": "+5%"}
-- 元数据
source VARCHAR(20) DEFAULT 'ai', -- 来源: ai/manual
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, -- 创建时间
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, -- 更新时间
-- 外键约束
CONSTRAINT fk_career_project FOREIGN KEY (project_id) REFERENCES projects(id) ON DELETE CASCADE
);
-- 创建索引
CREATE INDEX IF NOT EXISTS idx_careers_project_id ON careers(project_id);
CREATE INDEX IF NOT EXISTS idx_careers_type ON careers(type);
-- 创建更新时间触发器
CREATE OR REPLACE FUNCTION update_careers_updated_at()
RETURNS TRIGGER AS $$
BEGIN
NEW.updated_at = CURRENT_TIMESTAMP;
RETURN NEW;
END;
$$ LANGUAGE plpgsql;
CREATE TRIGGER trigger_careers_updated_at
BEFORE UPDATE ON careers
FOR EACH ROW
EXECUTE FUNCTION update_careers_updated_at();
-- 添加表注释
COMMENT ON TABLE careers IS '职业表';
COMMENT ON COLUMN careers.name IS '职业名称';
COMMENT ON COLUMN careers.type IS '职业类型: main(主职业)/sub(副职业)';
COMMENT ON COLUMN careers.description IS '职业描述';
COMMENT ON COLUMN careers.category IS '职业分类(如:战斗系、生产系、辅助系)';
COMMENT ON COLUMN careers.stages IS '职业阶段列表(JSON)';
COMMENT ON COLUMN careers.max_stage IS '最大阶段数';
COMMENT ON COLUMN careers.requirements IS '职业要求/限制';
COMMENT ON COLUMN careers.special_abilities IS '特殊能力描述';
COMMENT ON COLUMN careers.worldview_rules IS '世界观规则关联';
COMMENT ON COLUMN careers.attribute_bonuses IS '属性加成(JSON)';
COMMENT ON COLUMN careers.source IS '来源: ai/manual';
COMMENT ON COLUMN careers.created_at IS '创建时间';
COMMENT ON COLUMN careers.updated_at IS '更新时间';
-- ===== 2. 创建角色职业关联表 =====
CREATE TABLE IF NOT EXISTS character_careers (
id VARCHAR(36) PRIMARY KEY,
character_id VARCHAR(36) NOT NULL,
career_id VARCHAR(36) NOT NULL,
career_type VARCHAR(20) NOT NULL, -- main(主职业)/sub(副职业)
-- 阶段进度
current_stage INT NOT NULL DEFAULT 1, -- 当前阶段(对应职业中的数值)
stage_progress INT DEFAULT 0, -- 阶段内进度(0-100
-- 时间记录
started_at VARCHAR(100), -- 开始修炼时间(小说时间线)
reached_current_stage_at VARCHAR(100), -- 到达当前阶段时间
-- 备注
notes TEXT, -- 备注(如:修炼心得、特殊事件)
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, -- 创建时间
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, -- 更新时间
-- 外键约束
CONSTRAINT fk_charcareer_character FOREIGN KEY (character_id) REFERENCES characters(id) ON DELETE CASCADE,
CONSTRAINT fk_charcareer_career FOREIGN KEY (career_id) REFERENCES careers(id) ON DELETE CASCADE,
-- 唯一约束:一个角色不能重复拥有同一个职业
CONSTRAINT uk_character_career UNIQUE (character_id, career_id)
);
-- 创建索引
CREATE INDEX IF NOT EXISTS idx_character_careers_character_id ON character_careers(character_id);
CREATE INDEX IF NOT EXISTS idx_character_careers_career_type ON character_careers(career_type);
-- 创建更新时间触发器
CREATE OR REPLACE FUNCTION update_character_careers_updated_at()
RETURNS TRIGGER AS $$
BEGIN
NEW.updated_at = CURRENT_TIMESTAMP;
RETURN NEW;
END;
$$ LANGUAGE plpgsql;
CREATE TRIGGER trigger_character_careers_updated_at
BEFORE UPDATE ON character_careers
FOR EACH ROW
EXECUTE FUNCTION update_character_careers_updated_at();
-- 添加表注释
COMMENT ON TABLE character_careers IS '角色职业关联表';
COMMENT ON COLUMN character_careers.career_type IS 'main(主职业)/sub(副职业)';
COMMENT ON COLUMN character_careers.current_stage IS '当前阶段(对应职业中的数值)';
COMMENT ON COLUMN character_careers.stage_progress IS '阶段内进度(0-100';
COMMENT ON COLUMN character_careers.started_at IS '开始修炼时间(小说时间线)';
COMMENT ON COLUMN character_careers.reached_current_stage_at IS '到达当前阶段时间';
COMMENT ON COLUMN character_careers.notes IS '备注(如:修炼心得、特殊事件)';
-- ===== 3. 扩展角色表(添加冗余字段,可选) =====
-- 注意:这部分是可选的,用于提升查询性能
-- 检查字段是否存在,如果不存在则添加
DO $$
BEGIN
-- 添加 main_career_id 字段
IF NOT EXISTS (SELECT 1 FROM information_schema.columns
WHERE table_name='characters' AND column_name='main_career_id') THEN
ALTER TABLE characters ADD COLUMN main_career_id VARCHAR(36);
COMMENT ON COLUMN characters.main_career_id IS '主职业ID';
END IF;
-- 添加 main_career_stage 字段
IF NOT EXISTS (SELECT 1 FROM information_schema.columns
WHERE table_name='characters' AND column_name='main_career_stage') THEN
ALTER TABLE characters ADD COLUMN main_career_stage INT;
COMMENT ON COLUMN characters.main_career_stage IS '主职业当前阶段';
END IF;
-- 添加 sub_careers 字段
IF NOT EXISTS (SELECT 1 FROM information_schema.columns
WHERE table_name='characters' AND column_name='sub_careers') THEN
ALTER TABLE characters ADD COLUMN sub_careers TEXT;
COMMENT ON COLUMN characters.sub_careers IS '副职业列表(JSON): [{"career_id": "xxx", "stage": 3}, ...]';
END IF;
END $$;
-- 添加外键约束(如果需要)
DO $$
BEGIN
IF NOT EXISTS (SELECT 1 FROM information_schema.table_constraints
WHERE constraint_name='fk_main_career' AND table_name='characters') THEN
ALTER TABLE characters
ADD CONSTRAINT fk_main_career
FOREIGN KEY (main_career_id) REFERENCES careers(id) ON DELETE SET NULL;
END IF;
END $$;
-- ===== 4. 创建视图(可选,便于查询) =====
CREATE OR REPLACE VIEW v_character_career_details AS
SELECT
cc.id AS relation_id,
cc.character_id,
c.name AS character_name,
cc.career_id,
ca.name AS career_name,
ca.type AS career_type_name,
cc.career_type,
cc.current_stage,
ca.max_stage,
cc.stage_progress,
cc.started_at,
cc.reached_current_stage_at,
cc.notes,
ca.description AS career_description,
ca.category AS career_category,
ca.stages AS career_stages_json,
cc.created_at,
cc.updated_at
FROM character_careers cc
JOIN characters c ON cc.character_id = c.id
JOIN careers ca ON cc.career_id = ca.id
ORDER BY cc.career_type DESC, cc.created_at;
COMMENT ON VIEW v_character_career_details IS '角色职业详细信息视图';
-- ===== 完成提示 =====
DO $$
BEGIN
RAISE NOTICE '职业体系数据库表创建完成!';
RAISE NOTICE '职业表记录数: %', (SELECT COUNT(*) FROM careers);
RAISE NOTICE '角色职业关联表记录数: %', (SELECT COUNT(*) FROM character_careers);
END $$;
@@ -1,73 +0,0 @@
-- 创建章节重新生成任务表
-- 用于支持根据AI分析建议重新生成章节内容的功能
-- 创建重新生成任务表
CREATE TABLE IF NOT EXISTS regeneration_tasks (
id VARCHAR(36) PRIMARY KEY,
chapter_id VARCHAR(36) NOT NULL,
analysis_id VARCHAR(36),
user_id VARCHAR(100) NOT NULL,
project_id VARCHAR(36) NOT NULL,
-- 修改指令
modification_instructions TEXT NOT NULL,
original_suggestions JSON,
selected_suggestion_indices JSON,
custom_instructions TEXT,
-- 生成配置
style_id INTEGER,
target_word_count INTEGER DEFAULT 3000,
focus_areas JSON,
preserve_elements JSON,
-- 任务状态
status VARCHAR(20) DEFAULT 'pending',
progress INTEGER DEFAULT 0,
error_message TEXT,
-- 内容数据
original_content TEXT,
original_word_count INTEGER,
regenerated_content TEXT,
regenerated_word_count INTEGER,
-- 版本信息
version_number INTEGER DEFAULT 1,
version_note TEXT,
-- 时间戳
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
started_at TIMESTAMP,
completed_at TIMESTAMP,
-- 外键约束
CONSTRAINT fk_regeneration_chapter FOREIGN KEY (chapter_id) REFERENCES chapters(id) ON DELETE CASCADE,
CONSTRAINT fk_regeneration_project FOREIGN KEY (project_id) REFERENCES projects(id) ON DELETE CASCADE,
CONSTRAINT fk_regeneration_analysis FOREIGN KEY (analysis_id) REFERENCES analysis_tasks(id) ON DELETE SET NULL,
CONSTRAINT fk_regeneration_style FOREIGN KEY (style_id) REFERENCES writing_styles(id) ON DELETE SET NULL
);
-- 创建索引以提升查询性能
CREATE INDEX IF NOT EXISTS idx_regeneration_tasks_chapter ON regeneration_tasks(chapter_id);
CREATE INDEX IF NOT EXISTS idx_regeneration_tasks_project ON regeneration_tasks(project_id);
CREATE INDEX IF NOT EXISTS idx_regeneration_tasks_user ON regeneration_tasks(user_id);
CREATE INDEX IF NOT EXISTS idx_regeneration_tasks_status ON regeneration_tasks(status);
CREATE INDEX IF NOT EXISTS idx_regeneration_tasks_created ON regeneration_tasks(created_at DESC);
-- 添加注释
COMMENT ON TABLE regeneration_tasks IS '章节重新生成任务表,记录每次根据AI建议重新生成章节的任务';
COMMENT ON COLUMN regeneration_tasks.modification_instructions IS '合并后的完整修改指令';
COMMENT ON COLUMN regeneration_tasks.original_suggestions IS '原始AI分析建议列表';
COMMENT ON COLUMN regeneration_tasks.selected_suggestion_indices IS '用户选择的建议索引';
COMMENT ON COLUMN regeneration_tasks.preserve_elements IS '需要保留的元素配置(JSON)';
COMMENT ON COLUMN regeneration_tasks.focus_areas IS '重点优化方向列表(JSON)';
-- 修复外键约束(合并自 fix_all_missing_columns.sql
-- 删除可能存在问题的外键约束
ALTER TABLE regeneration_tasks
DROP CONSTRAINT IF EXISTS fk_regeneration_analysis;
-- 完成提示
SELECT '✅ 重新生成任务表创建完成,外键约束已修复' AS status;
+77
View File
@@ -0,0 +1,77 @@
#!/bin/bash
# Docker 容器启动入口脚本
# 功能:等待数据库就绪,执行迁移,启动应用
set -e # 遇到错误立即退出
echo "================================================"
echo "🚀 MuMuAINovel 启动中..."
echo "================================================"
# 数据库配置(从环境变量读取)
DB_HOST="${DB_HOST:-postgres}"
DB_PORT="${DB_PORT:-5432}"
DB_USER="${POSTGRES_USER:-mumuai}"
DB_NAME="${POSTGRES_DB:-mumuai_novel}"
# 等待数据库就绪
echo "⏳ 等待数据库启动..."
MAX_RETRIES=30
RETRY_COUNT=0
while ! nc -z "$DB_HOST" "$DB_PORT" 2>/dev/null; do
RETRY_COUNT=$((RETRY_COUNT + 1))
if [ $RETRY_COUNT -ge $MAX_RETRIES ]; then
echo "❌ 错误: 数据库连接超时(${MAX_RETRIES}秒)"
exit 1
fi
echo " 等待数据库... ($RETRY_COUNT/$MAX_RETRIES)"
sleep 1
done
echo "✅ 数据库连接成功"
# 额外等待,确保数据库完全就绪
echo "⏳ 等待数据库完全就绪..."
sleep 3
# 检查数据库是否可以接受连接
echo "🔍 检查数据库状态..."
if ! PGPASSWORD="${POSTGRES_PASSWORD}" psql -h "$DB_HOST" -U "$DB_USER" -d "$DB_NAME" -c "SELECT 1;" > /dev/null 2>&1; then
echo "❌ 数据库尚未就绪,继续等待..."
sleep 5
fi
echo "✅ 数据库已就绪"
# 运行数据库迁移
echo "================================================"
echo "🔄 执行数据库迁移..."
echo "================================================"
cd /app
# 统一使用 alembic upgrade head
# Alembic 会自动处理首次部署和增量迁移
echo "🔄 升级数据库到最新版本..."
alembic upgrade head
if [ $? -eq 0 ]; then
echo "✅ 数据库迁移成功"
else
echo "❌ 数据库迁移失败"
exit 1
fi
echo "================================================"
echo "🎉 启动应用服务..."
echo "================================================"
# 启动应用(使用 exec 替换当前进程,确保信号正确传递)
cd /app
exec uvicorn app.main:app \
--host "${APP_HOST:-0.0.0.0}" \
--port "${APP_PORT:-8000}" \
--log-level info \
--access-log \
--use-colors
-9
View File
@@ -1,9 +0,0 @@
-- 修复 projects 表中 user_id 字段长度不足的问题
-- 将 user_id 从 VARCHAR(36) 扩展到 VARCHAR(100)
ALTER TABLE projects ALTER COLUMN user_id TYPE VARCHAR(100);
-- 验证修改
SELECT column_name, data_type, character_maximum_length
FROM information_schema.columns
WHERE table_name = 'projects' AND column_name = 'user_id';
+184
View File
@@ -0,0 +1,184 @@
#!/usr/bin/env python3
"""
数据库自动迁移脚本
用于开发和生产环境的数据库迁移管理
"""
import subprocess
import sys
import os
from pathlib import Path
# 添加项目路径
project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root))
from app.logger import get_logger
logger = get_logger(__name__)
def run_command(cmd: list, description: str) -> bool:
"""运行命令并返回是否成功"""
try:
logger.info(f"🚀 {description}...")
result = subprocess.run(
cmd,
cwd=project_root,
capture_output=True,
text=True,
check=False
)
if result.returncode == 0:
logger.info(f"{description}成功")
if result.stdout:
print(result.stdout)
return True
else:
logger.error(f"{description}失败")
if result.stderr:
print(result.stderr, file=sys.stderr)
return False
except Exception as e:
logger.error(f"{description}异常: {e}")
return False
def create_migration(message: str = None):
"""创建新的迁移版本"""
if not message:
message = input("请输入迁移描述: ").strip()
if not message:
message = "auto_migration"
cmd = ["alembic", "revision", "--autogenerate", "-m", message]
return run_command(cmd, f"生成迁移: {message}")
def upgrade_database(revision: str = "head"):
"""升级数据库到指定版本"""
cmd = ["alembic", "upgrade", revision]
return run_command(cmd, f"升级数据库到: {revision}")
def downgrade_database(revision: str = "-1"):
"""降级数据库到指定版本"""
cmd = ["alembic", "downgrade", revision]
return run_command(cmd, f"降级数据库到: {revision}")
def show_current():
"""显示当前数据库版本"""
cmd = ["alembic", "current"]
return run_command(cmd, "查看当前版本")
def show_history():
"""显示迁移历史"""
cmd = ["alembic", "history", "--verbose"]
return run_command(cmd, "查看迁移历史")
def show_heads():
"""显示最新版本"""
cmd = ["alembic", "heads"]
return run_command(cmd, "查看最新版本")
def stamp_database(revision: str = "head"):
"""标记数据库版本(不执行迁移)"""
cmd = ["alembic", "stamp", revision]
return run_command(cmd, f"标记数据库版本: {revision}")
def auto_migrate():
"""自动迁移:生成并执行迁移"""
logger.info("=" * 60)
logger.info("🔄 开始自动迁移流程")
logger.info("=" * 60)
# 1. 创建迁移
if not create_migration("auto_migration"):
logger.error("❌ 自动迁移失败:无法生成迁移")
return False
# 2. 执行迁移
if not upgrade_database():
logger.error("❌ 自动迁移失败:无法执行迁移")
return False
logger.info("=" * 60)
logger.info("✅ 自动迁移完成")
logger.info("=" * 60)
return True
def init_database():
"""初始化数据库(首次部署)"""
logger.info("=" * 60)
logger.info("🔧 初始化数据库")
logger.info("=" * 60)
# 创建初始迁移
if not create_migration("initial_migration"):
logger.warning("⚠️ 无法创建初始迁移,可能已存在")
# 执行迁移
if not upgrade_database():
logger.error("❌ 初始化失败")
return False
logger.info("=" * 60)
logger.info("✅ 数据库初始化完成")
logger.info("=" * 60)
return True
def main():
"""主函数"""
if len(sys.argv) < 2:
print("使用方法:")
print(" python migrate.py create [message] - 创建新迁移")
print(" python migrate.py upgrade [revision] - 升级数据库(默认: head)")
print(" python migrate.py downgrade [revision] - 降级数据库(默认: -1)")
print(" python migrate.py current - 显示当前版本")
print(" python migrate.py history - 显示迁移历史")
print(" python migrate.py heads - 显示最新版本")
print(" python migrate.py stamp [revision] - 标记版本(默认: head)")
print(" python migrate.py auto - 自动迁移(生成+执行)")
print(" python migrate.py init - 初始化数据库")
sys.exit(1)
command = sys.argv[1]
if command == "create":
message = sys.argv[2] if len(sys.argv) > 2 else None
success = create_migration(message)
elif command == "upgrade":
revision = sys.argv[2] if len(sys.argv) > 2 else "head"
success = upgrade_database(revision)
elif command == "downgrade":
revision = sys.argv[2] if len(sys.argv) > 2 else "-1"
success = downgrade_database(revision)
elif command == "current":
success = show_current()
elif command == "history":
success = show_history()
elif command == "heads":
success = show_heads()
elif command == "stamp":
revision = sys.argv[2] if len(sys.argv) > 2 else "head"
success = stamp_database(revision)
elif command == "auto":
success = auto_migrate()
elif command == "init":
success = init_database()
else:
logger.error(f"❌ 未知命令: {command}")
success = False
sys.exit(0 if success else 1)
if __name__ == "__main__":
main()
@@ -1,859 +0,0 @@
#!/usr/bin/env python3
"""
SQLite to PostgreSQL 数据迁移脚本
使用方法:
python backend/scripts/migrate_sqlite_to_postgres.py
前置条件:
1. PostgreSQL数据库已创建
2. .env文件中DATABASE_URL已配置为PostgreSQL
3. SQLite数据文件存在于 backend/data/ 目录
"""
import asyncio
import sys
from pathlib import Path
from typing import List, Dict, Any
import logging
from datetime import datetime
# 添加项目根目录到Python路径
sys.path.insert(0, str(Path(__file__).parent.parent))
from sqlalchemy import create_engine, text, select
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker
from app.database import Base
from app.models import (
Project, Outline, Character, Chapter, GenerationHistory,
Settings, WritingStyle, ProjectDefaultStyle,
RelationshipType, CharacterRelationship, Organization, OrganizationMember,
StoryMemory, PlotAnalysis, AnalysisTask, BatchGenerationTask,
MCPPlugin
)
from app.config import settings
# 创建日志目录
log_dir = Path(__file__).parent.parent / "logs"
log_dir.mkdir(exist_ok=True)
# 生成日志文件名(带时间戳)
log_filename = log_dir / f"migration_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log"
# 设置日志 - 同时输出到控制台和文件
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[
logging.StreamHandler(), # 控制台输出
logging.FileHandler(log_filename, encoding='utf-8') # 文件输出
]
)
logger = logging.getLogger(__name__)
logger.info(f"📝 日志文件: {log_filename}")
class SQLiteToPostgresMigrator:
"""SQLite到PostgreSQL的数据迁移器"""
def __init__(self, sqlite_dir: Path, target_user_id: str):
"""
初始化迁移器
Args:
sqlite_dir: SQLite数据库文件目录
target_user_id: 目标用户ID(迁移后的数据归属)
"""
self.sqlite_dir = sqlite_dir
self.target_user_id = target_user_id
self.sqlite_files = list(sqlite_dir.glob("ai_story_user_*.db"))
# PostgreSQL连接
if "postgresql" not in settings.database_url:
raise ValueError("DATABASE_URL必须配置为PostgreSQL")
self.pg_engine = create_async_engine(
settings.database_url,
echo=False,
pool_pre_ping=True
)
self.pg_session_maker = async_sessionmaker(
self.pg_engine,
class_=AsyncSession,
expire_on_commit=False
)
async def migrate_all(self):
"""迁移所有SQLite数据库"""
if not self.sqlite_files:
logger.warning(f"未找到SQLite数据库文件: {self.sqlite_dir}")
return
logger.info(f"找到 {len(self.sqlite_files)} 个SQLite数据库文件")
# 创建PostgreSQL表结构
await self._create_tables()
# 初始化关系类型数据
await self._init_relationship_types()
# 逐个迁移
for sqlite_file in self.sqlite_files:
await self._migrate_single_db(sqlite_file)
# 重置自增序列
await self._reset_sequences()
logger.info("✅ 所有数据迁移完成")
async def _create_tables(self):
"""创建PostgreSQL表结构"""
logger.info("创建PostgreSQL表结构...")
async with self.pg_engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
logger.info("✅ 表结构创建完成")
async def _init_relationship_types(self):
"""初始化关系类型数据"""
logger.info("初始化关系类型数据...")
# 预置关系类型数据
relationship_types = [
# 家族关系
{"name": "父亲", "category": "family", "reverse_name": "子女", "intimacy_range": "high", "icon": "👨"},
{"name": "母亲", "category": "family", "reverse_name": "子女", "intimacy_range": "high", "icon": "👩"},
{"name": "兄弟", "category": "family", "reverse_name": "兄弟", "intimacy_range": "high", "icon": "👬"},
{"name": "姐妹", "category": "family", "reverse_name": "姐妹", "intimacy_range": "high", "icon": "👭"},
{"name": "子女", "category": "family", "reverse_name": "父母", "intimacy_range": "high", "icon": "👶"},
{"name": "配偶", "category": "family", "reverse_name": "配偶", "intimacy_range": "high", "icon": "💑"},
{"name": "恋人", "category": "family", "reverse_name": "恋人", "intimacy_range": "high", "icon": "💕"},
# 社交关系
{"name": "师父", "category": "social", "reverse_name": "徒弟", "intimacy_range": "high", "icon": "🎓"},
{"name": "徒弟", "category": "social", "reverse_name": "师父", "intimacy_range": "high", "icon": "📚"},
{"name": "朋友", "category": "social", "reverse_name": "朋友", "intimacy_range": "medium", "icon": "🤝"},
{"name": "同学", "category": "social", "reverse_name": "同学", "intimacy_range": "medium", "icon": "🎒"},
{"name": "邻居", "category": "social", "reverse_name": "邻居", "intimacy_range": "low", "icon": "🏘️"},
{"name": "知己", "category": "social", "reverse_name": "知己", "intimacy_range": "high", "icon": "💙"},
# 职业关系
{"name": "上司", "category": "professional", "reverse_name": "下属", "intimacy_range": "low", "icon": "👔"},
{"name": "下属", "category": "professional", "reverse_name": "上司", "intimacy_range": "low", "icon": "💼"},
{"name": "同事", "category": "professional", "reverse_name": "同事", "intimacy_range": "medium", "icon": "🤵"},
{"name": "合作伙伴", "category": "professional", "reverse_name": "合作伙伴", "intimacy_range": "medium", "icon": "🤜🤛"},
# 敌对关系
{"name": "敌人", "category": "hostile", "reverse_name": "敌人", "intimacy_range": "low", "icon": "⚔️"},
{"name": "仇人", "category": "hostile", "reverse_name": "仇人", "intimacy_range": "low", "icon": "💢"},
{"name": "竞争对手", "category": "hostile", "reverse_name": "竞争对手", "intimacy_range": "low", "icon": "🎯"},
{"name": "宿敌", "category": "hostile", "reverse_name": "宿敌", "intimacy_range": "low", "icon": ""},
]
try:
async with self.pg_session_maker() as session:
# 检查是否已经有数据
result = await session.execute(select(RelationshipType))
existing = result.scalars().first()
if existing:
logger.info("关系类型数据已存在,跳过初始化")
return
# 插入预置数据
logger.info("开始插入关系类型数据...")
for rt_data in relationship_types:
relationship_type = RelationshipType(**rt_data)
session.add(relationship_type)
await session.commit()
logger.info(f"✅ 成功插入 {len(relationship_types)} 条关系类型数据")
except Exception as e:
logger.error(f"初始化关系类型数据失败: {str(e)}", exc_info=True)
# 不抛出异常,继续迁移流程
logger.warning("关系类型初始化失败,将跳过有外键依赖的记录")
async def _migrate_single_db(self, sqlite_file: Path):
"""迁移单个SQLite数据库"""
# 从文件名提取user_id
filename = sqlite_file.stem # ai_story_user_xxx
if filename.startswith("ai_story_user_"):
user_id = filename.replace("ai_story_user_", "")
else:
user_id = self.target_user_id
logger.info(f"\n{'='*60}")
logger.info(f"开始迁移: {sqlite_file.name} -> user_id: {user_id}")
logger.info(f"{'='*60}")
# 创建SQLite连接
sqlite_url = f"sqlite+aiosqlite:///{sqlite_file.absolute()}"
sqlite_engine = create_async_engine(sqlite_url, echo=False)
sqlite_session_maker = async_sessionmaker(
sqlite_engine,
class_=AsyncSession,
expire_on_commit=False
)
try:
# 迁移各个表
async with sqlite_session_maker() as sqlite_session:
async with self.pg_session_maker() as pg_session:
# 按照依赖顺序迁移
await self._migrate_table(
sqlite_session, pg_session, user_id, Settings, "设置"
)
await self._migrate_table(
sqlite_session, pg_session, user_id, Project, "项目"
)
await self._migrate_table(
sqlite_session, pg_session, user_id, Character, "角色"
)
await self._migrate_table(
sqlite_session, pg_session, user_id, Outline, "大纲"
)
await self._migrate_table(
sqlite_session, pg_session, user_id, Chapter, "章节"
)
await self._migrate_table(
sqlite_session, pg_session, user_id, CharacterRelationship, "角色关系"
)
await self._migrate_table(
sqlite_session, pg_session, user_id, Organization, "组织"
)
await self._migrate_table(
sqlite_session, pg_session, user_id, OrganizationMember, "组织成员"
)
await self._migrate_table(
sqlite_session, pg_session, user_id, GenerationHistory, "生成历史"
)
await self._migrate_table(
sqlite_session, pg_session, user_id, WritingStyle, "写作风格"
)
await self._migrate_table(
sqlite_session, pg_session, user_id, ProjectDefaultStyle, "项目默认风格"
)
await self._migrate_table(
sqlite_session, pg_session, user_id, StoryMemory, "记忆"
)
await self._migrate_table(
sqlite_session, pg_session, user_id, PlotAnalysis, "剧情分析"
)
await self._migrate_table(
sqlite_session, pg_session, user_id, AnalysisTask, "分析任务"
)
await self._migrate_table(
sqlite_session, pg_session, user_id, BatchGenerationTask, "批量生成任务"
)
await self._migrate_table(
sqlite_session, pg_session, user_id, MCPPlugin, "MCP插件"
)
await pg_session.commit()
logger.info(f"{sqlite_file.name} 迁移完成")
except Exception as e:
logger.error(f"❌ 迁移失败: {e}", exc_info=True)
finally:
await sqlite_engine.dispose()
async def _migrate_table(
self,
sqlite_session: AsyncSession,
pg_session: AsyncSession,
user_id: str,
model_class,
table_name: str
):
"""迁移单个表的数据"""
try:
# 获取SQLite表中实际存在的列
sqlite_table = model_class.__table__
sqlite_conn = await sqlite_session.connection()
# 查询SQLite表结构
inspect_result = await sqlite_conn.execute(
text(f"PRAGMA table_info({sqlite_table.name})")
)
sqlite_columns = {row[1] for row in inspect_result.fetchall()} # row[1]是列名
# 构建只包含SQLite中存在的列的查询
available_columns = [
c for c in model_class.__table__.columns
if c.name in sqlite_columns
]
if not available_columns:
logger.warning(f" ⚠️ {table_name}: 表结构不匹配,跳过")
return
# 从SQLite读取数据(只查询存在的列)
result = await sqlite_session.execute(
select(*available_columns)
)
records = result.all()
if not records:
logger.info(f" - {table_name}: 无数据")
return
# 为每条记录创建字典并添加user_id
migrated_count = 0
skipped_count = 0
for record in records:
# 从查询结果构建字典
record_dict = {}
for i, col in enumerate(available_columns):
record_dict[col.name] = record[i]
# 添加user_id(如果PostgreSQL模型有这个字段但SQLite没有)
if hasattr(model_class, 'user_id') and 'user_id' not in record_dict:
record_dict['user_id'] = user_id
# 验证字段长度(防止超长字段导致插入失败)
if not self._validate_field_lengths(model_class, record_dict, table_name):
skipped_count += 1
record_id = record_dict.get('id', 'unknown')
logger.warning(f" ⚠️ [{table_name}] 跳过超长字段记录 ID={record_id}")
continue
# 验证外键引用(针对有外键的表)
validation_result = await self._validate_foreign_keys(pg_session, model_class, record_dict)
if not validation_result:
skipped_count += 1
record_id = record_dict.get('id', 'unknown')
logger.warning(f" ⚠️ [{table_name}] 跳过无效外键记录 ID={record_id}")
# 输出记录详情以便调试
if model_class.__tablename__ == 'story_memories':
logger.warning(f" 记忆详情: project_id={record_dict.get('project_id')}, "
f"chapter_id={record_dict.get('chapter_id')}, "
f"type={record_dict.get('memory_type')}")
elif model_class.__tablename__ == 'character_relationships':
logger.warning(f" 关系详情: project_id={record_dict.get('project_id')}, "
f"from={record_dict.get('character_from_id')}, "
f"to={record_dict.get('character_to_id')}, "
f"type_id={record_dict.get('relationship_type_id')}")
elif model_class.__tablename__ == 'organizations':
logger.warning(f" 组织详情: project_id={record_dict.get('project_id')}, "
f"character_id={record_dict.get('character_id')}")
elif model_class.__tablename__ == 'organization_members':
logger.warning(f" 成员详情: org_id={record_dict.get('organization_id')}, "
f"character_id={record_dict.get('character_id')}")
elif model_class.__tablename__ == 'writing_styles':
logger.warning(f" 写作风格详情: project_id={record_dict.get('project_id')}, "
f"name={record_dict.get('name')}, "
f"style_type={record_dict.get('style_type')}")
elif model_class.__tablename__ == 'characters':
logger.warning(f" 角色详情: project_id={record_dict.get('project_id')}, "
f"name={record_dict.get('name')}, "
f"is_organization={record_dict.get('is_organization')}")
elif model_class.__tablename__ == 'outlines':
logger.warning(f" 大纲详情: project_id={record_dict.get('project_id')}, "
f"title={record_dict.get('title')}")
elif model_class.__tablename__ == 'chapters':
logger.warning(f" 章节详情: project_id={record_dict.get('project_id')}, "
f"title={record_dict.get('title')}, "
f"chapter_number={record_dict.get('chapter_number')}")
elif model_class.__tablename__ == 'generation_history':
logger.warning(f" 生成历史详情: project_id={record_dict.get('project_id')}, "
f"chapter_id={record_dict.get('chapter_id')}, "
f"model={record_dict.get('model')}")
elif model_class.__tablename__ == 'plot_analysis':
logger.warning(f" 剧情分析详情: project_id={record_dict.get('project_id')}, "
f"chapter_id={record_dict.get('chapter_id')}, "
f"plot_stage={record_dict.get('plot_stage')}")
elif model_class.__tablename__ == 'analysis_tasks':
logger.warning(f" 分析任务详情: chapter_id={record_dict.get('chapter_id')}, "
f"project_id={record_dict.get('project_id')}, "
f"status={record_dict.get('status')}")
elif model_class.__tablename__ == 'batch_generation_tasks':
logger.warning(f" 批量生成任务详情: project_id={record_dict.get('project_id')}, "
f"status={record_dict.get('status')}, "
f"completed={record_dict.get('completed_chapters')}/{record_dict.get('total_chapters')}")
elif model_class.__tablename__ == 'project_default_styles':
logger.warning(f" 项目默认风格详情: project_id={record_dict.get('project_id')}, "
f"style_id={record_dict.get('style_id')}")
continue
# 检查记录是否已存在(避免主键冲突)
record_id = record_dict.get('id')
if record_id and await self._record_exists(pg_session, model_class, record_id):
skipped_count += 1
logger.debug(f" 跳过已存在的记录: {record_id}")
continue
# 创建新记录
try:
new_record = model_class(**record_dict)
pg_session.add(new_record)
migrated_count += 1
except Exception as e:
logger.warning(f" ⚠️ 跳过无效记录: {str(e)[:100]}")
skipped_count += 1
continue
await pg_session.flush()
if skipped_count > 0:
logger.info(f"{table_name}: {migrated_count} 条记录(跳过 {skipped_count} 条无效记录)")
else:
logger.info(f"{table_name}: {migrated_count} 条记录")
except Exception as e:
logger.error(f"{table_name} 迁移失败: {e}")
raise
async def _record_exists(
self,
pg_session: AsyncSession,
model_class,
record_id: Any
) -> bool:
"""
检查记录是否已存在
Args:
pg_session: PostgreSQL会话
model_class: 模型类
record_id: 记录ID
Returns:
bool: 记录是否存在
"""
try:
# 获取主键列
pk_column = list(model_class.__table__.primary_key.columns)[0]
result = await pg_session.execute(
select(pk_column).where(pk_column == record_id)
)
return result.scalar_one_or_none() is not None
except Exception:
return False
async def _validate_foreign_keys(
self,
pg_session: AsyncSession,
model_class,
record_dict: Dict[str, Any]
) -> bool:
"""
验证记录的外键是否有效
Args:
pg_session: PostgreSQL会话
model_class: 模型类
record_dict: 记录字典
Returns:
bool: 外键是否全部有效
"""
from app.models import Character, Project, Chapter
# 使用no_autoflush防止过早flush
with pg_session.no_autoflush:
# 针对StoryMemory表验证外键
if model_class.__tablename__ == 'story_memories':
# 验证project_id
project_id = record_dict.get('project_id')
if project_id:
result = await pg_session.execute(
select(Project.id).where(Project.id == project_id)
)
if not result.scalar_one_or_none():
logger.warning(f" ❌ [记忆] 无效的project_id: {project_id}")
return False
# 验证chapter_id(可选)
chapter_id = record_dict.get('chapter_id')
if chapter_id:
result = await pg_session.execute(
select(Chapter.id).where(Chapter.id == chapter_id)
)
if not result.scalar_one_or_none():
logger.warning(f" ❌ [记忆] 无效的chapter_id: {chapter_id}")
return False
# 针对CharacterRelationship表验证外键
elif model_class.__tablename__ == 'character_relationships':
# 验证project_id
project_id = record_dict.get('project_id')
if project_id:
result = await pg_session.execute(
select(Project.id).where(Project.id == project_id)
)
if not result.scalar_one_or_none():
logger.warning(f" ❌ 无效的project_id: {project_id}")
return False
# 验证character_from_id
char_from_id = record_dict.get('character_from_id')
if char_from_id:
result = await pg_session.execute(
select(Character.id).where(Character.id == char_from_id)
)
if not result.scalar_one_or_none():
logger.warning(f" ❌ 无效的character_from_id: {char_from_id}")
return False
# 验证character_to_id
char_to_id = record_dict.get('character_to_id')
if char_to_id:
result = await pg_session.execute(
select(Character.id).where(Character.id == char_to_id)
)
if not result.scalar_one_or_none():
logger.warning(f" ❌ 无效的character_to_id: {char_to_id}")
return False
# 验证relationship_type_id
rel_type_id = record_dict.get('relationship_type_id')
if rel_type_id:
result = await pg_session.execute(
select(RelationshipType.id).where(RelationshipType.id == rel_type_id)
)
if not result.scalar_one_or_none():
logger.warning(f" ❌ 无效的relationship_type_id: {rel_type_id}")
return False
# 针对Organization表验证外键
elif model_class.__tablename__ == 'organizations':
# 验证character_id
char_id = record_dict.get('character_id')
if char_id:
result = await pg_session.execute(
select(Character.id).where(Character.id == char_id)
)
if not result.scalar_one_or_none():
logger.warning(f" ❌ [组织] 无效的character_id: {char_id}")
return False
# 针对OrganizationMember表验证外键
elif model_class.__tablename__ == 'organization_members':
from app.models import Organization
# 验证organization_id
org_id = record_dict.get('organization_id')
if org_id:
result = await pg_session.execute(
select(Organization.id).where(Organization.id == org_id)
)
if not result.scalar_one_or_none():
logger.warning(f" ❌ 无效的organization_id: {org_id}")
return False
# 验证character_id
char_id = record_dict.get('character_id')
if char_id:
result = await pg_session.execute(
select(Character.id).where(Character.id == char_id)
)
if not result.scalar_one_or_none():
logger.warning(f" ❌ [组织成员] 无效的character_id: {char_id}")
return False
# 针对Character表验证外键
elif model_class.__tablename__ == 'characters':
# 验证project_id
project_id = record_dict.get('project_id')
if project_id:
result = await pg_session.execute(
select(Project.id).where(Project.id == project_id)
)
if not result.scalar_one_or_none():
logger.warning(f" ❌ [角色] 无效的project_id: {project_id}")
return False
# 针对Outline表验证外键
elif model_class.__tablename__ == 'outlines':
# 验证project_id
project_id = record_dict.get('project_id')
if project_id:
result = await pg_session.execute(
select(Project.id).where(Project.id == project_id)
)
if not result.scalar_one_or_none():
logger.warning(f" ❌ [大纲] 无效的project_id: {project_id}")
return False
# 针对Chapter表验证外键
elif model_class.__tablename__ == 'chapters':
# 验证project_id
project_id = record_dict.get('project_id')
if project_id:
result = await pg_session.execute(
select(Project.id).where(Project.id == project_id)
)
if not result.scalar_one_or_none():
logger.warning(f" ❌ [章节] 无效的project_id: {project_id}")
return False
# 针对WritingStyle表验证外键
elif model_class.__tablename__ == 'writing_styles':
# 验证project_id(可选)
project_id = record_dict.get('project_id')
if project_id:
result = await pg_session.execute(
select(Project.id).where(Project.id == project_id)
)
if not result.scalar_one_or_none():
logger.warning(f" ❌ [写作风格] 无效的project_id: {project_id}")
return False
# 针对GenerationHistory表验证外键
elif model_class.__tablename__ == 'generation_history':
# 验证project_id
project_id = record_dict.get('project_id')
if project_id:
result = await pg_session.execute(
select(Project.id).where(Project.id == project_id)
)
if not result.scalar_one_or_none():
logger.warning(f" ❌ [生成历史] 无效的project_id: {project_id}")
return False
# 验证chapter_id(可选)
chapter_id = record_dict.get('chapter_id')
if chapter_id:
result = await pg_session.execute(
select(Chapter.id).where(Chapter.id == chapter_id)
)
if not result.scalar_one_or_none():
logger.warning(f" ❌ [生成历史] 无效的chapter_id: {chapter_id}")
return False
# 针对PlotAnalysis表验证外键
elif model_class.__tablename__ == 'plot_analysis':
# 验证project_id(必需)
project_id = record_dict.get('project_id')
if project_id:
result = await pg_session.execute(
select(Project.id).where(Project.id == project_id)
)
if not result.scalar_one_or_none():
logger.warning(f" ❌ [剧情分析] 无效的project_id: {project_id}")
return False
# 验证chapter_id(必需)
chapter_id = record_dict.get('chapter_id')
if chapter_id:
result = await pg_session.execute(
select(Chapter.id).where(Chapter.id == chapter_id)
)
if not result.scalar_one_or_none():
logger.warning(f" ❌ [剧情分析] 无效的chapter_id: {chapter_id}")
return False
# 针对AnalysisTask表验证外键
elif model_class.__tablename__ == 'analysis_tasks':
# 验证chapter_id(必需)
chapter_id = record_dict.get('chapter_id')
if chapter_id:
result = await pg_session.execute(
select(Chapter.id).where(Chapter.id == chapter_id)
)
if not result.scalar_one_or_none():
logger.warning(f" ❌ [分析任务] 无效的chapter_id: {chapter_id}")
return False
# 验证project_id
project_id = record_dict.get('project_id')
if project_id:
result = await pg_session.execute(
select(Project.id).where(Project.id == project_id)
)
if not result.scalar_one_or_none():
logger.warning(f" ❌ [分析任务] 无效的project_id: {project_id}")
return False
# 针对BatchGenerationTask表验证外键
elif model_class.__tablename__ == 'batch_generation_tasks':
# 验证project_id(必需)
project_id = record_dict.get('project_id')
if project_id:
result = await pg_session.execute(
select(Project.id).where(Project.id == project_id)
)
if not result.scalar_one_or_none():
logger.warning(f" ❌ [批量生成任务] 无效的project_id: {project_id}")
return False
# 针对ProjectDefaultStyle表验证外键
elif model_class.__tablename__ == 'project_default_styles':
from app.models import WritingStyle
# 验证project_id(必需)
project_id = record_dict.get('project_id')
if project_id:
result = await pg_session.execute(
select(Project.id).where(Project.id == project_id)
)
if not result.scalar_one_or_none():
logger.warning(f" ❌ [项目默认风格] 无效的project_id: {project_id}")
return False
# 验证style_id(必需)
style_id = record_dict.get('style_id')
if style_id:
result = await pg_session.execute(
select(WritingStyle.id).where(WritingStyle.id == style_id)
)
if not result.scalar_one_or_none():
logger.warning(f" ❌ [项目默认风格] 无效的style_id: {style_id}")
return False
return True
def _validate_field_lengths(
self,
model_class,
record_dict: Dict[str, Any],
table_name: str
) -> bool:
"""
验证记录的字段长度是否符合模型定义
Args:
model_class: 模型类
record_dict: 记录字典
table_name: 表名(用于日志)
Returns:
bool: 字段长度是否全部有效
"""
from sqlalchemy import String
# 检查所有字符串类型字段
for column in model_class.__table__.columns:
# 只检查有长度限制的String类型字段
if isinstance(column.type, String) and column.type.length:
field_name = column.name
field_value = record_dict.get(field_name)
max_length = column.type.length
# 如果字段有值且超过最大长度
if field_value and isinstance(field_value, str) and len(field_value) > max_length:
logger.warning(
f" ❌ [{table_name}] 字段 '{field_name}' 超长: "
f"{len(field_value)} > {max_length} (截断了 {len(field_value) - max_length} 字符)"
)
# 对于敏感字段如API密钥,记录部分内容
if field_name in ['api_key', 'api_base_url']:
preview = field_value[:50] + "..." + field_value[-20:] if len(field_value) > 70 else field_value
logger.warning(f" 值预览: {preview}")
return False
return True
async def _reset_sequences(self):
"""重置PostgreSQL的自增序列到正确的值"""
logger.info("\n" + "="*60)
logger.info("重置自增序列...")
logger.info("="*60)
# 需要重置序列的表(使用Integer自增主键的表)
tables_with_sequences = [
('relationship_types', 'id'),
('writing_styles', 'id'),
('project_default_styles', 'id'),
]
async with self.pg_session_maker() as session:
for table_name, id_column in tables_with_sequences:
try:
# 获取表中当前最大ID
result = await session.execute(
text(f"SELECT MAX({id_column}) FROM {table_name}")
)
max_id = result.scalar()
if max_id is not None:
# 重置序列到 max_id + 1
sequence_name = f"{table_name}_{id_column}_seq"
await session.execute(
text(f"SELECT setval('{sequence_name}', :max_id, true)"),
{"max_id": max_id}
)
logger.info(f"{table_name}: 序列重置到 {max_id}")
else:
logger.info(f" - {table_name}: 表为空,跳过序列重置")
except Exception as e:
logger.warning(f" ⚠️ {table_name}: 序列重置失败 - {str(e)}")
await session.commit()
logger.info("✅ 序列重置完成")
async def cleanup(self):
"""清理资源"""
await self.pg_engine.dispose()
async def main():
"""主函数"""
banner = """
╔══════════════════════════════════════════════════════════════╗
║ SQLite to PostgreSQL 数据迁移工具 ║
║ ║
║ 此工具将SQLite数据迁移到PostgreSQL ║
║ 请确保: ║
║ 1. PostgreSQL数据库已创建 ║
║ 2. .env中DATABASE_URL已配置为PostgreSQL ║
║ 3. SQLite数据文件存在 ║
╚══════════════════════════════════════════════════════════════╝
"""
print(banner)
logger.info(banner)
# 配置
sqlite_dir = Path(__file__).parent.parent / "data"
target_user_id = "migrated_user" # 默认用户ID
config_info = f"""
配置信息:
SQLite目录: {sqlite_dir}
PostgreSQL: {settings.database_url}
目标用户ID: {target_user_id}
日志文件: {log_filename}
"""
print(config_info)
logger.info(config_info)
# 确认
response = input("是否继续迁移? (yes/no): ")
if response.lower() not in ['yes', 'y']:
print("已取消迁移")
return
# 执行迁移
migrator = SQLiteToPostgresMigrator(sqlite_dir, target_user_id)
try:
await migrator.migrate_all()
success_msg = """
🎉 数据迁移成功完成!
下一步:
1. 测试应用功能
2. 验证数据完整性
3. 备份SQLite文件后可删除
详细日志已保存到: {}
""".format(log_filename)
print(success_msg)
logger.info(success_msg)
except Exception as e:
error_msg = f"\n❌ 迁移失败: {e}\n详细日志已保存到: {log_filename}"
print(error_msg)
logger.error("迁移过程出错", exc_info=True)
finally:
await migrator.cleanup()
logger.info(f"🔒 数据库连接已关闭,日志文件: {log_filename}")
if __name__ == "__main__":
asyncio.run(main())
-224
View File
@@ -1,224 +0,0 @@
"""
用户数据迁移脚本 - 从JSON文件迁移到数据库
"""
import asyncio
import json
import os
import sys
from pathlib import Path
# 添加项目根目录到 Python 路径
project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root))
from app.user_manager import user_manager
from app.user_password import password_manager
from app.config import DATA_DIR
async def migrate_users():
"""迁移用户数据"""
users_file = DATA_DIR / "users.json"
if not users_file.exists():
print("❌ 用户数据文件不存在,跳过迁移")
return 0
try:
with open(users_file, "r", encoding="utf-8") as f:
users_data = json.load(f)
if not users_data:
print("️ 用户数据为空,跳过迁移")
return 0
migrated_count = 0
for user_id, user_info in users_data.items():
try:
# 迁移用户基本信息
await user_manager.create_or_update_from_linuxdo(
linuxdo_id=user_info["linuxdo_id"],
username=user_info["username"],
display_name=user_info["display_name"],
avatar_url=user_info.get("avatar_url"),
trust_level=user_info.get("trust_level", 0)
)
# 如果用户是管理员,设置管理员权限
if user_info.get("is_admin", False):
await user_manager.set_admin(user_id, True)
migrated_count += 1
print(f"✅ 迁移用户: {user_info['username']} ({user_id})")
except Exception as e:
print(f"❌ 迁移用户 {user_id} 失败: {e}")
print(f"\n✅ 用户数据迁移完成: {migrated_count}/{len(users_data)} 个用户")
# 备份原文件
backup_file = DATA_DIR / "users.json.backup"
os.rename(users_file, backup_file)
print(f"📦 原文件已备份到: {backup_file}")
return migrated_count
except Exception as e:
print(f"❌ 迁移用户数据失败: {e}")
return 0
async def migrate_passwords():
"""迁移密码数据"""
passwords_file = DATA_DIR / "user_passwords.json"
if not passwords_file.exists():
print("❌ 密码数据文件不存在,跳过迁移")
return 0
try:
with open(passwords_file, "r", encoding="utf-8") as f:
passwords_data = json.load(f)
if not passwords_data:
print("️ 密码数据为空,跳过迁移")
return 0
migrated_count = 0
for user_id, pwd_info in passwords_data.items():
try:
# 直接插入密码记录(已经是哈希值)
from app.models.user import UserPassword
from app.user_password import password_manager as pm
async with await pm._get_session() as session:
from sqlalchemy import select
# 检查是否已存在
result = await session.execute(
select(UserPassword).where(UserPassword.user_id == user_id)
)
existing = result.scalar_one_or_none()
if existing:
print(f"️ 密码已存在,跳过: {pwd_info['username']} ({user_id})")
continue
# 创建密码记录
from datetime import datetime
pwd_record = UserPassword(
user_id=user_id,
username=pwd_info["username"],
password_hash=pwd_info["password_hash"],
has_custom_password=pwd_info.get("has_custom_password", False),
created_at=datetime.now(),
updated_at=datetime.now()
)
session.add(pwd_record)
await session.commit()
migrated_count += 1
print(f"✅ 迁移密码: {pwd_info['username']} ({user_id})")
except Exception as e:
print(f"❌ 迁移密码 {user_id} 失败: {e}")
print(f"\n✅ 密码数据迁移完成: {migrated_count}/{len(passwords_data)} 个密码")
# 备份原文件
backup_file = DATA_DIR / "user_passwords.json.backup"
os.rename(passwords_file, backup_file)
print(f"📦 原文件已备份到: {backup_file}")
return migrated_count
except Exception as e:
print(f"❌ 迁移密码数据失败: {e}")
return 0
async def migrate_admins():
"""迁移管理员列表"""
admins_file = DATA_DIR / "admins.json"
if not admins_file.exists():
print("❌ 管理员数据文件不存在,跳过迁移")
return 0
try:
with open(admins_file, "r", encoding="utf-8") as f:
admins_data = json.load(f)
admin_list = admins_data.get("admins", [])
if not admin_list:
print("️ 管理员列表为空,跳过迁移")
return 0
migrated_count = 0
for user_id in admin_list:
try:
# 设置管理员权限
success = await user_manager.set_admin(user_id, True)
if success:
migrated_count += 1
print(f"✅ 设置管理员: {user_id}")
else:
print(f"⚠️ 用户不存在或已是管理员: {user_id}")
except Exception as e:
print(f"❌ 设置管理员 {user_id} 失败: {e}")
print(f"\n✅ 管理员数据迁移完成: {migrated_count}/{len(admin_list)} 个管理员")
# 备份原文件
backup_file = DATA_DIR / "admins.json.backup"
os.rename(admins_file, backup_file)
print(f"📦 原文件已备份到: {backup_file}")
return migrated_count
except Exception as e:
print(f"❌ 迁移管理员数据失败: {e}")
return 0
async def main():
"""主函数"""
print("=" * 60)
print("用户数据迁移工具 - JSON 到数据库")
print("=" * 60)
print()
# 迁移用户
print("📋 步骤 1/3: 迁移用户数据")
print("-" * 60)
user_count = await migrate_users()
print()
# 迁移密码
print("📋 步骤 2/3: 迁移密码数据")
print("-" * 60)
pwd_count = await migrate_passwords()
print()
# 迁移管理员
print("📋 步骤 3/3: 迁移管理员数据")
print("-" * 60)
admin_count = await migrate_admins()
print()
# 总结
print("=" * 60)
print("迁移完成")
print("=" * 60)
print(f"✅ 用户: {user_count}")
print(f"✅ 密码: {pwd_count}")
print(f"✅ 管理员: {admin_count}")
print()
print("💡 提示: 原文件已备份为 .backup 后缀")
print("💡 如需回滚,请删除数据库文件并恢复 .backup 文件")
if __name__ == "__main__":
asyncio.run(main())
@@ -1,300 +0,0 @@
"""
用户数据迁移脚本 - 从JSON文件迁移到PostgreSQL数据库
使用方法:
python migrate_users_to_postgres.py
python migrate_users_to_postgres.py --db-url postgresql+asyncpg://user:pass@localhost/dbname
"""
import asyncio
import json
import os
import sys
import argparse
from pathlib import Path
from datetime import datetime
# 添加项目根目录到 Python 路径
project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root))
from sqlalchemy import select, text
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker
from app.config import settings, DATA_DIR
async def create_tables(engine):
"""创建用户相关表"""
from app.database import Base
from app.models.user import User, UserPassword
print("📋 创建数据库表...")
async with engine.begin() as conn:
# 只创建用户相关的表
await conn.run_sync(User.metadata.create_all)
await conn.run_sync(UserPassword.metadata.create_all)
print("✅ 表创建成功")
async def migrate_users(session):
"""迁移用户数据"""
from app.models.user import User as UserModel
users_file = DATA_DIR / "users.json"
if not users_file.exists():
print("️ 用户数据文件不存在,跳过迁移")
return 0
try:
with open(users_file, "r", encoding="utf-8") as f:
users_data = json.load(f)
if not users_data:
print("️ 用户数据为空,跳过迁移")
return 0
migrated_count = 0
for user_id, user_info in users_data.items():
try:
# 检查用户是否已存在
result = await session.execute(
select(UserModel).where(UserModel.user_id == user_id)
)
existing = result.scalar_one_or_none()
if existing:
print(f"️ 用户已存在,跳过: {user_info['username']} ({user_id})")
continue
# 创建用户记录
user = UserModel(
user_id=user_id,
username=user_info["username"],
display_name=user_info["display_name"],
avatar_url=user_info.get("avatar_url"),
trust_level=user_info.get("trust_level", 0),
is_admin=user_info.get("is_admin", False),
linuxdo_id=user_info["linuxdo_id"],
created_at=datetime.fromisoformat(user_info.get("created_at", datetime.now().isoformat())),
last_login=datetime.fromisoformat(user_info.get("last_login", datetime.now().isoformat()))
)
session.add(user)
migrated_count += 1
print(f"✅ 迁移用户: {user_info['username']} ({user_id})")
except Exception as e:
print(f"❌ 迁移用户 {user_id} 失败: {e}")
await session.commit()
print(f"\n✅ 用户数据迁移完成: {migrated_count}/{len(users_data)} 个用户")
return migrated_count
except Exception as e:
print(f"❌ 迁移用户数据失败: {e}")
await session.rollback()
return 0
async def migrate_passwords(session):
"""迁移密码数据"""
from app.models.user import UserPassword
passwords_file = DATA_DIR / "user_passwords.json"
if not passwords_file.exists():
print("️ 密码数据文件不存在,跳过迁移")
return 0
try:
with open(passwords_file, "r", encoding="utf-8") as f:
passwords_data = json.load(f)
if not passwords_data:
print("️ 密码数据为空,跳过迁移")
return 0
migrated_count = 0
for user_id, pwd_info in passwords_data.items():
try:
# 检查密码是否已存在
result = await session.execute(
select(UserPassword).where(UserPassword.user_id == user_id)
)
existing = result.scalar_one_or_none()
if existing:
print(f"️ 密码已存在,跳过: {pwd_info['username']} ({user_id})")
continue
# 创建密码记录
pwd_record = UserPassword(
user_id=user_id,
username=pwd_info["username"],
password_hash=pwd_info["password_hash"],
has_custom_password=pwd_info.get("has_custom_password", False),
created_at=datetime.now(),
updated_at=datetime.now()
)
session.add(pwd_record)
migrated_count += 1
print(f"✅ 迁移密码: {pwd_info['username']} ({user_id})")
except Exception as e:
print(f"❌ 迁移密码 {user_id} 失败: {e}")
await session.commit()
print(f"\n✅ 密码数据迁移完成: {migrated_count}/{len(passwords_data)} 个密码")
return migrated_count
except Exception as e:
print(f"❌ 迁移密码数据失败: {e}")
await session.rollback()
return 0
async def backup_json_files():
"""备份原始JSON文件"""
files_to_backup = ["users.json", "user_passwords.json", "admins.json"]
print("\n📦 备份原始文件...")
for filename in files_to_backup:
source = DATA_DIR / filename
if source.exists():
backup = DATA_DIR / f"{filename}.backup.{datetime.now().strftime('%Y%m%d_%H%M%S')}"
import shutil
shutil.copy2(source, backup)
print(f"✅ 备份: {filename} -> {backup.name}")
async def main(db_url=None):
"""主函数
Args:
db_url: 可选的数据库URL,如果不提供则使用配置文件中的
"""
print("=" * 70)
print("用户数据迁移工具 - JSON 到 PostgreSQL")
print("=" * 70)
print()
# 确定使用的数据库URL
target_db_url = db_url if db_url else settings.database_url
# 检查数据库配置
if "postgresql" not in target_db_url:
print("❌ 错误: 未指定 PostgreSQL 数据库")
if not db_url:
print(f" 当前配置: {settings.database_url}")
print(" 请使用 --db-url 参数指定PostgreSQL数据库,或在 .env 中配置 DATABASE_URL")
else:
print(f" 提供的URL: {target_db_url}")
print()
print("示例:")
print(" python migrate_users_to_postgres.py --db-url postgresql+asyncpg://user:pass@localhost/dbname")
return
# 隐藏密码部分显示
display_url = target_db_url
if '@' in display_url:
parts = display_url.split('@')
if ':' in parts[0]:
user_part = parts[0].split(':')[0]
display_url = f"{user_part}:****@{parts[1]}"
print(f"📊 目标数据库: {display_url}")
print()
try:
# 创建数据库引擎
engine = create_async_engine(
target_db_url,
echo=False,
future=True,
pool_pre_ping=True,
)
# 创建表
await create_tables(engine)
print()
# 创建会话
async_session = async_sessionmaker(
engine,
class_=AsyncSession,
expire_on_commit=False
)
# 迁移用户
print("📋 步骤 1/2: 迁移用户数据")
print("-" * 70)
async with async_session() as session:
user_count = await migrate_users(session)
print()
# 迁移密码
print("📋 步骤 2/2: 迁移密码数据")
print("-" * 70)
async with async_session() as session:
pwd_count = await migrate_passwords(session)
print()
# 备份原文件
await backup_json_files()
print()
# 总结
print("=" * 70)
print("迁移完成")
print("=" * 70)
print(f"✅ 用户: {user_count}")
print(f"✅ 密码: {pwd_count}")
print()
print("💡 提示:")
print(" - 原文件已备份(带时间戳)")
print(" - 可以安全删除 users.json 和 user_passwords.json")
print(" - 如需回滚,请从备份文件恢复")
print()
# 关闭引擎
await engine.dispose()
except Exception as e:
print(f"\n❌ 迁移过程出错: {e}")
import traceback
traceback.print_exc()
sys.exit(1)
if __name__ == "__main__":
# 解析命令行参数
parser = argparse.ArgumentParser(
description="迁移用户数据从JSON到PostgreSQL数据库",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
示例:
# 使用 .env 配置的数据库
python migrate_users_to_postgres.py
# 指定数据库URL
python migrate_users_to_postgres.py --db-url postgresql+asyncpg://user:pass@localhost/dbname
# 使用环境变量
DATABASE_URL=postgresql+asyncpg://user:pass@localhost/db python migrate_users_to_postgres.py
"""
)
parser.add_argument(
"--db-url",
type=str,
help="PostgreSQL数据库连接URL (格式: postgresql+asyncpg://user:password@host:port/database)",
default=None
)
args = parser.parse_args()
# 运行迁移
asyncio.run(main(db_url=args.db_url))
@@ -1,36 +0,0 @@
-- 迁移写作风格从项目级别到用户级别
-- 将 writing_styles 表的 project_id 字段改为 user_id
-- 步骤1: 添加新的 user_id 字段
ALTER TABLE writing_styles ADD COLUMN user_id VARCHAR(255);
-- 步骤2: 将现有数据从 project_id 映射到 user_id
-- 通过 projects 表关联,将项目的用户ID填充到风格的 user_id
UPDATE writing_styles ws
SET user_id = (
SELECT p.user_id
FROM projects p
WHERE p.id = ws.project_id
)
WHERE ws.project_id IS NOT NULL;
-- 步骤3: 添加外键约束
ALTER TABLE writing_styles
ADD CONSTRAINT fk_writing_styles_user
FOREIGN KEY (user_id) REFERENCES users(user_id) ON DELETE CASCADE;
-- 步骤4: 删除旧的 project_id 外键约束
ALTER TABLE writing_styles DROP CONSTRAINT IF EXISTS writing_styles_project_id_fkey;
-- 步骤5: 删除 project_id 列
ALTER TABLE writing_styles DROP COLUMN project_id;
-- 步骤6: 更新注释
COMMENT ON COLUMN writing_styles.user_id IS '所属用户ID(NULL表示全局预设风格)';
-- 验证迁移结果
SELECT
COUNT(*) as total_styles,
COUNT(user_id) as user_styles,
COUNT(*) FILTER (WHERE user_id IS NULL) as preset_styles
FROM writing_styles;
@@ -1,25 +0,0 @@
-- Migration: Add outline_mode to projects table
-- Description: 为项目表添加大纲模式字段,支持一对一和一对多两种模式
-- Date: 2025-11-27
-- 1. 添加 outline_mode 字段
ALTER TABLE projects
ADD COLUMN outline_mode VARCHAR(20) NOT NULL DEFAULT 'one-to-many';
-- 2. 添加检查约束,确保只能是两个有效值之一
ALTER TABLE projects
ADD CONSTRAINT check_outline_mode
CHECK (outline_mode IN ('one-to-one', 'one-to-many'));
-- 3. 创建索引以提高查询性能
CREATE INDEX idx_projects_outline_mode ON projects(outline_mode);
-- 4. 为现有项目设置默认模式为一对多(细化模式)
-- 这是因为现有项目大多使用展开功能
UPDATE projects SET outline_mode = 'one-to-many' WHERE outline_mode IS NULL;
-- 5. 添加注释
COMMENT ON COLUMN projects.outline_mode IS '大纲章节模式: one-to-one(传统模式,1大纲→1章节) 或 one-to-many(细化模式,1大纲→N章节)';
-- 验证迁移结果
-- SELECT id, title, outline_mode FROM projects LIMIT 10;
@@ -1,34 +0,0 @@
-- 创建提示词模板表
CREATE TABLE IF NOT EXISTS prompt_templates (
id VARCHAR(36) PRIMARY KEY,
user_id VARCHAR(50) NOT NULL,
template_key VARCHAR(100) NOT NULL,
template_name VARCHAR(200) NOT NULL,
template_content TEXT NOT NULL,
description TEXT,
category VARCHAR(50),
parameters TEXT,
is_active BOOLEAN DEFAULT TRUE,
is_system_default BOOLEAN DEFAULT FALSE,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
CONSTRAINT uk_user_template UNIQUE (user_id, template_key)
);
-- 创建索引
CREATE INDEX IF NOT EXISTS idx_user_template ON prompt_templates(user_id, template_key);
CREATE INDEX IF NOT EXISTS idx_user_id ON prompt_templates(user_id);
CREATE INDEX IF NOT EXISTS idx_category ON prompt_templates(category);
-- 添加注释
COMMENT ON TABLE prompt_templates IS '提示词模板表';
COMMENT ON COLUMN prompt_templates.user_id IS '用户ID';
COMMENT ON COLUMN prompt_templates.template_key IS '模板键名';
COMMENT ON COLUMN prompt_templates.template_name IS '模板显示名称';
COMMENT ON COLUMN prompt_templates.template_content IS '模板内容';
COMMENT ON COLUMN prompt_templates.description IS '模板描述';
COMMENT ON COLUMN prompt_templates.category IS '模板分类';
COMMENT ON COLUMN prompt_templates.parameters IS '模板参数定义(JSON)';
COMMENT ON COLUMN prompt_templates.is_active IS '是否启用';
COMMENT ON COLUMN prompt_templates.is_system_default IS '是否为系统默认模板';
+21 -6
View File
@@ -36,7 +36,8 @@ except ImportError:
from psycopg2 import sql
from psycopg2.extensions import ISOLATION_LEVEL_AUTOCOMMIT
from app.database import init_db
# 注意: 表结构应由 Alembic 管理
from pathlib import Path
# 设置日志
logging.basicConfig(
@@ -277,12 +278,26 @@ class PostgreSQLSetup:
return False
async def initialize_tables(self) -> bool:
"""初始化数据库表结构"""
"""初始化数据库表结构(使用 Alembic"""
try:
logger.info(f"📋 初始化数据库表结构...")
await init_db('system')
logger.info(f"✅ 表结构初始化成功")
return True
import subprocess
logger.info(f"📋 使用 Alembic 初始化数据库表结构...")
# 运行 Alembic 迁移
result = subprocess.run(
["alembic", "upgrade", "head"],
capture_output=True,
text=True,
cwd=Path(__file__).parent.parent
)
if result.returncode == 0:
logger.info(f"✅ 表结构初始化成功")
return True
else:
logger.error(f"❌ Alembic 迁移失败: {result.stderr}")
return False
except Exception as e:
logger.error(f"❌ 初始化表结构失败: {e}")
return False