Files
MuMuAINovel/backend/app/database.py
T

261 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""数据库连接和会话管理 - 支持多用户数据隔离"""
import asyncio
from typing import Dict, Any
from datetime import datetime
from sqlalchemy import select, text
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker
from sqlalchemy.orm import declarative_base
from sqlalchemy.pool import StaticPool
from fastapi import Request, HTTPException
from app.config import settings
from app.logger import get_logger
logger = get_logger(__name__)
# 创建基类
Base = declarative_base()
# 引擎缓存:每个用户一个引擎
_engine_cache: Dict[str, Any] = {}
# 锁管理:用于保护引擎创建过程
_engine_locks: Dict[str, asyncio.Lock] = {}
_cache_lock = asyncio.Lock()
# 会话统计(用于监控连接泄漏)
_session_stats = {
"created": 0,
"closed": 0,
"active": 0,
"errors": 0,
"generator_exits": 0,
"last_check": None
}
async def get_engine(user_id: str):
"""获取或创建用户专属的数据库引擎(线程安全)
Args:
user_id: 用户ID
Returns:
用户专属的异步引擎
"""
if user_id in _engine_cache:
return _engine_cache[user_id]
async with _cache_lock:
if user_id not in _engine_locks:
_engine_locks[user_id] = asyncio.Lock()
user_lock = _engine_locks[user_id]
async with user_lock:
if user_id not in _engine_cache:
db_url = f"sqlite+aiosqlite:///data/ai_story_user_{user_id}.db"
engine = create_async_engine(
db_url,
echo=False,
future=True,
poolclass=StaticPool,
pool_pre_ping=True,
pool_recycle=3600,
connect_args={
"timeout": 30,
"check_same_thread": False
}
)
try:
async with engine.begin() as conn:
await conn.execute(text("PRAGMA journal_mode=WAL"))
await conn.execute(text("PRAGMA synchronous=NORMAL"))
await conn.execute(text("PRAGMA cache_size=-64000"))
await conn.execute(text("PRAGMA temp_store=MEMORY"))
await conn.execute(text("PRAGMA busy_timeout=5000"))
logger.info(f"✅ 用户 {user_id} 的数据库已优化(WAL模式 + 64MB缓存)")
except Exception as e:
logger.warning(f"⚠️ 用户 {user_id} 数据库优化失败: {str(e)}")
_engine_cache[user_id] = engine
logger.info(f"为用户 {user_id} 创建数据库引擎")
return _engine_cache[user_id]
async def get_db(request: Request):
"""获取数据库会话的依赖函数
从 request.state.user_id 获取用户ID,然后返回该用户的数据库会话
"""
user_id = getattr(request.state, "user_id", None)
if not user_id:
raise HTTPException(status_code=401, detail="未登录或用户ID缺失")
engine = await get_engine(user_id)
AsyncSessionLocal = async_sessionmaker(
engine,
class_=AsyncSession,
expire_on_commit=False
)
session = AsyncSessionLocal()
session_id = id(session)
global _session_stats
_session_stats["created"] += 1
_session_stats["active"] += 1
logger.debug(f"📊 会话创建 [User:{user_id}][ID:{session_id}] - 活跃:{_session_stats['active']}, 总创建:{_session_stats['created']}, 总关闭:{_session_stats['closed']}")
try:
yield session
if session.in_transaction():
await session.rollback()
except GeneratorExit:
_session_stats["generator_exits"] += 1
logger.warning(f"⚠️ GeneratorExit [User:{user_id}][ID:{session_id}] - SSE连接断开(总计:{_session_stats['generator_exits']}次)")
try:
if session.in_transaction():
await session.rollback()
logger.info(f"✅ 事务已回滚 [User:{user_id}][ID:{session_id}]GeneratorExit")
except Exception as rollback_error:
_session_stats["errors"] += 1
logger.error(f"❌ GeneratorExit回滚失败 [User:{user_id}][ID:{session_id}]: {str(rollback_error)}")
except Exception as e:
_session_stats["errors"] += 1
logger.error(f"❌ 会话异常 [User:{user_id}][ID:{session_id}]: {str(e)}")
try:
if session.in_transaction():
await session.rollback()
logger.info(f"✅ 事务已回滚 [User:{user_id}][ID:{session_id}](异常)")
except Exception as rollback_error:
logger.error(f"❌ 异常回滚失败 [User:{user_id}][ID:{session_id}]: {str(rollback_error)}")
raise
finally:
try:
if session.in_transaction():
await session.rollback()
logger.warning(f"⚠️ finally中发现未提交事务 [User:{user_id}][ID:{session_id}],已回滚")
await session.close()
_session_stats["closed"] += 1
_session_stats["active"] -= 1
_session_stats["last_check"] = datetime.now().isoformat()
logger.debug(f"📊 会话关闭 [User:{user_id}][ID:{session_id}] - 活跃:{_session_stats['active']}, 总创建:{_session_stats['created']}, 总关闭:{_session_stats['closed']}, 错误:{_session_stats['errors']}")
if _session_stats["active"] > 100:
logger.warning(f"🚨 活跃会话数过多: {_session_stats['active']},可能存在连接泄漏!")
elif _session_stats["active"] < 0:
logger.error(f"🚨 活跃会话数异常: {_session_stats['active']},统计可能不准确!")
except Exception as e:
_session_stats["errors"] += 1
logger.error(f"❌ 关闭会话时出错 [User:{user_id}][ID:{session_id}]: {str(e)}", exc_info=True)
try:
await session.close()
except:
pass
async def _init_relationship_types(user_id: str):
"""为指定用户初始化预置的关系类型数据
Args:
user_id: 用户ID
"""
from app.models.relationship import RelationshipType
relationship_types = [
{"name": "父亲", "category": "family", "reverse_name": "子女", "intimacy_range": "high", "icon": "👨"},
{"name": "母亲", "category": "family", "reverse_name": "子女", "intimacy_range": "high", "icon": "👩"},
{"name": "兄弟", "category": "family", "reverse_name": "兄弟", "intimacy_range": "high", "icon": "👬"},
{"name": "姐妹", "category": "family", "reverse_name": "姐妹", "intimacy_range": "high", "icon": "👭"},
{"name": "子女", "category": "family", "reverse_name": "父母", "intimacy_range": "high", "icon": "👶"},
{"name": "配偶", "category": "family", "reverse_name": "配偶", "intimacy_range": "high", "icon": "💑"},
{"name": "恋人", "category": "family", "reverse_name": "恋人", "intimacy_range": "high", "icon": "💕"},
{"name": "师父", "category": "social", "reverse_name": "徒弟", "intimacy_range": "high", "icon": "🎓"},
{"name": "徒弟", "category": "social", "reverse_name": "师父", "intimacy_range": "high", "icon": "📚"},
{"name": "朋友", "category": "social", "reverse_name": "朋友", "intimacy_range": "medium", "icon": "🤝"},
{"name": "同学", "category": "social", "reverse_name": "同学", "intimacy_range": "medium", "icon": "🎒"},
{"name": "邻居", "category": "social", "reverse_name": "邻居", "intimacy_range": "low", "icon": "🏘️"},
{"name": "知己", "category": "social", "reverse_name": "知己", "intimacy_range": "high", "icon": "💙"},
{"name": "上司", "category": "professional", "reverse_name": "下属", "intimacy_range": "low", "icon": "👔"},
{"name": "下属", "category": "professional", "reverse_name": "上司", "intimacy_range": "low", "icon": "💼"},
{"name": "同事", "category": "professional", "reverse_name": "同事", "intimacy_range": "medium", "icon": "🤵"},
{"name": "合作伙伴", "category": "professional", "reverse_name": "合作伙伴", "intimacy_range": "medium", "icon": "🤜🤛"},
{"name": "敌人", "category": "hostile", "reverse_name": "敌人", "intimacy_range": "low", "icon": "⚔️"},
{"name": "仇人", "category": "hostile", "reverse_name": "仇人", "intimacy_range": "low", "icon": "💢"},
{"name": "竞争对手", "category": "hostile", "reverse_name": "竞争对手", "intimacy_range": "low", "icon": "🎯"},
{"name": "宿敌", "category": "hostile", "reverse_name": "宿敌", "intimacy_range": "low", "icon": ""},
]
try:
engine = await get_engine(user_id)
AsyncSessionLocal = async_sessionmaker(
engine,
class_=AsyncSession,
expire_on_commit=False
)
async with AsyncSessionLocal() as session:
result = await session.execute(select(RelationshipType))
existing = result.scalars().first()
if existing:
logger.info(f"用户 {user_id} 的关系类型数据已存在,跳过初始化")
return
logger.info(f"开始为用户 {user_id} 插入关系类型数据...")
for rt_data in relationship_types:
relationship_type = RelationshipType(**rt_data)
session.add(relationship_type)
await session.commit()
logger.info(f"成功为用户 {user_id} 插入 {len(relationship_types)} 条关系类型数据")
except Exception as e:
logger.error(f"用户 {user_id} 初始化关系类型数据失败: {str(e)}", exc_info=True)
raise
async def init_db(user_id: str):
"""初始化指定用户的数据库,创建所有表并插入预置数据
Args:
user_id: 用户ID
"""
try:
logger.info(f"开始初始化用户 {user_id} 的数据库...")
engine = await get_engine(user_id)
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
await _init_relationship_types(user_id)
logger.info(f"用户 {user_id} 的数据库初始化成功")
except Exception as e:
logger.error(f"用户 {user_id} 的数据库初始化失败: {str(e)}", exc_info=True)
raise
async def close_db():
"""关闭所有数据库连接"""
try:
logger.info("正在关闭所有数据库连接...")
for user_id, engine in _engine_cache.items():
await engine.dispose()
logger.info(f"用户 {user_id} 的数据库连接已关闭")
_engine_cache.clear()
logger.info("所有数据库连接已关闭")
except Exception as e:
logger.error(f"关闭数据库连接失败: {str(e)}", exc_info=True)
raise