2025-10-30 11:14:43 +08:00
|
|
|
|
"""数据库连接和会话管理 - 支持多用户数据隔离"""
|
|
|
|
|
|
import asyncio
|
|
|
|
|
|
from typing import Dict, Any
|
|
|
|
|
|
from datetime import datetime
|
|
|
|
|
|
from sqlalchemy import select, text
|
|
|
|
|
|
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker
|
|
|
|
|
|
from sqlalchemy.orm import declarative_base
|
|
|
|
|
|
from sqlalchemy.pool import StaticPool
|
|
|
|
|
|
from fastapi import Request, HTTPException
|
|
|
|
|
|
from app.config import settings
|
|
|
|
|
|
from app.logger import get_logger
|
|
|
|
|
|
|
|
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
# 创建基类
|
|
|
|
|
|
Base = declarative_base()
|
|
|
|
|
|
|
2025-11-04 14:38:59 +08:00
|
|
|
|
# 导入所有模型,确保 Base.metadata 能够发现它们
|
|
|
|
|
|
# 这必须在 Base 创建之后、init_db 之前导入
|
|
|
|
|
|
from app.models import (
|
|
|
|
|
|
Project, Outline, Character, Chapter, GenerationHistory,
|
|
|
|
|
|
Settings, WritingStyle, ProjectDefaultStyle,
|
|
|
|
|
|
RelationshipType, CharacterRelationship, Organization, OrganizationMember,
|
|
|
|
|
|
StoryMemory, PlotAnalysis, AnalysisTask
|
|
|
|
|
|
)
|
|
|
|
|
|
|
2025-10-30 11:14:43 +08:00
|
|
|
|
# 引擎缓存:每个用户一个引擎
|
|
|
|
|
|
_engine_cache: Dict[str, Any] = {}
|
|
|
|
|
|
|
|
|
|
|
|
# 锁管理:用于保护引擎创建过程
|
|
|
|
|
|
_engine_locks: Dict[str, asyncio.Lock] = {}
|
|
|
|
|
|
_cache_lock = asyncio.Lock()
|
|
|
|
|
|
|
|
|
|
|
|
# 会话统计(用于监控连接泄漏)
|
|
|
|
|
|
_session_stats = {
|
|
|
|
|
|
"created": 0,
|
|
|
|
|
|
"closed": 0,
|
|
|
|
|
|
"active": 0,
|
|
|
|
|
|
"errors": 0,
|
|
|
|
|
|
"generator_exits": 0,
|
|
|
|
|
|
"last_check": None
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def get_engine(user_id: str):
|
|
|
|
|
|
"""获取或创建用户专属的数据库引擎(线程安全)
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
user_id: 用户ID
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
用户专属的异步引擎
|
|
|
|
|
|
"""
|
|
|
|
|
|
if user_id in _engine_cache:
|
|
|
|
|
|
return _engine_cache[user_id]
|
|
|
|
|
|
|
|
|
|
|
|
async with _cache_lock:
|
|
|
|
|
|
if user_id not in _engine_locks:
|
|
|
|
|
|
_engine_locks[user_id] = asyncio.Lock()
|
|
|
|
|
|
user_lock = _engine_locks[user_id]
|
|
|
|
|
|
|
|
|
|
|
|
async with user_lock:
|
|
|
|
|
|
if user_id not in _engine_cache:
|
|
|
|
|
|
db_url = f"sqlite+aiosqlite:///data/ai_story_user_{user_id}.db"
|
|
|
|
|
|
engine = create_async_engine(
|
|
|
|
|
|
db_url,
|
|
|
|
|
|
echo=False,
|
|
|
|
|
|
future=True,
|
|
|
|
|
|
poolclass=StaticPool,
|
|
|
|
|
|
pool_pre_ping=True,
|
|
|
|
|
|
pool_recycle=3600,
|
|
|
|
|
|
connect_args={
|
|
|
|
|
|
"timeout": 30,
|
|
|
|
|
|
"check_same_thread": False
|
|
|
|
|
|
}
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
async with engine.begin() as conn:
|
|
|
|
|
|
await conn.execute(text("PRAGMA journal_mode=WAL"))
|
|
|
|
|
|
await conn.execute(text("PRAGMA synchronous=NORMAL"))
|
|
|
|
|
|
await conn.execute(text("PRAGMA cache_size=-64000"))
|
|
|
|
|
|
await conn.execute(text("PRAGMA temp_store=MEMORY"))
|
|
|
|
|
|
await conn.execute(text("PRAGMA busy_timeout=5000"))
|
|
|
|
|
|
|
|
|
|
|
|
logger.info(f"✅ 用户 {user_id} 的数据库已优化(WAL模式 + 64MB缓存)")
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
logger.warning(f"⚠️ 用户 {user_id} 数据库优化失败: {str(e)}")
|
|
|
|
|
|
_engine_cache[user_id] = engine
|
|
|
|
|
|
logger.info(f"为用户 {user_id} 创建数据库引擎")
|
|
|
|
|
|
|
|
|
|
|
|
return _engine_cache[user_id]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def get_db(request: Request):
|
|
|
|
|
|
"""获取数据库会话的依赖函数
|
|
|
|
|
|
|
|
|
|
|
|
从 request.state.user_id 获取用户ID,然后返回该用户的数据库会话
|
|
|
|
|
|
"""
|
|
|
|
|
|
user_id = getattr(request.state, "user_id", None)
|
|
|
|
|
|
|
|
|
|
|
|
if not user_id:
|
|
|
|
|
|
raise HTTPException(status_code=401, detail="未登录或用户ID缺失")
|
|
|
|
|
|
|
|
|
|
|
|
engine = await get_engine(user_id)
|
|
|
|
|
|
|
|
|
|
|
|
AsyncSessionLocal = async_sessionmaker(
|
|
|
|
|
|
engine,
|
|
|
|
|
|
class_=AsyncSession,
|
|
|
|
|
|
expire_on_commit=False
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
session = AsyncSessionLocal()
|
|
|
|
|
|
session_id = id(session)
|
|
|
|
|
|
|
|
|
|
|
|
global _session_stats
|
|
|
|
|
|
_session_stats["created"] += 1
|
|
|
|
|
|
_session_stats["active"] += 1
|
|
|
|
|
|
|
|
|
|
|
|
logger.debug(f"📊 会话创建 [User:{user_id}][ID:{session_id}] - 活跃:{_session_stats['active']}, 总创建:{_session_stats['created']}, 总关闭:{_session_stats['closed']}")
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
yield session
|
|
|
|
|
|
if session.in_transaction():
|
|
|
|
|
|
await session.rollback()
|
|
|
|
|
|
except GeneratorExit:
|
|
|
|
|
|
_session_stats["generator_exits"] += 1
|
|
|
|
|
|
logger.warning(f"⚠️ GeneratorExit [User:{user_id}][ID:{session_id}] - SSE连接断开(总计:{_session_stats['generator_exits']}次)")
|
|
|
|
|
|
try:
|
|
|
|
|
|
if session.in_transaction():
|
|
|
|
|
|
await session.rollback()
|
|
|
|
|
|
logger.info(f"✅ 事务已回滚 [User:{user_id}][ID:{session_id}](GeneratorExit)")
|
|
|
|
|
|
except Exception as rollback_error:
|
|
|
|
|
|
_session_stats["errors"] += 1
|
|
|
|
|
|
logger.error(f"❌ GeneratorExit回滚失败 [User:{user_id}][ID:{session_id}]: {str(rollback_error)}")
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
_session_stats["errors"] += 1
|
|
|
|
|
|
logger.error(f"❌ 会话异常 [User:{user_id}][ID:{session_id}]: {str(e)}")
|
|
|
|
|
|
try:
|
|
|
|
|
|
if session.in_transaction():
|
|
|
|
|
|
await session.rollback()
|
|
|
|
|
|
logger.info(f"✅ 事务已回滚 [User:{user_id}][ID:{session_id}](异常)")
|
|
|
|
|
|
except Exception as rollback_error:
|
|
|
|
|
|
logger.error(f"❌ 异常回滚失败 [User:{user_id}][ID:{session_id}]: {str(rollback_error)}")
|
|
|
|
|
|
raise
|
|
|
|
|
|
finally:
|
|
|
|
|
|
try:
|
|
|
|
|
|
if session.in_transaction():
|
|
|
|
|
|
await session.rollback()
|
|
|
|
|
|
logger.warning(f"⚠️ finally中发现未提交事务 [User:{user_id}][ID:{session_id}],已回滚")
|
|
|
|
|
|
|
|
|
|
|
|
await session.close()
|
|
|
|
|
|
|
|
|
|
|
|
_session_stats["closed"] += 1
|
|
|
|
|
|
_session_stats["active"] -= 1
|
|
|
|
|
|
_session_stats["last_check"] = datetime.now().isoformat()
|
|
|
|
|
|
|
|
|
|
|
|
logger.debug(f"📊 会话关闭 [User:{user_id}][ID:{session_id}] - 活跃:{_session_stats['active']}, 总创建:{_session_stats['created']}, 总关闭:{_session_stats['closed']}, 错误:{_session_stats['errors']}")
|
|
|
|
|
|
|
2025-10-30 22:01:10 +08:00
|
|
|
|
if _session_stats["active"] > 100:
|
2025-10-30 11:14:43 +08:00
|
|
|
|
logger.warning(f"🚨 活跃会话数过多: {_session_stats['active']},可能存在连接泄漏!")
|
|
|
|
|
|
elif _session_stats["active"] < 0:
|
|
|
|
|
|
logger.error(f"🚨 活跃会话数异常: {_session_stats['active']},统计可能不准确!")
|
|
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
_session_stats["errors"] += 1
|
|
|
|
|
|
logger.error(f"❌ 关闭会话时出错 [User:{user_id}][ID:{session_id}]: {str(e)}", exc_info=True)
|
|
|
|
|
|
try:
|
|
|
|
|
|
await session.close()
|
|
|
|
|
|
except:
|
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
async def _init_relationship_types(user_id: str):
|
|
|
|
|
|
"""为指定用户初始化预置的关系类型数据
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
user_id: 用户ID
|
|
|
|
|
|
"""
|
|
|
|
|
|
from app.models.relationship import RelationshipType
|
|
|
|
|
|
|
|
|
|
|
|
relationship_types = [
|
|
|
|
|
|
{"name": "父亲", "category": "family", "reverse_name": "子女", "intimacy_range": "high", "icon": "👨"},
|
|
|
|
|
|
{"name": "母亲", "category": "family", "reverse_name": "子女", "intimacy_range": "high", "icon": "👩"},
|
|
|
|
|
|
{"name": "兄弟", "category": "family", "reverse_name": "兄弟", "intimacy_range": "high", "icon": "👬"},
|
|
|
|
|
|
{"name": "姐妹", "category": "family", "reverse_name": "姐妹", "intimacy_range": "high", "icon": "👭"},
|
|
|
|
|
|
{"name": "子女", "category": "family", "reverse_name": "父母", "intimacy_range": "high", "icon": "👶"},
|
|
|
|
|
|
{"name": "配偶", "category": "family", "reverse_name": "配偶", "intimacy_range": "high", "icon": "💑"},
|
|
|
|
|
|
{"name": "恋人", "category": "family", "reverse_name": "恋人", "intimacy_range": "high", "icon": "💕"},
|
|
|
|
|
|
|
|
|
|
|
|
{"name": "师父", "category": "social", "reverse_name": "徒弟", "intimacy_range": "high", "icon": "🎓"},
|
|
|
|
|
|
{"name": "徒弟", "category": "social", "reverse_name": "师父", "intimacy_range": "high", "icon": "📚"},
|
|
|
|
|
|
{"name": "朋友", "category": "social", "reverse_name": "朋友", "intimacy_range": "medium", "icon": "🤝"},
|
|
|
|
|
|
{"name": "同学", "category": "social", "reverse_name": "同学", "intimacy_range": "medium", "icon": "🎒"},
|
|
|
|
|
|
{"name": "邻居", "category": "social", "reverse_name": "邻居", "intimacy_range": "low", "icon": "🏘️"},
|
|
|
|
|
|
{"name": "知己", "category": "social", "reverse_name": "知己", "intimacy_range": "high", "icon": "💙"},
|
|
|
|
|
|
|
|
|
|
|
|
{"name": "上司", "category": "professional", "reverse_name": "下属", "intimacy_range": "low", "icon": "👔"},
|
|
|
|
|
|
{"name": "下属", "category": "professional", "reverse_name": "上司", "intimacy_range": "low", "icon": "💼"},
|
|
|
|
|
|
{"name": "同事", "category": "professional", "reverse_name": "同事", "intimacy_range": "medium", "icon": "🤵"},
|
|
|
|
|
|
{"name": "合作伙伴", "category": "professional", "reverse_name": "合作伙伴", "intimacy_range": "medium", "icon": "🤜🤛"},
|
|
|
|
|
|
|
|
|
|
|
|
{"name": "敌人", "category": "hostile", "reverse_name": "敌人", "intimacy_range": "low", "icon": "⚔️"},
|
|
|
|
|
|
{"name": "仇人", "category": "hostile", "reverse_name": "仇人", "intimacy_range": "low", "icon": "💢"},
|
|
|
|
|
|
{"name": "竞争对手", "category": "hostile", "reverse_name": "竞争对手", "intimacy_range": "low", "icon": "🎯"},
|
|
|
|
|
|
{"name": "宿敌", "category": "hostile", "reverse_name": "宿敌", "intimacy_range": "low", "icon": "⚡"},
|
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
engine = await get_engine(user_id)
|
|
|
|
|
|
AsyncSessionLocal = async_sessionmaker(
|
|
|
|
|
|
engine,
|
|
|
|
|
|
class_=AsyncSession,
|
|
|
|
|
|
expire_on_commit=False
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
async with AsyncSessionLocal() as session:
|
|
|
|
|
|
result = await session.execute(select(RelationshipType))
|
|
|
|
|
|
existing = result.scalars().first()
|
|
|
|
|
|
|
|
|
|
|
|
if existing:
|
|
|
|
|
|
logger.info(f"用户 {user_id} 的关系类型数据已存在,跳过初始化")
|
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
|
|
logger.info(f"开始为用户 {user_id} 插入关系类型数据...")
|
|
|
|
|
|
for rt_data in relationship_types:
|
|
|
|
|
|
relationship_type = RelationshipType(**rt_data)
|
|
|
|
|
|
session.add(relationship_type)
|
|
|
|
|
|
|
|
|
|
|
|
await session.commit()
|
|
|
|
|
|
logger.info(f"成功为用户 {user_id} 插入 {len(relationship_types)} 条关系类型数据")
|
|
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
logger.error(f"用户 {user_id} 初始化关系类型数据失败: {str(e)}", exc_info=True)
|
|
|
|
|
|
raise
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2025-10-31 17:23:25 +08:00
|
|
|
|
async def _init_global_writing_styles(user_id: str):
|
|
|
|
|
|
"""为指定用户初始化全局预设写作风格
|
|
|
|
|
|
|
|
|
|
|
|
全局预设风格的 project_id 为 NULL,所有用户共享
|
|
|
|
|
|
只在第一次创建数据库时插入一次
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
user_id: 用户ID
|
|
|
|
|
|
"""
|
|
|
|
|
|
from app.models.writing_style import WritingStyle
|
|
|
|
|
|
from app.services.prompt_service import WritingStyleManager
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
engine = await get_engine(user_id)
|
|
|
|
|
|
AsyncSessionLocal = async_sessionmaker(
|
|
|
|
|
|
engine,
|
|
|
|
|
|
class_=AsyncSession,
|
|
|
|
|
|
expire_on_commit=False
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
async with AsyncSessionLocal() as session:
|
|
|
|
|
|
# 检查是否已存在全局预设风格
|
|
|
|
|
|
result = await session.execute(
|
|
|
|
|
|
select(WritingStyle).where(WritingStyle.project_id.is_(None))
|
|
|
|
|
|
)
|
|
|
|
|
|
existing = result.scalars().first()
|
|
|
|
|
|
|
|
|
|
|
|
if existing:
|
|
|
|
|
|
logger.info(f"用户 {user_id} 的全局预设风格已存在,跳过初始化")
|
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
|
|
logger.info(f"开始为用户 {user_id} 插入全局预设写作风格...")
|
|
|
|
|
|
|
|
|
|
|
|
# 获取所有预设风格配置
|
|
|
|
|
|
presets = WritingStyleManager.get_all_presets()
|
|
|
|
|
|
|
|
|
|
|
|
for index, (preset_id, preset_data) in enumerate(presets.items(), start=1):
|
|
|
|
|
|
style = WritingStyle(
|
|
|
|
|
|
project_id=None, # NULL 表示全局预设
|
|
|
|
|
|
name=preset_data["name"],
|
|
|
|
|
|
style_type="preset",
|
|
|
|
|
|
preset_id=preset_id,
|
|
|
|
|
|
description=preset_data["description"],
|
|
|
|
|
|
prompt_content=preset_data["prompt_content"],
|
|
|
|
|
|
order_index=index
|
|
|
|
|
|
)
|
|
|
|
|
|
session.add(style)
|
|
|
|
|
|
|
|
|
|
|
|
await session.commit()
|
|
|
|
|
|
logger.info(f"成功为用户 {user_id} 插入 {len(presets)} 个全局预设写作风格")
|
|
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
logger.error(f"用户 {user_id} 初始化全局预设写作风格失败: {str(e)}", exc_info=True)
|
|
|
|
|
|
raise
|
|
|
|
|
|
|
|
|
|
|
|
|
2025-10-30 11:14:43 +08:00
|
|
|
|
async def init_db(user_id: str):
|
|
|
|
|
|
"""初始化指定用户的数据库,创建所有表并插入预置数据
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
user_id: 用户ID
|
|
|
|
|
|
"""
|
|
|
|
|
|
try:
|
|
|
|
|
|
logger.info(f"开始初始化用户 {user_id} 的数据库...")
|
|
|
|
|
|
engine = await get_engine(user_id)
|
|
|
|
|
|
|
|
|
|
|
|
async with engine.begin() as conn:
|
|
|
|
|
|
await conn.run_sync(Base.metadata.create_all)
|
|
|
|
|
|
|
|
|
|
|
|
await _init_relationship_types(user_id)
|
2025-10-31 17:23:25 +08:00
|
|
|
|
await _init_global_writing_styles(user_id)
|
2025-10-30 11:14:43 +08:00
|
|
|
|
|
|
|
|
|
|
logger.info(f"用户 {user_id} 的数据库初始化成功")
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
logger.error(f"用户 {user_id} 的数据库初始化失败: {str(e)}", exc_info=True)
|
|
|
|
|
|
raise
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def close_db():
|
|
|
|
|
|
"""关闭所有数据库连接"""
|
|
|
|
|
|
try:
|
|
|
|
|
|
logger.info("正在关闭所有数据库连接...")
|
|
|
|
|
|
for user_id, engine in _engine_cache.items():
|
|
|
|
|
|
await engine.dispose()
|
|
|
|
|
|
logger.info(f"用户 {user_id} 的数据库连接已关闭")
|
|
|
|
|
|
_engine_cache.clear()
|
|
|
|
|
|
logger.info("所有数据库连接已关闭")
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
logger.error(f"关闭数据库连接失败: {str(e)}", exc_info=True)
|
|
|
|
|
|
raise
|