update:1.重构项目数据库初始化和迁移逻辑,使用Alembic数据库管理工具

This commit is contained in:
xiamuceer
2025-12-26 15:05:48 +08:00
parent a5788e75ae
commit f32e51b594
39 changed files with 2249 additions and 2037 deletions
+2 -7
View File
@@ -8,7 +8,7 @@ from datetime import datetime
import hashlib
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.database import get_db, init_db
from app.database import get_db
from app.models.user import User
from app.user_manager import user_manager
from app.user_password import password_manager
@@ -160,12 +160,7 @@ async def create_user(
password=data.password
)
# 初始化用户数据库
try:
await init_db(new_user.user_id)
logger.info(f"用户 {new_user.user_id} 数据库初始化成功")
except Exception as e:
logger.error(f"用户 {new_user.user_id} 数据库初始化失败: {e}")
# Settings 将在首次访问设置页面时自动创建(延迟初始化)
logger.info(f"管理员 {admin.user_id} 创建了新用户 {new_user.user_id} ({data.username})")
+3 -20
View File
@@ -10,7 +10,6 @@ from datetime import datetime, timedelta, timezone
from app.services.oauth_service import LinuxDOOAuthService
from app.user_manager import user_manager
from app.user_password import password_manager
from app.database import init_db
from app.logger import get_logger
from app.config import settings
@@ -152,12 +151,7 @@ async def local_login(request: LocalLoginRequest, response: Response):
logger.info(f"[本地登录] 管理员用户 {user.user_id} 登录成功")
# 初始化用户数据库
try:
await init_db(user.user_id)
logger.info(f"本地用户 {user.user_id} 数据库初始化成功")
except Exception as e:
logger.error(f"本地用户 {user.user_id} 数据库初始化失败: {e}")
# Settings 将在首次访问设置页面时自动创建(延迟初始化)
# 设置 Cookie2小时有效)
max_age = settings.SESSION_EXPIRE_MINUTES * 60
@@ -261,13 +255,7 @@ async def _handle_callback(
default_password = await password_manager.set_password(user.user_id, username)
logger.info(f"用户 {user.user_id} ({username}) 自动绑定默认密码: {default_password}")
# 3.5. 初始化用户数据库(如果是新用户
try:
await init_db(user.user_id)
logger.info(f"用户 {user.user_id} 数据库初始化成功")
except Exception as e:
logger.error(f"用户 {user.user_id} 数据库初始化失败: {e}")
# 继续执行,不影响登录流程(可能是已存在的用户)
# Settings 将在首次访问设置页面时自动创建(延迟初始化
# 4. 设置 Cookie 并重定向到前端回调页面
# 使用配置的前端URL,支持不同的部署环境
@@ -495,12 +483,7 @@ async def bind_account_login(request: LocalLoginRequest, response: Response):
if not is_valid:
raise HTTPException(status_code=401, detail="用户名或密码错误")
# 初始化用户数据库
try:
await init_db(target_user.user_id)
logger.info(f"绑定账号用户 {target_user.user_id} 数据库初始化成功")
except Exception as e:
logger.error(f"绑定账号用户 {target_user.user_id} 数据库初始化失败: {e}")
# Settings 将在首次访问设置页面时自动创建(延迟初始化)
# 设置 Cookie2小时有效)
max_age = settings.SESSION_EXPIRE_MINUTES * 60
+17 -139
View File
@@ -21,7 +21,7 @@ from app.models import (
Settings, WritingStyle, ProjectDefaultStyle,
RelationshipType, CharacterRelationship, Organization, OrganizationMember,
StoryMemory, PlotAnalysis, AnalysisTask, BatchGenerationTask,
RegenerationTask, Career, CharacterCareer
RegenerationTask, Career, CharacterCareer, User, MCPPlugin, PromptTemplate
)
# 引擎缓存:每个用户一个引擎
@@ -223,147 +223,25 @@ async def get_db(request: Request):
except:
pass
async def _init_relationship_types(user_id: str):
"""为指定用户初始化预置的关系类型数据
async def init_db(user_id: str = None):
"""
初始化数据库(已弃用)
⚠️ 此函数已弃用,仅保留用于向后兼容
新的最佳实践:
- 表结构管理: 使用 'alembic upgrade head'
- 用户配置: Settings 在首次访问时自动创建(延迟初始化)
Args:
user_id: 用户ID
user_id: 用户ID (已不再使用)
"""
from app.models.relationship import RelationshipType
relationship_types = [
{"name": "父亲", "category": "family", "reverse_name": "子女", "intimacy_range": "high", "icon": "👨"},
{"name": "母亲", "category": "family", "reverse_name": "子女", "intimacy_range": "high", "icon": "👩"},
{"name": "兄弟", "category": "family", "reverse_name": "兄弟", "intimacy_range": "high", "icon": "👬"},
{"name": "姐妹", "category": "family", "reverse_name": "姐妹", "intimacy_range": "high", "icon": "👭"},
{"name": "子女", "category": "family", "reverse_name": "父母", "intimacy_range": "high", "icon": "👶"},
{"name": "配偶", "category": "family", "reverse_name": "配偶", "intimacy_range": "high", "icon": "💑"},
{"name": "恋人", "category": "family", "reverse_name": "恋人", "intimacy_range": "high", "icon": "💕"},
{"name": "师父", "category": "social", "reverse_name": "徒弟", "intimacy_range": "high", "icon": "🎓"},
{"name": "徒弟", "category": "social", "reverse_name": "师父", "intimacy_range": "high", "icon": "📚"},
{"name": "朋友", "category": "social", "reverse_name": "朋友", "intimacy_range": "medium", "icon": "🤝"},
{"name": "同学", "category": "social", "reverse_name": "同学", "intimacy_range": "medium", "icon": "🎒"},
{"name": "邻居", "category": "social", "reverse_name": "邻居", "intimacy_range": "low", "icon": "🏘️"},
{"name": "知己", "category": "social", "reverse_name": "知己", "intimacy_range": "high", "icon": "💙"},
{"name": "上司", "category": "professional", "reverse_name": "下属", "intimacy_range": "low", "icon": "👔"},
{"name": "下属", "category": "professional", "reverse_name": "上司", "intimacy_range": "low", "icon": "💼"},
{"name": "同事", "category": "professional", "reverse_name": "同事", "intimacy_range": "medium", "icon": "🤵"},
{"name": "合作伙伴", "category": "professional", "reverse_name": "合作伙伴", "intimacy_range": "medium", "icon": "🤜🤛"},
{"name": "敌人", "category": "hostile", "reverse_name": "敌人", "intimacy_range": "low", "icon": "⚔️"},
{"name": "仇人", "category": "hostile", "reverse_name": "仇人", "intimacy_range": "low", "icon": "💢"},
{"name": "竞争对手", "category": "hostile", "reverse_name": "竞争对手", "intimacy_range": "low", "icon": "🎯"},
{"name": "宿敌", "category": "hostile", "reverse_name": "宿敌", "intimacy_range": "low", "icon": ""},
]
try:
engine = await get_engine(user_id)
AsyncSessionLocal = async_sessionmaker(
engine,
class_=AsyncSession,
expire_on_commit=False
)
async with AsyncSessionLocal() as session:
result = await session.execute(select(RelationshipType))
existing = result.scalars().first()
if existing:
logger.info(f"用户 {user_id} 的关系类型数据已存在,跳过初始化")
return
logger.info(f"开始为用户 {user_id} 插入关系类型数据...")
for rt_data in relationship_types:
relationship_type = RelationshipType(**rt_data)
session.add(relationship_type)
await session.commit()
logger.info(f"成功为用户 {user_id} 插入 {len(relationship_types)} 条关系类型数据")
except Exception as e:
logger.error(f"用户 {user_id} 初始化关系类型数据失败: {str(e)}", exc_info=True)
raise
async def _init_global_writing_styles(user_id: str):
"""为指定用户初始化全局预设写作风格
全局预设风格的 project_id 为 NULL,所有用户共享
只在第一次创建数据库时插入一次
Args:
user_id: 用户ID
"""
from app.models.writing_style import WritingStyle
from app.services.prompt_service import WritingStyleManager
try:
engine = await get_engine(user_id)
AsyncSessionLocal = async_sessionmaker(
engine,
class_=AsyncSession,
expire_on_commit=False
)
async with AsyncSessionLocal() as session:
# 检查是否已存在全局预设风格
result = await session.execute(
select(WritingStyle).where(WritingStyle.user_id.is_(None))
)
existing = result.scalars().first()
if existing:
logger.info(f"用户 {user_id} 的全局预设风格已存在,跳过初始化")
return
logger.info(f"开始为用户 {user_id} 插入全局预设写作风格...")
# 获取所有预设风格配置
presets = WritingStyleManager.get_all_presets()
for index, (preset_id, preset_data) in enumerate(presets.items(), start=1):
style = WritingStyle(
user_id=None, # NULL 表示全局预设
name=preset_data["name"],
style_type="preset",
preset_id=preset_id,
description=preset_data["description"],
prompt_content=preset_data["prompt_content"],
order_index=index
)
session.add(style)
await session.commit()
logger.info(f"成功为用户 {user_id} 插入 {len(presets)} 个全局预设写作风格")
except Exception as e:
logger.error(f"用户 {user_id} 初始化全局预设写作风格失败: {str(e)}", exc_info=True)
raise
async def init_db(user_id: str):
"""初始化指定用户的数据库,创建所有表并插入预置数据
Args:
user_id: 用户ID
"""
try:
logger.info(f"开始初始化用户 {user_id} 的数据库...")
engine = await get_engine(user_id)
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
await _init_relationship_types(user_id)
await _init_global_writing_styles(user_id)
logger.info(f"用户 {user_id} 的数据库初始化成功")
except Exception as e:
logger.error(f"用户 {user_id} 的数据库初始化失败: {str(e)}", exc_info=True)
raise
logger.warning(
"⚠️ init_db() 已弃用且无实际作用!\n"
" - 表结构: 由 Alembic 管理\n"
" - 用户配置: Settings API 自动创建\n"
" 建议移除此调用"
)
async def close_db():
+1 -17
View File
@@ -27,23 +27,7 @@ logger = get_logger(__name__)
@asynccontextmanager
async def lifespan(app: FastAPI):
"""应用生命周期管理"""
logger.info("应用启动,初始化数据库表结构...")
# 在应用启动时初始化数据库表结构
try:
from app.database import get_engine, Base
# 使用全局引擎创建所有表
engine = await get_engine("_global_init_")
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
logger.info("✅ 数据库表结构初始化成功")
except Exception as e:
logger.error(f"❌ 数据库表结构初始化失败: {str(e)}", exc_info=True)
# 不阻止应用启动,允许在后续操作中重试
logger.info("应用启动完成,等待用户登录...")
logger.info("应用启动完成")
yield
+16 -6
View File
@@ -149,13 +149,23 @@ class MCPPluginRegistry:
session.status = "active"
logger.info(f"✅ 会话 {plugin_id} 恢复正常")
# 检查长时间无活动的会话
# 检查即将过期的会话(最后1分钟提醒)
idle_time = time.time() - session.last_access
if idle_time > mcp_config.IDLE_TIMEOUT_SECONDS:
logger.info(
f"💤 会话 {plugin_id} 空闲 {idle_time/60:.1f} 分钟,"
f"准备清理"
)
time_until_expiry = self._client_ttl - idle_time
# 仅在最后1分钟(60秒)内提醒一次
if 0 < time_until_expiry <= 60:
# 使用会话属性避免重复提醒
if not hasattr(session, '_expiry_warned') or not session._expiry_warned:
logger.warning(
f"⏰ 会话 {plugin_id} 即将过期 "
f"(剩余 {time_until_expiry:.0f} 秒)"
)
session._expiry_warned = True
elif time_until_expiry > 60:
# 重置警告标志(如果会话被重新使用)
if hasattr(session, '_expiry_warned'):
session._expiry_warned = False
async def _get_user_lock(self, user_id: str) -> asyncio.Lock:
"""
+3 -1
View File
@@ -15,6 +15,7 @@ from app.models.mcp_plugin import MCPPlugin
from app.models.user import User, UserPassword
from app.models.regeneration_task import RegenerationTask
from app.models.career import Career, CharacterCareer
from app.models.prompt_template import PromptTemplate
__all__ = [
"Project",
@@ -38,5 +39,6 @@ __all__ = [
"UserPassword",
"RegenerationTask",
"Career",
"CharacterCareer"
"CharacterCareer",
"PromptTemplate"
]
+5 -4
View File
@@ -4,7 +4,8 @@ from openai import AsyncOpenAI
from anthropic import AsyncAnthropic
from app.config import settings as app_settings
from app.logger import get_logger
from app.mcp.adapters import UniversalMCPAdapter, PromptInjectionAdapter
from app.mcp.adapters import PromptInjectionAdapter
from app.mcp.adapters.universal import universal_mcp_adapter
import httpx
import json
import hashlib
@@ -145,11 +146,11 @@ class AIService:
self.default_temperature = default_temperature or app_settings.default_temperature
self.default_max_tokens = default_max_tokens or app_settings.default_max_tokens
# 初始化MCP适配器
# 使用全局MCP适配器单例
self.enable_mcp_adapter = enable_mcp_adapter
if enable_mcp_adapter:
self.mcp_adapter = UniversalMCPAdapter()
logger.info("✅ MCP通用适配器已启用")
self.mcp_adapter = universal_mcp_adapter
logger.info("✅ MCP通用适配器已启用(使用全局单例)")
else:
self.mcp_adapter = None
logger.info("⚠️ MCP适配器已禁用")
+43 -12
View File
@@ -18,21 +18,52 @@ from pathlib import Path
if 'SENTENCE_TRANSFORMERS_HOME' not in os.environ:
# 根据运行环境确定模型目录
if getattr(sys, 'frozen', False):
# PyInstaller 打包后
base_dir = Path(sys.executable).parent
# PyInstaller 打包后 - 需要检查多个可能的位置
exe_dir = Path(sys.executable).parent
# 检查顺序:
# 1. _MEIPASS/backend/embedding (临时解压目录)
# 2. exe同级/_internal/backend/embedding
# 3. exe同级/backend/embedding
possible_paths = []
if hasattr(sys, '_MEIPASS'):
possible_paths.append(Path(sys._MEIPASS) / 'backend' / 'embedding')
possible_paths.extend([
exe_dir / '_internal' / 'backend' / 'embedding',
exe_dir / 'backend' / 'embedding',
exe_dir / '_internal' / 'embedding',
exe_dir / 'embedding'
])
model_dir = None
for path in possible_paths:
if path.exists():
model_dir = path
logger.info(f"🔧 找到打包环境模型目录: {model_dir}")
break
if model_dir:
os.environ['SENTENCE_TRANSFORMERS_HOME'] = str(model_dir)
else:
# 最后降级方案
fallback_dir = exe_dir / 'embedding'
os.environ['SENTENCE_TRANSFORMERS_HOME'] = str(fallback_dir)
logger.warning(f"⚠️ 未找到预打包模型,使用降级目录: {fallback_dir}")
logger.warning(f" 检查过的路径: {[str(p) for p in possible_paths]}")
else:
# 开发模式,从当前文件位置向上找到项目根目录
base_dir = Path(__file__).parent.parent.parent
model_dir = base_dir / 'backend' / 'embedding'
if model_dir.exists():
os.environ['SENTENCE_TRANSFORMERS_HOME'] = str(model_dir)
logger.info(f"🔧 设置模型目录: {model_dir}")
else:
# 降级到项目根目录的 embedding
fallback_dir = base_dir / 'embedding'
os.environ['SENTENCE_TRANSFORMERS_HOME'] = str(fallback_dir)
logger.info(f"🔧 使用降级模型目录: {fallback_dir}")
model_dir = base_dir / 'backend' / 'embedding'
if model_dir.exists():
os.environ['SENTENCE_TRANSFORMERS_HOME'] = str(model_dir)
logger.info(f"🔧 设置开发环境模型目录: {model_dir}")
else:
# 降级到项目根目录的 embedding
fallback_dir = base_dir / 'embedding'
os.environ['SENTENCE_TRANSFORMERS_HOME'] = str(fallback_dir)
logger.info(f"🔧 使用降级模型目录: {fallback_dir}")
class MemoryService: