From 20d9319a16431ab68536391d943e5c102ef5246e Mon Sep 17 00:00:00 2001 From: xiamuceer Date: Mon, 10 Nov 2025 21:16:55 +0800 Subject: [PATCH] =?UTF-8?q?update:1.=E5=88=87=E6=8D=A2=E6=95=B0=E6=8D=AE?= =?UTF-8?q?=E5=BA=93PostgreSQL?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .dockerignore | 1 + README.md | 124 ++- backend/.env.example | 162 +++- backend/app/api/chapters.py | 142 ++- backend/app/api/characters.py | 71 +- backend/app/api/memories.py | 63 +- backend/app/api/organizations.py | 115 ++- backend/app/api/outlines.py | 104 ++- backend/app/api/projects.py | 196 ++++- backend/app/api/relationships.py | 48 +- backend/app/api/wizard_stream.py | 6 + backend/app/api/writing_styles.py | 80 +- backend/app/config.py | 37 +- backend/app/database.py | 226 ++++- backend/app/main.py | 1 - backend/app/mcp/http_client.py | 17 +- backend/app/models/character.py | 4 +- backend/app/models/memory.py | 4 +- backend/app/models/project.py | 1 + backend/app/models/relationship.py | 2 +- backend/app/services/ai_service.py | 8 +- backend/app/services/memory_service.py | 2 +- .../adapter_config.json | 0 backend/requirements.txt | 12 +- backend/scripts/init_postgres.sql | 30 + backend/scripts/migrate_sqlite_to_postgres.py | 816 ++++++++++++++++++ backend/scripts/setup_postgres.py | 408 +++++++++ docker-compose.yml | 79 +- frontend/src/components/ChapterAnalysis.tsx | 8 + frontend/src/pages/Chapters.tsx | 1 + frontend/src/types/index.ts | 14 +- 31 files changed, 2526 insertions(+), 256 deletions(-) create mode 100644 backend/embedding/models--sentence-transformers--paraphrase-multilingual-MiniLM-L12-v2/.no_exist/86741b4e3f5cb7765a600d3a3d55a0f6a6cb443d/adapter_config.json create mode 100644 backend/scripts/init_postgres.sql create mode 100644 backend/scripts/migrate_sqlite_to_postgres.py create mode 100644 backend/scripts/setup_postgres.py diff --git a/.dockerignore b/.dockerignore index ef6162d..7dfbc79 100644 --- a/.dockerignore +++ b/.dockerignore @@ -39,6 +39,7 @@ Thumbs.db # 数据库文件(不包含在镜像中) data/*.db backend/data/*.db +postgres_data/ # ChromaDB数据库(不包含在镜像中,会在运行时生成) backend/data/chroma_db/ diff --git a/README.md b/README.md index 32fa39b..5699d83 100644 --- a/README.md +++ b/README.md @@ -26,7 +26,7 @@ - 🌐 **世界观设定** - 构建完整的故事世界观和背景设定 - 🔐 **多种登录方式** - 支持 LinuxDO OAuth 登录和本地账户登录 - 🐳 **Docker 部署** - 一键部署,开箱即用 -- 💾 **数据持久化** - 基于 SQLite 的本地数据存储,支持多用户隔离 +- 💾 **数据持久化** - 支持 PostgreSQL 和 SQLite 双数据库,多用户数据隔离 - 🎨 **现代化 UI** - 基于 Ant Design 的美观界面,响应式设计 @@ -36,7 +36,7 @@ - [ ] **灵感模式** - 提供创作灵感和点子生成功能 - [✔] **自定义写作风格** - 支持自定义AI写作风格和语言风格 -- [ ] **支持数据导入导出** - 支持项目数据的导入和导出功能 +- [✔] **支持数据导入导出** - 支持项目数据的导入和导出功能 - [ ] **添加prompt调整界面** - 提供可视化的prompt模板编辑和调整界面 - [✔] **开放章节内容字数限制** - 支持用户在生成章节内容时设置字数 @wyf007 - [ ] **设定追溯与矛盾检测** - 对大纲、世界观、角色档案中的设定支持悬停查看注释,显示相关章节来源和佐证原文;自动检测新章节与已有设定的矛盾(吃书),标记为"矛盾"设定并提供解决建议,当新设定解决矛盾后自动更新注释说明 @lulujiang @@ -116,9 +116,77 @@ npm run build ## 🐳 部署方式 -### Docker Compose 部署 +### Docker Compose 部署(PostgreSQL) -#### 使用 Docker Hub 镜像(推荐) +**推荐生产环境使用**,提供更好的性能和并发支持。 + +#### 快速启动 + +```bash +# 1. 克隆项目 +git clone https://github.com/xiamuceer-j/MuMuAINovel.git +cd MuMuAINovel + +# 2. 配置环境变量 +cp backend/.env.example .env +# 编辑 .env 文件,设置必要的配置: +# - POSTGRES_PASSWORD(数据库密码) +# - OPENAI_API_KEY(AI服务密钥) +# - 其他可选配置 + +# 3. 启动服务(包含PostgreSQL) +docker-compose up -d + +# 4. 查看服务状态 +docker-compose ps + +# 5. 查看日志 +docker-compose logs -f + +# 6. 访问应用 +# 打开浏览器访问 http://localhost:8000 +``` + +#### 环境变量配置 + +创建 `.env` 文件并配置: + +```bash +# PostgreSQL数据库密码(必须设置) +POSTGRES_PASSWORD=your_secure_password_here + +# AI服务配置(必须设置) +OPENAI_API_KEY=your_openai_api_key +DEFAULT_AI_PROVIDER=openai +DEFAULT_MODEL=gpt-4 + +# 本地账户登录(可选) +LOCAL_AUTH_ENABLED=true +LOCAL_AUTH_USERNAME=admin +LOCAL_AUTH_PASSWORD=admin123 + +# 其他配置见 backend/.env.example +``` + +#### 服务说明 + +- **postgres**: PostgreSQL 18 数据库 + - 端口:5432 + - 数据持久化:`./postgres_data` + - 已优化配置,支持80-150并发用户 + +- **mumuainovel**: 主应用服务 + - 端口:8000 + - 自动等待数据库就绪 + - 日志持久化:`./logs` + +详细部署指南请参考:[Docker + PostgreSQL 部署文档](docs/docker-postgres-deployment.md) + +### Docker Compose 部署(SQLite) + +适合个人使用或小团队,配置更简单。 + +#### 使用 Docker Hub 镜像 项目已发布到 Docker Hub,可直接拉取使用: @@ -197,7 +265,11 @@ networks: #### 2. 数据持久化 -数据目录已通过 volume 挂载,数据不会丢失: +**PostgreSQL部署**: +- `./postgres_data`:PostgreSQL 数据库文件 +- `./logs`:应用日志文件 + +**SQLite部署**: - `./data`:SQLite 数据库文件 - `./logs`:应用日志文件 @@ -238,11 +310,33 @@ services: ## ⚙️ 配置说明 +### 数据库选择 + +项目支持两种数据库: + +```bash +# .env 配置 +DATABASE_URL=postgresql+asyncpg://mumuai:password@postgres:5432/mumuai_novel +``` + +```bash +# .env 配置 +DATABASE_URL=sqlite+aiosqlite:///data/ai_story.db +``` + ### 环境变量 创建 `.env` 文件并配置以下变量: ```bash +# ===== 数据库配置 ===== +# PostgreSQL(生产环境推荐) +DATABASE_URL=postgresql+asyncpg://mumuai:password@postgres:5432/mumuai_novel +POSTGRES_PASSWORD=your_secure_password_here + +# 或使用 SQLite(开发环境) +# DATABASE_URL=sqlite+aiosqlite:///data/ai_story.db + # ===== AI 服务配置(必填)===== # OpenAI 配置(支持官方API和中转API) OPENAI_API_KEY=your_openai_key_here @@ -325,24 +419,6 @@ OPENAI_API_KEY=sk-xxxxxxxxxxxxxxxx OPENAI_BASE_URL=https://api.new-api.com/v1 ``` -**API2D** -```bash -OPENAI_API_KEY=fk-xxxxxxxxxxxxxxxx -OPENAI_BASE_URL=https://api.api2d.com/v1 -``` - -**OpenAI-SB** -```bash -OPENAI_API_KEY=sb-xxxxxxxxxxxxxxxx -OPENAI_BASE_URL=https://api.openai-sb.com/v1 -``` - -**自建 One API / New API** -```bash -OPENAI_API_KEY=sk-xxxxxxxxxxxxxxxx -OPENAI_BASE_URL=https://your-domain.com/v1 -``` - ##### 注意事项 - ✅ 所有支持 OpenAI 接口格式的服务都可以使用 @@ -436,7 +512,7 @@ MuMuAINovel/ ### 后端 - **框架**:FastAPI 0.109.0 -- **数据库**:SQLite + SQLAlchemy(异步) +- **数据库**:PostgreSQL / SQLite + SQLAlchemy(异步) - **AI 集成**:OpenAI、Anthropic、Google Gemini SDK - **认证**:LinuxDO OAuth2、本地账户 - **日志**:Python logging + 文件轮转 diff --git a/backend/.env.example b/backend/.env.example index ddcefe1..3646421 100644 --- a/backend/.env.example +++ b/backend/.env.example @@ -1,54 +1,140 @@ -# AI服务配置 -# OpenAI配置 -OPENAI_API_KEY=your_openai_key_here -OPENAI_BASE_URL=https://api.openai.com/v1 - -# Anthropic配置 -ANTHROPIC_API_KEY=your_anthropic_key_here -ANTHROPIC_BASE_URL=https://api.anthropic.com - -# 默认AI提供商:openai, gemini, anthropic -DEFAULT_AI_PROVIDER=openai -DEFAULT_MODEL=gpt-4.1 -DEFAULT_TEMPERATURE=0.8 -DEFAULT_MAX_TOKENS=32000 +# ========================================== +# MuMuAINovel 配置文件示例 +# ========================================== +# 复制此文件为 .env 并修改配置值 +# cp .env.example .env +# ========================================== # 应用配置 +# ========================================== APP_NAME=MuMuAINovel APP_VERSION=1.0.0 APP_HOST=0.0.0.0 APP_PORT=8000 -DEBUG=true +DEBUG=True +# ========================================== +# 数据库配置 +# ========================================== + +# 选项1: PostgreSQL(生产环境推荐) +# DATABASE_URL=postgresql+asyncpg://username:password@localhost:5432/database_name +# 示例: +# DATABASE_URL=postgresql+asyncpg://mumuai:your_password@localhost:5432/mumuai_novel + +# 选项2: SQLite(开发环境,默认) +DATABASE_URL=sqlite+aiosqlite:///data/ai_story.db + +# PostgreSQL 连接池配置(优化后,支持80-150并发用户) +DATABASE_POOL_SIZE=30 # 核心连接数(默认30,小团队可用20) +DATABASE_MAX_OVERFLOW=20 # 最大溢出连接数(默认20,小团队可用10) +DATABASE_POOL_TIMEOUT=60 # 连接等待超时秒数(默认60) +DATABASE_POOL_RECYCLE=1800 # 连接回收时间秒数(默认1800=30分钟) +DATABASE_POOL_PRE_PING=True # 连接前检测是否有效 +DATABASE_POOL_USE_LIFO=True # 使用LIFO策略提高连接复用率 + +# 会话监控配置 +DATABASE_SESSION_MAX_ACTIVE=50 # 活跃会话警告阈值 +DATABASE_SESSION_LEAK_THRESHOLD=100 # 会话泄漏严重告警阈值 + +# SQLite 性能优化配置(仅在使用SQLite时生效) +SQLITE_CACHE_SIZE_MB=128 # SQLite缓存大小(MB),默认128 +SQLITE_MMAP_SIZE_MB=256 # 内存映射I/O大小(MB),默认256 +SQLITE_WAL_AUTOCHECKPOINT=1000 # WAL自动检查点间隔 + +# 数据库监控配置 +DATABASE_ENABLE_SLOW_QUERY_LOG=True # 启用慢查询日志 +DATABASE_SLOW_QUERY_THRESHOLD=1.0 # 慢查询阈值(秒) +DATABASE_ENABLE_METRICS=True # 启用性能指标收集 + +# ========================================== +# 日志配置 +# ========================================== +LOG_LEVEL=INFO +LOG_TO_FILE=True +LOG_FILE_PATH=logs/app.log +LOG_MAX_BYTES=10485760 +LOG_BACKUP_COUNT=30 + +# ========================================== +# CORS配置 +# ========================================== +CORS_ORIGINS=["http://localhost:8000","http://127.0.0.1:8000"] + +# ========================================== +# AI服务配置 +# ========================================== + +# OpenAI配置 +OPENAI_API_KEY=your_openai_api_key_here +OPENAI_BASE_URL=https://api.openai.com/v1 +# 或使用兼容的API服务(如DeepSeek) +# OPENAI_BASE_URL=https://api.deepseek.com/v1 + +# Gemini配置(可选) +# GEMINI_API_KEY=your_gemini_api_key_here +# GEMINI_BASE_URL=https://generativelanguage.googleapis.com + +# Anthropic配置(可选) +# ANTHROPIC_API_KEY=your_anthropic_api_key_here +# ANTHROPIC_BASE_URL=https://api.anthropic.com + +# 默认AI配置 +DEFAULT_AI_PROVIDER=openai +DEFAULT_MODEL=gpt-4 +DEFAULT_TEMPERATURE=0.7 +DEFAULT_MAX_TOKENS=2000 + +# ========================================== # LinuxDO OAuth2 配置(可选) -# 注意:Docker部署时,LINUXDO_REDIRECT_URI 应该使用实际的域名或服务器IP -# 本地开发: http://localhost:8000/api/auth/callback -# 生产环境: https://your-domain.com/api/auth/callback 或 http://your-server-ip:8000/api/auth/callback -LINUXDO_CLIENT_ID=your_client_id_here -LINUXDO_CLIENT_SECRET=your_client_secret_here -LINUXDO_REDIRECT_URI=http://localhost:8000/api/auth/callback +# ========================================== +# LINUXDO_CLIENT_ID=your_client_id +# LINUXDO_CLIENT_SECRET=your_client_secret +# LINUXDO_REDIRECT_URI=http://localhost:8000/api/auth/callback -# 前端URL配置(用于OAuth回调后重定向到前端) -# 本地开发: http://localhost:8000 -# 生产环境: https://your-domain.com 或 http://your-server-ip:8000 +# 前端URL(OAuth回调后重定向) FRONTEND_URL=http://localhost:8000 -# 本地账户登录配置 -# 启用本地账户登录(true/false) -LOCAL_AUTH_ENABLED=true -# 本地登录用户名 -LOCAL_AUTH_USERNAME=admin -# 本地登录密码 -LOCAL_AUTH_PASSWORD=your_secure_password_here -# 本地用户显示名称 -LOCAL_AUTH_DISPLAY_NAME=管理员 +# 初始管理员(LinuxDO user_id) +# INITIAL_ADMIN_LINUXDO_ID=12345 +# ========================================== +# 本地账户登录配置 +# ========================================== +LOCAL_AUTH_ENABLED=True +LOCAL_AUTH_USERNAME=admin +LOCAL_AUTH_PASSWORD=admin123 +LOCAL_AUTH_DISPLAY_NAME=本地管理员 + +# ========================================== # 会话配置 -# 会话过期时间(分钟),默认120分钟(2小时) +# ========================================== SESSION_EXPIRE_MINUTES=120 -# 会话刷新阈值(分钟),剩余时间少于此值时可刷新,默认30分钟 SESSION_REFRESH_THRESHOLD_MINUTES=30 -# CORS配置(生产环境) -# 允许的跨域来源,多个用逗号分隔 -# CORS_ORIGINS=https://your-domain.com,https://www.your-domain.com \ No newline at end of file +# ========================================== +# 部署配置说明 +# ========================================== + +# 生产环境 PostgreSQL 配置示例(50-100并发用户): +# DATABASE_URL=postgresql+asyncpg://mumuai:SecurePassword123@db.example.com:5432/mumuai_prod +# DATABASE_POOL_SIZE=30 +# DATABASE_MAX_OVERFLOW=20 +# DATABASE_POOL_TIMEOUT=60 +# DATABASE_POOL_RECYCLE=1800 +# DATABASE_SESSION_MAX_ACTIVE=50 +# DEBUG=False +# LOG_LEVEL=WARNING + +# 高并发环境 PostgreSQL 配置示例(100+并发用户): +# DATABASE_URL=postgresql+asyncpg://mumuai:SecurePassword123@db.example.com:5432/mumuai_prod +# DATABASE_POOL_SIZE=40 +# DATABASE_MAX_OVERFLOW=30 +# DATABASE_SESSION_MAX_ACTIVE=80 +# DATABASE_SESSION_LEAK_THRESHOLD=150 + +# Docker 部署配置示例: +# DATABASE_URL=postgresql+asyncpg://mumuai:password@postgres:5432/mumuai_novel +# OPENAI_BASE_URL=https://api.openai.com/v1 +# FRONTEND_URL=https://your-domain.com +# LINUXDO_REDIRECT_URI=https://your-domain.com/api/auth/callback \ No newline at end of file diff --git a/backend/app/api/chapters.py b/backend/app/api/chapters.py index ffd399c..ae81504 100644 --- a/backend/app/api/chapters.py +++ b/backend/app/api/chapters.py @@ -43,6 +43,39 @@ logger = get_logger(__name__) db_write_locks: dict[str, Lock] = {} +async def verify_project_access(project_id: str, user_id: str, db: AsyncSession) -> Project: + """ + 验证用户是否有权访问指定项目 + + Args: + project_id: 项目ID + user_id: 用户ID + db: 数据库会话 + + Returns: + Project: 项目对象 + + Raises: + HTTPException: 401 未登录,404 项目不存在或无权访问 + """ + if not user_id: + raise HTTPException(status_code=401, detail="未登录") + + result = await db.execute( + select(Project).where( + Project.id == project_id, + Project.user_id == user_id + ) + ) + project = result.scalar_one_or_none() + + if not project: + logger.warning(f"项目访问被拒绝: project_id={project_id}, user_id={user_id}") + raise HTTPException(status_code=404, detail="项目不存在或无权访问") + + return project + + async def get_db_write_lock(user_id: str) -> Lock: """获取或创建用户的数据库写入锁""" if user_id not in db_write_locks: @@ -54,16 +87,13 @@ async def get_db_write_lock(user_id: str) -> Lock: @router.post("", response_model=ChapterResponse, summary="创建章节") async def create_chapter( chapter: ChapterCreate, + request: Request, db: AsyncSession = Depends(get_db) ): """创建新的章节""" - # 验证项目是否存在 - result = await db.execute( - select(Project).where(Project.id == chapter.project_id) - ) - project = result.scalar_one_or_none() - if not project: - raise HTTPException(status_code=404, detail="项目不存在") + # 验证用户权限和项目是否存在 + user_id = getattr(request.state, 'user_id', None) + project = await verify_project_access(chapter.project_id, user_id, db) # 计算字数 word_count = len(chapter.content) @@ -85,9 +115,14 @@ async def create_chapter( @router.get("/project/{project_id}", response_model=ChapterListResponse, summary="获取项目的所有章节") async def get_project_chapters( project_id: str, + request: Request, db: AsyncSession = Depends(get_db) ): """获取指定项目的所有章节(路径参数版本)""" + # 验证用户权限 + user_id = getattr(request.state, 'user_id', None) + await verify_project_access(project_id, user_id, db) + # 获取总数 count_result = await db.execute( select(func.count(Chapter.id)).where(Chapter.project_id == project_id) @@ -108,6 +143,7 @@ async def get_project_chapters( @router.get("/{chapter_id}", response_model=ChapterResponse, summary="获取章节详情") async def get_chapter( chapter_id: str, + request: Request, db: AsyncSession = Depends(get_db) ): """根据ID获取章节详情""" @@ -119,12 +155,17 @@ async def get_chapter( if not chapter: raise HTTPException(status_code=404, detail="章节不存在") + # 验证用户权限 + user_id = getattr(request.state, 'user_id', None) + await verify_project_access(chapter.project_id, user_id, db) + return chapter @router.get("/{chapter_id}/navigation", summary="获取章节导航信息") async def get_chapter_navigation( chapter_id: str, + request: Request, db: AsyncSession = Depends(get_db) ): """ @@ -140,6 +181,10 @@ async def get_chapter_navigation( if not current_chapter: raise HTTPException(status_code=404, detail="章节不存在") + # 验证用户权限 + user_id = getattr(request.state, 'user_id', None) + await verify_project_access(current_chapter.project_id, user_id, db) + # 获取上一章 prev_result = await db.execute( select(Chapter) @@ -183,6 +228,7 @@ async def get_chapter_navigation( async def update_chapter( chapter_id: str, chapter_update: ChapterUpdate, + request: Request, db: AsyncSession = Depends(get_db) ): """更新章节信息""" @@ -194,6 +240,10 @@ async def update_chapter( if not chapter: raise HTTPException(status_code=404, detail="章节不存在") + # 验证用户权限 + user_id = getattr(request.state, 'user_id', None) + await verify_project_access(chapter.project_id, user_id, db) + # 记录旧字数 old_word_count = chapter.word_count or 0 @@ -223,6 +273,7 @@ async def update_chapter( @router.delete("/{chapter_id}", summary="删除章节") async def delete_chapter( chapter_id: str, + request: Request, db: AsyncSession = Depends(get_db) ): """删除章节""" @@ -234,6 +285,10 @@ async def delete_chapter( if not chapter: raise HTTPException(status_code=404, detail="章节不存在") + # 验证用户权限 + user_id = getattr(request.state, 'user_id', None) + await verify_project_access(chapter.project_id, user_id, db) + # 更新项目字数 result = await db.execute( select(Project).where(Project.id == chapter.project_id) @@ -481,6 +536,7 @@ async def build_smart_chapter_context( @router.get("/{chapter_id}/can-generate", summary="检查章节是否可以生成") async def check_can_generate( chapter_id: str, + request: Request, db: AsyncSession = Depends(get_db) ): """ @@ -495,6 +551,10 @@ async def check_can_generate( if not chapter: raise HTTPException(status_code=404, detail="章节不存在") + # 验证用户权限 + user_id = getattr(request.state, 'user_id', None) + await verify_project_access(chapter.project_id, user_id, db) + # 检查前置条件 can_generate, error_msg, previous_chapters = await check_prerequisites(db, chapter) @@ -1238,6 +1298,7 @@ async def generate_chapter_content_stream( @router.get("/{chapter_id}/analysis/status", summary="查询章节分析任务状态") async def get_analysis_task_status( chapter_id: str, + request: Request, db: AsyncSession = Depends(get_db) ): """ @@ -1248,16 +1309,32 @@ async def get_analysis_task_status( - 如果任务状态为pending且超过2分钟未启动,自动标记为failed 返回: - - task_id: 任务ID - - status: pending/running/completed/failed + - has_task: 是否存在分析任务 + - task_id: 任务ID(如果存在) + - status: pending/running/completed/failed/none(如果不存在则为none) - progress: 0-100 - error_message: 错误信息(如果失败) - auto_recovered: 是否被自动恢复 - created_at: 创建时间 - completed_at: 完成时间 + + 注意:当章节不存在或无权访问时返回404,当没有分析任务时返回has_task=false """ from datetime import timedelta + # 先获取章节以验证存在性和权限 + chapter_result = await db.execute( + select(Chapter).where(Chapter.id == chapter_id) + ) + chapter = chapter_result.scalar_one_or_none() + + if not chapter: + raise HTTPException(status_code=404, detail="章节不存在") + + # 验证用户权限 + user_id = getattr(request.state, 'user_id', None) + await verify_project_access(chapter.project_id, user_id, db) + # 获取该章节最新的分析任务 result = await db.execute( select(AnalysisTask) @@ -1268,7 +1345,19 @@ async def get_analysis_task_status( task = result.scalar_one_or_none() if not task: - raise HTTPException(status_code=404, detail="未找到分析任务") + # 返回无任务状态,而不是抛出404错误 + return { + "has_task": False, + "chapter_id": chapter_id, + "status": "none", + "progress": 0, + "error_message": None, + "auto_recovered": False, + "task_id": None, + "created_at": None, + "started_at": None, + "completed_at": None + } auto_recovered = False current_time = datetime.now() @@ -1299,6 +1388,7 @@ async def get_analysis_task_status( logger.warning(f"🔄 自动恢复未启动的任务: {task.id}, 章节: {chapter_id}") return { + "has_task": True, "task_id": task.id, "chapter_id": task.chapter_id, "status": task.status, @@ -1314,6 +1404,7 @@ async def get_analysis_task_status( @router.get("/{chapter_id}/analysis", summary="获取章节分析结果") async def get_chapter_analysis( chapter_id: str, + request: Request, db: AsyncSession = Depends(get_db) ): """ @@ -1325,6 +1416,16 @@ async def get_chapter_analysis( - memories: 提取的记忆列表 - created_at: 分析时间 """ + # 先获取章节以验证权限 + chapter_result_check = await db.execute( + select(Chapter).where(Chapter.id == chapter_id) + ) + chapter_check = chapter_result_check.scalar_one_or_none() + if chapter_check: + # 验证用户权限 + user_id = getattr(request.state, 'user_id', None) + await verify_project_access(chapter_check.project_id, user_id, db) + # 获取分析结果 analysis_result = await db.execute( select(PlotAnalysis) @@ -1369,6 +1470,7 @@ async def get_chapter_analysis( @router.get("/{chapter_id}/annotations", summary="获取章节标注数据") async def get_chapter_annotations( chapter_id: str, + request: Request, db: AsyncSession = Depends(get_db) ): """ @@ -1377,6 +1479,9 @@ async def get_chapter_annotations( 返回格式化的标注列表,包含精确位置信息 适用于章节内容的可视化标注展示 """ + # 验证用户权限 + user_id = getattr(request.state, 'user_id', None) + # 获取章节 chapter_result = await db.execute( select(Chapter).where(Chapter.id == chapter_id) @@ -1386,6 +1491,9 @@ async def get_chapter_annotations( if not chapter: raise HTTPException(status_code=404, detail="章节不存在") + # 验证项目访问权限 + await verify_project_access(chapter.project_id, user_id, db) + # 获取分析结果 analysis_result = await db.execute( select(PlotAnalysis) @@ -1623,13 +1731,8 @@ async def batch_generate_chapters_in_order( if not user_id: raise HTTPException(status_code=401, detail="未登录") - # 验证项目存在 - project_result = await db.execute( - select(Project).where(Project.id == project_id) - ) - project = project_result.scalar_one_or_none() - if not project: - raise HTTPException(status_code=404, detail="项目不存在") + # 验证项目存在和用户权限 + project = await verify_project_access(project_id, user_id, db) # 获取项目的所有章节,按序号排序 result = await db.execute( @@ -1750,12 +1853,17 @@ async def get_batch_generation_status( @router.get("/project/{project_id}/batch-generate/active", summary="获取项目当前运行中的批量生成任务") async def get_active_batch_generation( project_id: str, + request: Request, db: AsyncSession = Depends(get_db) ): """ 获取项目当前运行中的批量生成任务 用于页面刷新后恢复任务状态 """ + # 验证用户权限 + user_id = getattr(request.state, 'user_id', None) + await verify_project_access(project_id, user_id, db) + result = await db.execute( select(BatchGenerationTask) .where(BatchGenerationTask.project_id == project_id) diff --git a/backend/app/api/characters.py b/backend/app/api/characters.py index 85f6446..3880a5b 100644 --- a/backend/app/api/characters.py +++ b/backend/app/api/characters.py @@ -24,12 +24,50 @@ router = APIRouter(prefix="/characters", tags=["角色管理"]) logger = get_logger(__name__) +async def verify_project_access(project_id: str, user_id: str, db: AsyncSession) -> Project: + """ + 验证用户是否有权访问指定项目 + + Args: + project_id: 项目ID + user_id: 用户ID + db: 数据库会话 + + Returns: + Project: 项目对象 + + Raises: + HTTPException: 401 未登录,404 项目不存在或无权访问 + """ + if not user_id: + raise HTTPException(status_code=401, detail="未登录") + + result = await db.execute( + select(Project).where( + Project.id == project_id, + Project.user_id == user_id + ) + ) + project = result.scalar_one_or_none() + + if not project: + logger.warning(f"项目访问被拒绝: project_id={project_id}, user_id={user_id}") + raise HTTPException(status_code=404, detail="项目不存在或无权访问") + + return project + + @router.get("", response_model=CharacterListResponse, summary="获取角色列表") async def get_characters( project_id: str, + request: Request, db: AsyncSession = Depends(get_db) ): """获取指定项目的所有角色(query参数版本)""" + # 验证用户权限 + user_id = getattr(request.state, 'user_id', None) + await verify_project_access(project_id, user_id, db) + # 获取总数 count_result = await db.execute( select(func.count(Character.id)).where(Character.project_id == project_id) @@ -93,9 +131,14 @@ async def get_characters( @router.get("/project/{project_id}", response_model=CharacterListResponse, summary="获取项目的所有角色") async def get_project_characters( project_id: str, + request: Request, db: AsyncSession = Depends(get_db) ): """获取指定项目的所有角色(路径参数版本)""" + # 验证用户权限 + user_id = getattr(request.state, 'user_id', None) + await verify_project_access(project_id, user_id, db) + # 获取总数 count_result = await db.execute( select(func.count(Character.id)).where(Character.project_id == project_id) @@ -159,6 +202,7 @@ async def get_project_characters( @router.get("/{character_id}", response_model=CharacterResponse, summary="获取角色详情") async def get_character( character_id: str, + request: Request, db: AsyncSession = Depends(get_db) ): """根据ID获取角色详情""" @@ -170,6 +214,10 @@ async def get_character( if not character: raise HTTPException(status_code=404, detail="角色不存在") + # 验证用户权限 + user_id = getattr(request.state, 'user_id', None) + await verify_project_access(character.project_id, user_id, db) + return character @@ -177,6 +225,7 @@ async def get_character( async def update_character( character_id: str, character_update: CharacterUpdate, + request: Request, db: AsyncSession = Depends(get_db) ): """更新角色信息""" @@ -188,6 +237,10 @@ async def update_character( if not character: raise HTTPException(status_code=404, detail="角色不存在") + # 验证用户权限 + user_id = getattr(request.state, 'user_id', None) + await verify_project_access(character.project_id, user_id, db) + # 更新字段 update_data = character_update.model_dump(exclude_unset=True) for field, value in update_data.items(): @@ -201,6 +254,7 @@ async def update_character( @router.delete("/{character_id}", summary="删除角色") async def delete_character( character_id: str, + request: Request, db: AsyncSession = Depends(get_db) ): """删除角色""" @@ -212,6 +266,10 @@ async def delete_character( if not character: raise HTTPException(status_code=404, detail="角色不存在") + # 验证用户权限 + user_id = getattr(request.state, 'user_id', None) + await verify_project_access(character.project_id, user_id, db) + await db.delete(character) await db.commit() @@ -233,13 +291,9 @@ async def generate_character( 生成内容包括:姓名、年龄、性别、性格、外貌、背景故事、人际关系等 """ - # 验证项目是否存在并获取项目信息 - result = await db.execute( - select(Project).where(Project.id == request.project_id) - ) - project = result.scalar_one_or_none() - if not project: - raise HTTPException(status_code=404, detail="项目不存在") + # 验证用户权限和项目是否存在 + user_id = getattr(http_request.state, 'user_id', None) + project = await verify_project_access(request.project_id, user_id, db) try: # 获取已存在的角色列表,用于关系网络 @@ -295,9 +349,6 @@ async def generate_character( user_input=user_input ) - # 获取user_id用于MCP工具调用 - user_id = http_request.state.user_id if hasattr(http_request.state, 'user_id') else 'default_user' - # 调用AI生成角色(支持MCP工具) logger.info(f"🎯 开始为项目 {request.project_id} 生成角色(启用MCP)") logger.info(f" - 角色名:{request.name or 'AI生成'}") diff --git a/backend/app/api/memories.py b/backend/app/api/memories.py index 68e0d16..9e54566 100644 --- a/backend/app/api/memories.py +++ b/backend/app/api/memories.py @@ -6,6 +6,7 @@ from typing import List, Optional from app.database import get_db from app.models.memory import StoryMemory, PlotAnalysis from app.models.chapter import Chapter +from app.models.project import Project from app.services.memory_service import memory_service from app.services.plot_analyzer import get_plot_analyzer from app.services.ai_service import create_user_ai_service @@ -17,6 +18,26 @@ logger = get_logger(__name__) router = APIRouter(prefix="/api/memories", tags=["memories"]) +async def verify_project_access(project_id: str, user_id: str, db: AsyncSession) -> Project: + """验证用户是否有权访问指定项目""" + if not user_id: + raise HTTPException(status_code=401, detail="未登录") + + result = await db.execute( + select(Project).where( + Project.id == project_id, + Project.user_id == user_id + ) + ) + project = result.scalar_one_or_none() + + if not project: + logger.warning(f"项目访问被拒绝: project_id={project_id}, user_id={user_id}") + raise HTTPException(status_code=404, detail="项目不存在或无权访问") + + return project + + @router.post("/projects/{project_id}/analyze-chapter/{chapter_id}") async def analyze_chapter( project_id: str, @@ -30,7 +51,10 @@ async def analyze_chapter( 对指定章节进行剧情分析,提取钩子、伏笔、情节点等,并存入记忆系统 """ try: - user_id = request.state.user_id + user_id = getattr(request.state, 'user_id', None) + + # 验证用户权限 + await verify_project_access(project_id, user_id, db) # 获取章节内容 result = await db.execute( @@ -192,7 +216,10 @@ async def get_project_memories( ): """获取项目的记忆列表""" try: - user_id = request.state.user_id + user_id = getattr(request.state, 'user_id', None) + + # 验证用户权限 + await verify_project_access(project_id, user_id, db) # 构建查询 query = select(StoryMemory).where(StoryMemory.project_id == project_id) @@ -222,10 +249,16 @@ async def get_project_memories( async def get_chapter_analysis( project_id: str, chapter_id: str, + request: Request, db: AsyncSession = Depends(get_db) ): """获取章节的剧情分析""" try: + user_id = getattr(request.state, 'user_id', None) + + # 验证用户权限 + await verify_project_access(project_id, user_id, db) + result = await db.execute( select(PlotAnalysis).where( and_( @@ -258,11 +291,15 @@ async def search_memories( query: str, memory_types: Optional[List[str]] = None, limit: int = 10, - min_importance: float = 0.0 + min_importance: float = 0.0, + db: AsyncSession = Depends(get_db) ): """语义搜索项目记忆""" try: - user_id = request.state.user_id + user_id = getattr(request.state, 'user_id', None) + + # 验证用户权限 + await verify_project_access(project_id, user_id, db) memories = await memory_service.search_memories( user_id=user_id, @@ -294,7 +331,10 @@ async def get_unresolved_foreshadows( ): """获取未完结的伏笔""" try: - user_id = request.state.user_id + user_id = getattr(request.state, 'user_id', None) + + # 验证用户权限 + await verify_project_access(project_id, user_id, db) # 从向量库搜索 foreshadows = await memory_service.find_unresolved_foreshadows( @@ -317,11 +357,15 @@ async def get_unresolved_foreshadows( @router.get("/projects/{project_id}/stats") async def get_memory_stats( project_id: str, - request: Request + request: Request, + db: AsyncSession = Depends(get_db) ): """获取记忆统计信息""" try: - user_id = request.state.user_id + user_id = getattr(request.state, 'user_id', None) + + # 验证用户权限 + await verify_project_access(project_id, user_id, db) stats = await memory_service.get_memory_stats( user_id=user_id, @@ -347,7 +391,10 @@ async def delete_chapter_memories( ): """删除章节的所有记忆""" try: - user_id = request.state.user_id + user_id = getattr(request.state, 'user_id', None) + + # 验证用户权限 + await verify_project_access(project_id, user_id, db) # 从数据库删除 result = await db.execute( diff --git a/backend/app/api/organizations.py b/backend/app/api/organizations.py index f6e01a4..287f663 100644 --- a/backend/app/api/organizations.py +++ b/backend/app/api/organizations.py @@ -1,5 +1,5 @@ """组织管理API""" -from fastapi import APIRouter, Depends, HTTPException +from fastapi import APIRouter, Depends, HTTPException, Request from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import select, and_ from typing import List, Optional @@ -31,6 +31,26 @@ router = APIRouter(prefix="/organizations", tags=["组织管理"]) logger = get_logger(__name__) +async def verify_project_access(project_id: str, user_id: str, db: AsyncSession) -> Project: + """验证用户是否有权访问指定项目""" + if not user_id: + raise HTTPException(status_code=401, detail="未登录") + + result = await db.execute( + select(Project).where( + Project.id == project_id, + Project.user_id == user_id + ) + ) + project = result.scalar_one_or_none() + + if not project: + logger.warning(f"项目访问被拒绝: project_id={project_id}, user_id={user_id}") + raise HTTPException(status_code=404, detail="项目不存在或无权访问") + + return project + + class OrganizationGenerateRequest(BaseModel): """AI生成组织的请求模型""" project_id: str = Field(..., description="项目ID") @@ -44,8 +64,13 @@ class OrganizationGenerateRequest(BaseModel): @router.get("/project/{project_id}", response_model=List[OrganizationDetailResponse], summary="获取项目的所有组织") async def get_project_organizations( project_id: str, + request: Request, db: AsyncSession = Depends(get_db) ): + # 验证用户权限 + user_id = getattr(request.state, 'user_id', None) + await verify_project_access(project_id, user_id, db) + """ 获取项目中的所有组织及其详情 @@ -85,6 +110,7 @@ async def get_project_organizations( @router.get("/{org_id}", response_model=OrganizationResponse, summary="获取组织详情") async def get_organization( org_id: str, + request: Request, db: AsyncSession = Depends(get_db) ): """获取组织的详细信息""" @@ -96,12 +122,17 @@ async def get_organization( if not org: raise HTTPException(status_code=404, detail="组织不存在") + # 验证用户权限 + user_id = getattr(request.state, 'user_id', None) + await verify_project_access(org.project_id, user_id, db) + return org @router.post("/", response_model=OrganizationResponse, summary="创建组织") async def create_organization( organization: OrganizationCreate, + request: Request, db: AsyncSession = Depends(get_db) ): """ @@ -110,6 +141,10 @@ async def create_organization( - 需要关联到一个已存在的角色记录(is_organization=True) - 可以设置父组织、势力等级等属性 """ + # 验证用户权限 + user_id = getattr(request.state, 'user_id', None) + await verify_project_access(organization.project_id, user_id, db) + # 验证角色是否存在且是组织 char_result = await db.execute( select(Character).where(Character.id == organization.character_id) @@ -142,6 +177,7 @@ async def create_organization( async def update_organization( org_id: str, organization: OrganizationUpdate, + request: Request, db: AsyncSession = Depends(get_db) ): """更新组织的属性""" @@ -153,6 +189,10 @@ async def update_organization( if not db_org: raise HTTPException(status_code=404, detail="组织不存在") + # 验证用户权限 + user_id = getattr(request.state, 'user_id', None) + await verify_project_access(db_org.project_id, user_id, db) + # 更新字段 update_data = organization.model_dump(exclude_unset=True) for field, value in update_data.items(): @@ -168,6 +208,7 @@ async def update_organization( @router.delete("/{org_id}", summary="删除组织") async def delete_organization( org_id: str, + request: Request, db: AsyncSession = Depends(get_db) ): """删除组织(会级联删除所有成员关系)""" @@ -179,6 +220,10 @@ async def delete_organization( if not db_org: raise HTTPException(status_code=404, detail="组织不存在") + # 验证用户权限 + user_id = getattr(request.state, 'user_id', None) + await verify_project_access(db_org.project_id, user_id, db) + await db.delete(db_org) await db.commit() @@ -191,6 +236,7 @@ async def delete_organization( @router.get("/{org_id}/members", response_model=List[OrganizationMemberDetailResponse], summary="获取组织成员") async def get_organization_members( org_id: str, + request: Request, db: AsyncSession = Depends(get_db) ): """ @@ -202,9 +248,14 @@ async def get_organization_members( org_result = await db.execute( select(Organization).where(Organization.id == org_id) ) - if not org_result.scalar_one_or_none(): + org = org_result.scalar_one_or_none() + if not org: raise HTTPException(status_code=404, detail="组织不存在") + # 验证用户权限 + user_id = getattr(request.state, 'user_id', None) + await verify_project_access(org.project_id, user_id, db) + # 获取成员列表 result = await db.execute( select(OrganizationMember) @@ -244,6 +295,7 @@ async def get_organization_members( async def add_organization_member( org_id: str, member: OrganizationMemberCreate, + request: Request, db: AsyncSession = Depends(get_db) ): """ @@ -260,6 +312,10 @@ async def add_organization_member( if not org: raise HTTPException(status_code=404, detail="组织不存在") + # 验证用户权限 + user_id = getattr(request.state, 'user_id', None) + await verify_project_access(org.project_id, user_id, db) + # 验证角色存在 char_result = await db.execute( select(Character).where(Character.id == member.character_id) @@ -304,6 +360,7 @@ async def add_organization_member( async def update_organization_member( member_id: str, member: OrganizationMemberUpdate, + request: Request, db: AsyncSession = Depends(get_db) ): """更新组织成员的职位、忠诚度等信息""" @@ -315,6 +372,14 @@ async def update_organization_member( if not db_member: raise HTTPException(status_code=404, detail="成员记录不存在") + # 通过成员所属的组织验证用户权限 + org_result = await db.execute( + select(Organization).where(Organization.id == db_member.organization_id) + ) + org = org_result.scalar_one() + user_id = getattr(request.state, 'user_id', None) + await verify_project_access(org.project_id, user_id, db) + # 更新字段 update_data = member.model_dump(exclude_unset=True) for field, value in update_data.items(): @@ -330,6 +395,7 @@ async def update_organization_member( @router.delete("/members/{member_id}", summary="移除组织成员") async def remove_organization_member( member_id: str, + request: Request, db: AsyncSession = Depends(get_db) ): """ @@ -350,6 +416,10 @@ async def remove_organization_member( select(Organization).where(Organization.id == db_member.organization_id) ) org = org_result.scalar_one() + + # 验证用户权限 + user_id = getattr(request.state, 'user_id', None) + await verify_project_access(org.project_id, user_id, db) org.member_count = max(0, org.member_count - 1) await db.delete(db_member) @@ -360,7 +430,8 @@ async def remove_organization_member( @router.post("/generate", response_model=CharacterResponse, summary="AI生成组织") async def generate_organization( - request: OrganizationGenerateRequest, + gen_request: OrganizationGenerateRequest, + http_request: Request, db: AsyncSession = Depends(get_db), user_ai_service: AIService = Depends(get_user_ai_service) ): @@ -372,19 +443,15 @@ async def generate_organization( 生成内容包括:组织名称、类型、特性、背景、目的、势力等级等 """ - # 验证项目是否存在并获取项目信息 - result = await db.execute( - select(Project).where(Project.id == request.project_id) - ) - project = result.scalar_one_or_none() - if not project: - raise HTTPException(status_code=404, detail="项目不存在") + # 验证用户权限 + user_id = getattr(http_request.state, 'user_id', None) + project = await verify_project_access(gen_request.project_id, user_id, db) try: # 获取已存在的角色和组织列表 existing_chars_result = await db.execute( select(Character) - .where(Character.project_id == request.project_id) + .where(Character.project_id == gen_request.project_id) .order_by(Character.created_at.desc()) ) existing_characters = existing_chars_result.scalars().all() @@ -422,10 +489,10 @@ async def generate_organization( # 构建用户输入信息 user_input = f""" 用户要求: -- 组织名称:{request.name or '请AI生成'} -- 组织类型:{request.organization_type or '请AI根据世界观决定'} -- 背景设定:{request.background or '无特殊要求'} -- 其他要求:{request.requirements or '无'} +- 组织名称:{gen_request.name or '请AI生成'} +- 组织类型:{gen_request.organization_type or '请AI根据世界观决定'} +- 背景设定:{gen_request.background or '无特殊要求'} +- 其他要求:{gen_request.requirements or '无'} """ # 使用统一的提示词服务 @@ -435,10 +502,10 @@ async def generate_organization( ) # 调用AI生成组织 - logger.info(f"🎯 开始为项目 {request.project_id} 生成组织") - logger.info(f" - 组织名:{request.name or 'AI生成'}") - logger.info(f" - 组织类型:{request.organization_type or 'AI决定'}") - logger.info(f" - 背景设定:{request.background or '无'}") + logger.info(f"🎯 开始为项目 {gen_request.project_id} 生成组织") + logger.info(f" - 组织名:{gen_request.name or 'AI生成'}") + logger.info(f" - 组织类型:{gen_request.organization_type or 'AI决定'}") + logger.info(f" - 背景设定:{gen_request.background or '无'}") logger.info(f" - AI提供商:{user_ai_service.api_provider}") logger.info(f" - AI模型:{user_ai_service.default_model}") logger.info(f" - Prompt长度:{len(prompt)} 字符") @@ -492,8 +559,8 @@ async def generate_organization( # 创建角色记录(组织也是角色的一种) character = Character( - project_id=request.project_id, - name=organization_data.get("name", request.name or "未命名组织"), + project_id=gen_request.project_id, + name=organization_data.get("name", gen_request.name or "未命名组织"), is_organization=True, role_type="supporting", # 组织通常作为配角 personality=organization_data.get("personality", ""), @@ -518,7 +585,7 @@ async def generate_organization( # 自动创建Organization详情记录 organization = Organization( character_id=character.id, - project_id=request.project_id, + project_id=gen_request.project_id, member_count=0, power_level=organization_data.get("power_level", 50), location=organization_data.get("location"), @@ -532,7 +599,7 @@ async def generate_organization( # 记录生成历史 history = GenerationHistory( - project_id=request.project_id, + project_id=gen_request.project_id, prompt=prompt, generated_content=ai_content, model=user_ai_service.default_model @@ -542,7 +609,7 @@ async def generate_organization( await db.commit() await db.refresh(character) - logger.info(f"🎉 成功为项目 {request.project_id} 生成组织: {character.name}") + logger.info(f"🎉 成功为项目 {gen_request.project_id} 生成组织: {character.name}") return character diff --git a/backend/app/api/outlines.py b/backend/app/api/outlines.py index a7dcced..31897b7 100644 --- a/backend/app/api/outlines.py +++ b/backend/app/api/outlines.py @@ -30,19 +30,49 @@ router = APIRouter(prefix="/outlines", tags=["大纲管理"]) logger = get_logger(__name__) +async def verify_project_access(project_id: str, user_id: str, db: AsyncSession) -> Project: + """ + 验证用户是否有权访问指定项目 + + Args: + project_id: 项目ID + user_id: 用户ID + db: 数据库会话 + + Returns: + Project: 项目对象 + + Raises: + HTTPException: 401 未登录,404 项目不存在或无权访问 + """ + if not user_id: + raise HTTPException(status_code=401, detail="未登录") + + result = await db.execute( + select(Project).where( + Project.id == project_id, + Project.user_id == user_id + ) + ) + project = result.scalar_one_or_none() + + if not project: + logger.warning(f"项目访问被拒绝: project_id={project_id}, user_id={user_id}") + raise HTTPException(status_code=404, detail="项目不存在或无权访问") + + return project + + @router.post("", response_model=OutlineResponse, summary="创建大纲") async def create_outline( outline: OutlineCreate, + request: Request, db: AsyncSession = Depends(get_db) ): """创建新的章节大纲,同时创建对应的章节记录""" - # 验证项目是否存在 - result = await db.execute( - select(Project).where(Project.id == outline.project_id) - ) - project = result.scalar_one_or_none() - if not project: - raise HTTPException(status_code=404, detail="项目不存在") + # 验证用户权限 + user_id = getattr(request.state, 'user_id', None) + await verify_project_access(outline.project_id, user_id, db) # 创建大纲 db_outline = Outline(**outline.model_dump()) @@ -66,9 +96,14 @@ async def create_outline( @router.get("", response_model=OutlineListResponse, summary="获取大纲列表") async def get_outlines( project_id: str, + request: Request, db: AsyncSession = Depends(get_db) ): """获取指定项目的所有大纲""" + # 验证用户权限 + user_id = getattr(request.state, 'user_id', None) + await verify_project_access(project_id, user_id, db) + # 获取总数 count_result = await db.execute( select(func.count(Outline.id)).where(Outline.project_id == project_id) @@ -89,9 +124,14 @@ async def get_outlines( @router.get("/project/{project_id}", response_model=OutlineListResponse, summary="获取项目的所有大纲") async def get_project_outlines( project_id: str, + request: Request, db: AsyncSession = Depends(get_db) ): """获取指定项目的所有大纲(路径参数版本)""" + # 验证用户权限 + user_id = getattr(request.state, 'user_id', None) + await verify_project_access(project_id, user_id, db) + # 获取总数 count_result = await db.execute( select(func.count(Outline.id)).where(Outline.project_id == project_id) @@ -112,6 +152,7 @@ async def get_project_outlines( @router.get("/{outline_id}", response_model=OutlineResponse, summary="获取大纲详情") async def get_outline( outline_id: str, + request: Request, db: AsyncSession = Depends(get_db) ): """根据ID获取大纲详情""" @@ -123,6 +164,10 @@ async def get_outline( if not outline: raise HTTPException(status_code=404, detail="大纲不存在") + # 验证用户权限 + user_id = getattr(request.state, 'user_id', None) + await verify_project_access(outline.project_id, user_id, db) + return outline @@ -130,6 +175,7 @@ async def get_outline( async def update_outline( outline_id: str, outline_update: OutlineUpdate, + request: Request, db: AsyncSession = Depends(get_db) ): """更新大纲信息,同步更新对应章节和structure字段""" @@ -141,6 +187,10 @@ async def update_outline( if not outline: raise HTTPException(status_code=404, detail="大纲不存在") + # 验证用户权限 + user_id = getattr(request.state, 'user_id', None) + await verify_project_access(outline.project_id, user_id, db) + # 更新字段 update_data = outline_update.model_dump(exclude_unset=True) for field, value in update_data.items(): @@ -196,6 +246,7 @@ async def update_outline( @router.delete("/{outline_id}", summary="删除大纲") async def delete_outline( outline_id: str, + request: Request, db: AsyncSession = Depends(get_db) ): """删除大纲,同步删除章节,并重新排序后续项""" @@ -207,6 +258,10 @@ async def delete_outline( if not outline: raise HTTPException(status_code=404, detail="大纲不存在") + # 验证用户权限 + user_id = getattr(request.state, 'user_id', None) + await verify_project_access(outline.project_id, user_id, db) + project_id = outline.project_id deleted_order = outline.order_index @@ -252,7 +307,8 @@ async def delete_outline( @router.post("/reorder", summary="批量重排序大纲") async def reorder_outlines( - request: OutlineReorderRequest, + reorder_request: OutlineReorderRequest, + http_request: Request, db: AsyncSession = Depends(get_db) ): """ @@ -261,10 +317,20 @@ async def reorder_outlines( 策略:先收集所有变更,最后一次性提交,避免临时冲突 """ try: + # 验证用户权限(通过第一个大纲的project_id) + user_id = getattr(http_request.state, 'user_id', None) + if reorder_request.orders and len(reorder_request.orders) > 0: + first_outline_result = await db.execute( + select(Outline).where(Outline.id == reorder_request.orders[0].id) + ) + first_outline = first_outline_result.scalar_one_or_none() + if first_outline: + await verify_project_access(first_outline.project_id, user_id, db) + # 第一步:收集所有大纲和对应的章节 outline_chapter_map = {} # {outline_id: (outline, chapter, old_order, new_order)} - for item in request.orders: + for item in reorder_request.orders: outline_id = item.id new_order = item.order_index @@ -341,13 +407,9 @@ async def generate_outline( - new: 强制全新生成 - continue: 强制续写模式 """ - # 验证项目是否存在 - result = await db.execute( - select(Project).where(Project.id == request.project_id) - ) - project = result.scalar_one_or_none() - if not project: - raise HTTPException(status_code=404, detail="项目不存在") + # 验证用户权限 + user_id = getattr(http_request.state, 'user_id', None) + project = await verify_project_access(request.project_id, user_id, db) try: # 获取现有大纲(强制从数据库获取最新数据,包括用户手动修改的内容) @@ -1472,13 +1534,9 @@ async def generate_outline_stream( "model": "gpt-4" // 可选 } """ - # 验证项目是否存在 - result = await db.execute( - select(Project).where(Project.id == data.get("project_id")) - ) - project = result.scalar_one_or_none() - if not project: - raise HTTPException(status_code=404, detail="项目不存在") + # 验证用户权限 + user_id = getattr(request.state, 'user_id', None) + project = await verify_project_access(data.get("project_id"), user_id, db) # 判断模式 mode = data.get("mode", "auto") diff --git a/backend/app/api/projects.py b/backend/app/api/projects.py index 51f6269..42e820f 100644 --- a/backend/app/api/projects.py +++ b/backend/app/api/projects.py @@ -41,17 +41,31 @@ router = APIRouter(prefix="/projects", tags=["项目管理"]) @router.post("", response_model=ProjectResponse, summary="创建项目") async def create_project( project: ProjectCreate, - db: AsyncSession = Depends(get_db) + db: AsyncSession = Depends(get_db), + request: Request = None ): try: - logger.info(f"创建新项目: {project.title}") - db_project = Project(**project.model_dump()) + # 从认证中间件获取用户ID + user_id = getattr(request.state, 'user_id', None) + if not user_id: + logger.warning("未登录用户尝试创建项目") + raise HTTPException(status_code=401, detail="未登录") + + logger.info(f"创建新项目: {project.title}, user_id={user_id}") + + # 创建项目时自动设置user_id + project_data = project.model_dump() + project_data['user_id'] = user_id + db_project = Project(**project_data) + db.add(db_project) await db.commit() await db.refresh(db_project) - logger.info(f"项目创建成功: {db_project.id}") + logger.info(f"项目创建成功: project_id={db_project.id}, user_id={user_id}") return db_project + except HTTPException: + raise except Exception as e: logger.error(f"创建项目失败: {str(e)}", exc_info=True) raise @@ -61,24 +75,38 @@ async def create_project( async def get_projects( skip: int = 0, limit: int = 100, - db: AsyncSession = Depends(get_db) + db: AsyncSession = Depends(get_db), + request: Request = None ): - """获取所有项目列表""" + """获取当前用户的项目列表""" try: - logger.debug(f"获取项目列表: skip={skip}, limit={limit}") - count_result = await db.execute(select(func.count(Project.id))) + # 从认证中间件获取用户ID + user_id = getattr(request.state, 'user_id', None) + if not user_id: + logger.warning("未登录用户尝试获取项目列表") + raise HTTPException(status_code=401, detail="未登录") + + logger.debug(f"获取项目列表: user_id={user_id}, skip={skip}, limit={limit}") + + # 只查询当前用户的项目 + count_result = await db.execute( + select(func.count(Project.id)).where(Project.user_id == user_id) + ) total = count_result.scalar_one() result = await db.execute( select(Project) + .where(Project.user_id == user_id) .order_by(Project.updated_at.desc()) .offset(skip) .limit(limit) ) projects = result.scalars().all() - logger.info(f"获取项目列表成功: 共{total}个项目") + logger.info(f"获取项目列表成功: user_id={user_id}, 共{total}个项目") return ProjectListResponse(total=total, items=projects) + except HTTPException: + raise except Exception as e: logger.error(f"获取项目列表失败: {str(e)}", exc_info=True) raise @@ -87,17 +115,29 @@ async def get_projects( @router.get("/{project_id}", response_model=ProjectResponse, summary="获取项目详情") async def get_project( project_id: str, - db: AsyncSession = Depends(get_db) + db: AsyncSession = Depends(get_db), + request: Request = None ): try: - logger.debug(f"获取项目详情: {project_id}") + # 从认证中间件获取用户ID + user_id = getattr(request.state, 'user_id', None) + if not user_id: + logger.warning("未登录用户尝试获取项目详情") + raise HTTPException(status_code=401, detail="未登录") + + logger.debug(f"获取项目详情: project_id={project_id}, user_id={user_id}") + + # 只查询当前用户的项目 result = await db.execute( - select(Project).where(Project.id == project_id) + select(Project).where( + Project.id == project_id, + Project.user_id == user_id + ) ) project = result.scalar_one_or_none() if not project: - logger.warning(f"项目不存在: {project_id}") + logger.warning(f"项目不存在或无权访问: project_id={project_id}, user_id={user_id}") raise HTTPException(status_code=404, detail="项目不存在") logger.info(f"获取项目详情成功: {project.title}") @@ -113,17 +153,29 @@ async def get_project( async def update_project( project_id: str, project_update: ProjectUpdate, - db: AsyncSession = Depends(get_db) + db: AsyncSession = Depends(get_db), + request: Request = None ): try: - logger.info(f"更新项目: {project_id}") + # 从认证中间件获取用户ID + user_id = getattr(request.state, 'user_id', None) + if not user_id: + logger.warning("未登录用户尝试更新项目") + raise HTTPException(status_code=401, detail="未登录") + + logger.info(f"更新项目: project_id={project_id}, user_id={user_id}") + + # 只查询当前用户的项目 result = await db.execute( - select(Project).where(Project.id == project_id) + select(Project).where( + Project.id == project_id, + Project.user_id == user_id + ) ) project = result.scalar_one_or_none() if not project: - logger.warning(f"项目不存在: {project_id}") + logger.warning(f"项目不存在或无权访问: project_id={project_id}, user_id={user_id}") raise HTTPException(status_code=404, detail="项目不存在") update_data = project_update.model_dump(exclude_unset=True) @@ -149,22 +201,30 @@ async def delete_project( db: AsyncSession = Depends(get_db) ): try: - logger.info(f"删除项目: {project_id}") + # 从认证中间件获取用户ID + user_id = getattr(request.state, 'user_id', None) + if not user_id: + logger.warning("未登录用户尝试删除项目") + raise HTTPException(status_code=401, detail="未登录") + + logger.info(f"删除项目: project_id={project_id}, user_id={user_id}") + + # 只查询当前用户的项目 result = await db.execute( - select(Project).where(Project.id == project_id) + select(Project).where( + Project.id == project_id, + Project.user_id == user_id + ) ) project = result.scalar_one_or_none() if not project: - logger.warning(f"项目不存在: {project_id}") + logger.warning(f"项目不存在或无权访问: project_id={project_id}, user_id={user_id}") raise HTTPException(status_code=404, detail="项目不存在") project_title = project.title - # 从认证中间件获取用户ID - user_id = getattr(request.state, 'user_id', None) - - # 删除向量数据库中的记忆 + # 删除向量数据库中的记忆(user_id已在上面获取) if user_id: try: await memory_service.delete_project_memories(user_id, project_id) @@ -234,22 +294,33 @@ async def delete_project( @router.get("/{project_id}/export", summary="导出项目章节为TXT") async def export_project_chapters( project_id: str, - db: AsyncSession = Depends(get_db) + db: AsyncSession = Depends(get_db), + request: Request = None ): """ 导出项目的所有章节内容为TXT文本文件 按章节顺序组织,包含项目基本信息 """ try: - logger.info(f"开始导出项目: {project_id}") + # 从认证中间件获取用户ID + user_id = getattr(request.state, 'user_id', None) + if not user_id: + logger.warning("未登录用户尝试导出项目") + raise HTTPException(status_code=401, detail="未登录") + logger.info(f"开始导出项目: project_id={project_id}, user_id={user_id}") + + # 只查询当前用户的项目 result = await db.execute( - select(Project).where(Project.id == project_id) + select(Project).where( + Project.id == project_id, + Project.user_id == user_id + ) ) project = result.scalar_one_or_none() if not project: - logger.warning(f"项目不存在: {project_id}") + logger.warning(f"项目不存在或无权访问: project_id={project_id}, user_id={user_id}") raise HTTPException(status_code=404, detail="项目不存在") chapters_result = await db.execute( @@ -326,6 +397,7 @@ async def export_project_chapters( @router.post("/{project_id}/check-consistency", summary="检查数据一致性") async def check_project_consistency( project_id: str, + request: Request, auto_fix: bool = True, db: AsyncSession = Depends(get_db) ): @@ -343,15 +415,25 @@ async def check_project_consistency( - organization_members: 验证组织成员数据完整性 """ try: - logger.info(f"开始数据一致性检查: {project_id}, auto_fix={auto_fix}") + # 从认证中间件获取用户ID + user_id = getattr(request.state, 'user_id', None) + if not user_id: + logger.warning("未登录用户尝试检查数据一致性") + raise HTTPException(status_code=401, detail="未登录") + logger.info(f"开始数据一致性检查: project_id={project_id}, user_id={user_id}, auto_fix={auto_fix}") + + # 只查询当前用户的项目 result = await db.execute( - select(Project).where(Project.id == project_id) + select(Project).where( + Project.id == project_id, + Project.user_id == user_id + ) ) project = result.scalar_one_or_none() if not project: - logger.warning(f"项目不存在: {project_id}") + logger.warning(f"项目不存在或无权访问: project_id={project_id}, user_id={user_id}") raise HTTPException(status_code=404, detail="项目不存在") report = await run_full_data_consistency_check(project_id, db, auto_fix) @@ -369,6 +451,7 @@ async def check_project_consistency( @router.post("/{project_id}/fix-organizations", summary="修复组织记录") async def fix_project_organizations( project_id: str, + request: Request, db: AsyncSession = Depends(get_db) ): """ @@ -377,15 +460,25 @@ async def fix_project_organizations( 为所有is_organization=True但没有Organization记录的Character创建记录 """ try: - logger.info(f"开始修复组织记录: {project_id}") + # 从认证中间件获取用户ID + user_id = getattr(request.state, 'user_id', None) + if not user_id: + logger.warning("未登录用户尝试修复组织记录") + raise HTTPException(status_code=401, detail="未登录") + logger.info(f"开始修复组织记录: project_id={project_id}, user_id={user_id}") + + # 只查询当前用户的项目 result = await db.execute( - select(Project).where(Project.id == project_id) + select(Project).where( + Project.id == project_id, + Project.user_id == user_id + ) ) project = result.scalar_one_or_none() if not project: - logger.warning(f"项目不存在: {project_id}") + logger.warning(f"项目不存在或无权访问: project_id={project_id}, user_id={user_id}") raise HTTPException(status_code=404, detail="项目不存在") fixed_count, total_count = await fix_missing_organization_records(project_id, db) @@ -407,6 +500,7 @@ async def fix_project_organizations( @router.post("/{project_id}/fix-member-counts", summary="修复成员计数") async def fix_project_member_counts( project_id: str, + request: Request, db: AsyncSession = Depends(get_db) ): """ @@ -415,15 +509,25 @@ async def fix_project_member_counts( 从实际成员记录重新计算每个组织的member_count """ try: - logger.info(f"开始修复成员计数: {project_id}") + # 从认证中间件获取用户ID + user_id = getattr(request.state, 'user_id', None) + if not user_id: + logger.warning("未登录用户尝试修复成员计数") + raise HTTPException(status_code=401, detail="未登录") + logger.info(f"开始修复成员计数: project_id={project_id}, user_id={user_id}") + + # 只查询当前用户的项目 result = await db.execute( - select(Project).where(Project.id == project_id) + select(Project).where( + Project.id == project_id, + Project.user_id == user_id + ) ) project = result.scalar_one_or_none() if not project: - logger.warning(f"项目不存在: {project_id}") + logger.warning(f"项目不存在或无权访问: project_id={project_id}, user_id={user_id}") raise HTTPException(status_code=404, detail="项目不存在") fixed_count, total_count = await fix_organization_member_counts(project_id, db) @@ -445,6 +549,7 @@ async def fix_project_member_counts( @router.post("/{project_id}/export-data", summary="导出项目数据为JSON") async def export_project_data( project_id: str, + request: Request, options: ExportOptions, db: AsyncSession = Depends(get_db) ): @@ -459,16 +564,25 @@ async def export_project_data( JSON文件下载 """ try: - logger.info(f"开始导出项目数据: {project_id}") + # 从认证中间件获取用户ID + user_id = getattr(request.state, 'user_id', None) + if not user_id: + logger.warning("未登录用户尝试导出项目数据") + raise HTTPException(status_code=401, detail="未登录") - # 检查项目是否存在 + logger.info(f"开始导出项目数据: project_id={project_id}, user_id={user_id}") + + # 只查询当前用户的项目 result = await db.execute( - select(Project).where(Project.id == project_id) + select(Project).where( + Project.id == project_id, + Project.user_id == user_id + ) ) project = result.scalar_one_or_none() if not project: - logger.warning(f"项目不存在: {project_id}") + logger.warning(f"项目不存在或无权访问: project_id={project_id}, user_id={user_id}") raise HTTPException(status_code=404, detail="项目不存在") # 导出数据 diff --git a/backend/app/api/relationships.py b/backend/app/api/relationships.py index 1ea4fcd..ee8ed32 100644 --- a/backend/app/api/relationships.py +++ b/backend/app/api/relationships.py @@ -1,5 +1,5 @@ """关系管理API""" -from fastapi import APIRouter, Depends, HTTPException, Query +from fastapi import APIRouter, Depends, HTTPException, Query, Request from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import select, or_, and_ from typing import List, Optional @@ -12,6 +12,7 @@ from app.models.relationship import ( OrganizationMember ) from app.models.character import Character +from app.models.project import Project from app.schemas.relationship import ( RelationshipTypeResponse, CharacterRelationshipCreate, @@ -27,6 +28,26 @@ router = APIRouter(prefix="/relationships", tags=["关系管理"]) logger = get_logger(__name__) +async def verify_project_access(project_id: str, user_id: str, db: AsyncSession) -> Project: + """验证用户是否有权访问指定项目""" + if not user_id: + raise HTTPException(status_code=401, detail="未登录") + + result = await db.execute( + select(Project).where( + Project.id == project_id, + Project.user_id == user_id + ) + ) + project = result.scalar_one_or_none() + + if not project: + logger.warning(f"项目访问被拒绝: project_id={project_id}, user_id={user_id}") + raise HTTPException(status_code=404, detail="项目不存在或无权访问") + + return project + + @router.get("/types", response_model=List[RelationshipTypeResponse], summary="获取关系类型列表") async def get_relationship_types(db: AsyncSession = Depends(get_db)): """获取所有预定义的关系类型""" @@ -38,9 +59,14 @@ async def get_relationship_types(db: AsyncSession = Depends(get_db)): @router.get("/project/{project_id}", response_model=List[CharacterRelationshipResponse], summary="获取项目的所有关系") async def get_project_relationships( project_id: str, + request: Request, character_id: Optional[str] = Query(None, description="筛选特定角色的关系"), db: AsyncSession = Depends(get_db) ): + # 验证用户权限 + user_id = getattr(request.state, 'user_id', None) + await verify_project_access(project_id, user_id, db) + """ 获取项目中的所有角色关系 @@ -70,8 +96,13 @@ async def get_project_relationships( @router.get("/graph/{project_id}", response_model=RelationshipGraphData, summary="获取关系图谱数据") async def get_relationship_graph( project_id: str, + request: Request, db: AsyncSession = Depends(get_db) ): + # 验证用户权限 + user_id = getattr(request.state, 'user_id', None) + await verify_project_access(project_id, user_id, db) + """ 获取用于可视化的关系图谱数据 @@ -122,6 +153,7 @@ async def get_relationship_graph( @router.post("/", response_model=CharacterRelationshipResponse, summary="创建角色关系") async def create_relationship( relationship: CharacterRelationshipCreate, + request: Request, db: AsyncSession = Depends(get_db) ): """ @@ -131,6 +163,10 @@ async def create_relationship( - 可以指定预定义的关系类型或自定义关系名称 - 可以设置亲密度、状态等属性 """ + # 验证用户权限 + user_id = getattr(request.state, 'user_id', None) + await verify_project_access(relationship.project_id, user_id, db) + # 验证角色是否存在 char_from = await db.execute( select(Character).where(Character.id == relationship.character_from_id) @@ -161,6 +197,7 @@ async def create_relationship( async def update_relationship( relationship_id: str, relationship: CharacterRelationshipUpdate, + request: Request, db: AsyncSession = Depends(get_db) ): """更新角色关系的属性(亲密度、状态等)""" @@ -174,6 +211,10 @@ async def update_relationship( if not db_rel: raise HTTPException(status_code=404, detail="关系不存在") + # 验证用户权限 + user_id = getattr(request.state, 'user_id', None) + await verify_project_access(db_rel.project_id, user_id, db) + # 更新字段 update_data = relationship.model_dump(exclude_unset=True) for field, value in update_data.items(): @@ -189,6 +230,7 @@ async def update_relationship( @router.delete("/{relationship_id}", summary="删除关系") async def delete_relationship( relationship_id: str, + request: Request, db: AsyncSession = Depends(get_db) ): """删除角色关系""" @@ -202,6 +244,10 @@ async def delete_relationship( if not db_rel: raise HTTPException(status_code=404, detail="关系不存在") + # 验证用户权限 + user_id = getattr(request.state, 'user_id', None) + await verify_project_access(db_rel.project_id, user_id, db) + await db.delete(db_rel) await db.commit() diff --git a/backend/app/api/wizard_stream.py b/backend/app/api/wizard_stream.py index 33826c2..8788495 100644 --- a/backend/app/api/wizard_stream.py +++ b/backend/app/api/wizard_stream.py @@ -183,7 +183,13 @@ async def world_building_generator( # 保存到数据库 yield await SSEResponse.send_progress("保存到数据库...", 90) + # 确保user_id存在 + if not user_id: + yield await SSEResponse.send_error("用户ID缺失,无法创建项目", 401) + return + project = Project( + user_id=user_id, # 添加user_id字段 title=title, description=description, theme=theme, diff --git a/backend/app/api/writing_styles.py b/backend/app/api/writing_styles.py index af503bd..bbe2e9e 100644 --- a/backend/app/api/writing_styles.py +++ b/backend/app/api/writing_styles.py @@ -1,5 +1,5 @@ """写作风格管理 API""" -from fastapi import APIRouter, Depends, HTTPException +from fastapi import APIRouter, Depends, HTTPException, Request from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import select, func, delete from typing import List @@ -16,8 +16,30 @@ from ..schemas.writing_style import ( SetDefaultStyleRequest ) from ..services.prompt_service import WritingStyleManager +from ..logger import get_logger router = APIRouter(prefix="/writing-styles", tags=["writing-styles"]) +logger = get_logger(__name__) + + +async def verify_project_access(project_id: str, user_id: str, db: AsyncSession) -> Project: + """验证用户是否有权访问指定项目""" + if not user_id: + raise HTTPException(status_code=401, detail="未登录") + + result = await db.execute( + select(Project).where( + Project.id == project_id, + Project.user_id == user_id + ) + ) + project = result.scalar_one_or_none() + + if not project: + logger.warning(f"项目访问被拒绝: project_id={project_id}, user_id={user_id}") + raise HTTPException(status_code=404, detail="项目不存在或无权访问") + + return project @router.get("/presets/list", response_model=List[dict]) @@ -42,6 +64,7 @@ async def get_preset_styles(): @router.post("", response_model=WritingStyleResponse, status_code=201) async def create_writing_style( style_data: WritingStyleCreate, + request: Request, db: AsyncSession = Depends(get_db) ): """ @@ -50,13 +73,9 @@ async def create_writing_style( - **基于预设创建**:提供 preset_id,系统会自动填充预设内容 - **完全自定义**:不提供 preset_id,需要手动填写所有字段 """ - # 验证项目是否存在 - result = await db.execute( - select(Project).where(Project.id == style_data.project_id) - ) - project = result.scalar_one_or_none() - if not project: - raise HTTPException(status_code=404, detail="项目不存在") + # 验证用户权限 + user_id = getattr(request.state, 'user_id', None) + await verify_project_access(style_data.project_id, user_id, db) # 如果基于预设创建,获取预设内容 if style_data.preset_id: @@ -120,6 +139,7 @@ async def create_writing_style( @router.get("/project/{project_id}", response_model=WritingStyleListResponse) async def get_project_styles( project_id: str, + request: Request, db: AsyncSession = Depends(get_db) ): """ @@ -128,13 +148,9 @@ async def get_project_styles( 返回:全局预设风格 + 该项目的自定义风格 按 order_index 排序,并标记哪个是当前项目的默认风格 """ - # 验证项目是否存在 - result = await db.execute( - select(Project).where(Project.id == project_id) - ) - project = result.scalar_one_or_none() - if not project: - raise HTTPException(status_code=404, detail="项目不存在") + # 验证用户权限 + user_id = getattr(request.state, 'user_id', None) + await verify_project_access(project_id, user_id, db) # 获取该项目的默认风格ID result = await db.execute( @@ -222,6 +238,7 @@ async def get_writing_style( async def update_writing_style( style_id: int, style_data: WritingStyleUpdate, + request: Request, db: AsyncSession = Depends(get_db) ): """ @@ -241,6 +258,10 @@ async def update_writing_style( if style.project_id is None: raise HTTPException(status_code=403, detail="不能修改全局预设风格,只能修改自定义风格") + # 验证用户权限 + user_id = getattr(request.state, 'user_id', None) + await verify_project_access(style.project_id, user_id, db) + # 更新字段 update_data = style_data.model_dump(exclude_unset=True) @@ -279,6 +300,7 @@ async def update_writing_style( @router.delete("/{style_id}", status_code=204) async def delete_writing_style( style_id: int, + request: Request, db: AsyncSession = Depends(get_db) ): """ @@ -300,6 +322,10 @@ async def delete_writing_style( if style.project_id is None: raise HTTPException(status_code=403, detail="不能删除全局预设风格,只能删除自定义风格") + # 验证用户权限 + user_id = getattr(request.state, 'user_id', None) + await verify_project_access(style.project_id, user_id, db) + # 检查是否有项目将其设置为默认风格 result = await db.execute( select(ProjectDefaultStyle).where(ProjectDefaultStyle.style_id == style_id) @@ -321,6 +347,7 @@ async def delete_writing_style( async def set_default_style( style_id: int, request_data: SetDefaultStyleRequest, + request: Request, db: AsyncSession = Depends(get_db) ): """ @@ -335,13 +362,9 @@ async def set_default_style( """ project_id = request_data.project_id - # 验证项目是否存在 - result = await db.execute( - select(Project).where(Project.id == project_id) - ) - project = result.scalar_one_or_none() - if not project: - raise HTTPException(status_code=404, detail="项目不存在") + # 验证用户权限 + user_id = getattr(request.state, 'user_id', None) + await verify_project_access(project_id, user_id, db) # 验证风格是否存在 result = await db.execute( @@ -379,6 +402,7 @@ async def set_default_style( @router.post("/project/{project_id}/init-defaults", response_model=WritingStyleListResponse) async def initialize_default_styles( project_id: str, + request: Request, db: AsyncSession = Depends(get_db) ): """ @@ -387,13 +411,9 @@ async def initialize_default_styles( 新架构下,预设风格是全局的,不需要为每个项目单独初始化 该接口保留用于兼容性,直接返回项目可用的所有风格 """ - # 验证项目是否存在 - result = await db.execute( - select(Project).where(Project.id == project_id) - ) - project = result.scalar_one_or_none() - if not project: - raise HTTPException(status_code=404, detail="项目不存在") + # 验证用户权限 + user_id = getattr(request.state, 'user_id', None) + await verify_project_access(project_id, user_id, db) # 直接返回项目可用的所有风格(全局预设 + 项目自定义) - return await get_project_styles(project_id, db) \ No newline at end of file + return await get_project_styles(project_id, request, db) \ No newline at end of file diff --git a/backend/app/config.py b/backend/app/config.py index 2247417..f3c57f4 100644 --- a/backend/app/config.py +++ b/backend/app/config.py @@ -3,6 +3,7 @@ from pydantic_settings import BaseSettings from typing import Optional from pathlib import Path import logging +import os # 获取项目根目录(从backend/app/config.py向上两级) PROJECT_ROOT = Path(__file__).parent.parent @@ -12,13 +13,15 @@ DATA_DIR.mkdir(exist_ok=True) # 配置模块使用标准logging(在logger.py初始化之前) config_logger = logging.getLogger(__name__) -# 数据库文件路径(绝对路径) +# 数据库配置:支持PostgreSQL和SQLite +# 优先使用环境变量DATABASE_URL,否则使用SQLite DB_FILE = DATA_DIR / "ai_story.db" +DEFAULT_SQLITE_URL = f"sqlite+aiosqlite:///{str(DB_FILE.absolute()).replace(chr(92), '/')}" -# 生成数据库URL(在类外部生成,确保使用绝对路径) -# 将Windows反斜杠转换为正斜杠,SQLite URL格式要求 -DATABASE_URL = f"sqlite+aiosqlite:///{str(DB_FILE.absolute()).replace(chr(92), '/')}" -config_logger.debug(f"数据库文件路径: {DB_FILE}") +# 从环境变量获取数据库URL,如果未设置则使用SQLite +DATABASE_URL = os.getenv("DATABASE_URL", DEFAULT_SQLITE_URL) + +config_logger.debug(f"数据库类型: {'PostgreSQL' if 'postgresql' in DATABASE_URL else 'SQLite'}") config_logger.debug(f"数据库URL: {DATABASE_URL}") class Settings(BaseSettings): @@ -41,9 +44,31 @@ class Settings(BaseSettings): # CORS配置 cors_origins: list[str] = ["http://localhost:8000", "http://127.0.0.1:8000"] - # 数据库配置 - 使用预先计算好的绝对路径URL + # 数据库配置 - 支持PostgreSQL和SQLite database_url: str = DATABASE_URL + # PostgreSQL连接池配置(优化后支持80-150并发用户) + database_pool_size: int = 30 # 核心连接池大小(从20提升到30) + database_max_overflow: int = 20 # 最大溢出连接数(从10提升到20) + database_pool_timeout: int = 60 # 连接池超时秒数(从30提升到60) + database_pool_recycle: int = 1800 # 连接回收时间秒数(从3600降低到1800,30分钟) + database_pool_pre_ping: bool = True # 连接前ping检测,确保连接有效 + database_pool_use_lifo: bool = True # 使用LIFO策略提高连接复用率 + + # 会话监控配置 + database_session_max_active: int = 50 # 活跃会话警告阈值(从100降低到50) + database_session_leak_threshold: int = 100 # 会话泄漏严重告警阈值 + + # SQLite优化配置 + sqlite_cache_size_mb: int = 128 # SQLite缓存大小MB(从64提升到128) + sqlite_mmap_size_mb: int = 256 # 内存映射I/O大小MB + sqlite_wal_autocheckpoint: int = 1000 # WAL自动检查点间隔 + + # 数据库监控配置 + database_enable_slow_query_log: bool = True # 启用慢查询日志 + database_slow_query_threshold: float = 1.0 # 慢查询阈值(秒) + database_enable_metrics: bool = True # 启用性能指标收集 + # AI服务配置 openai_api_key: Optional[str] = None openai_base_url: Optional[str] = None diff --git a/backend/app/database.py b/backend/app/database.py index 86aebac..1c83002 100644 --- a/backend/app/database.py +++ b/backend/app/database.py @@ -45,12 +45,59 @@ _session_stats = { async def get_engine(user_id: str): """获取或创建用户专属的数据库引擎(线程安全) + 支持PostgreSQL和SQLite两种数据库: + - PostgreSQL: 所有用户共享一个数据库,通过user_id字段隔离数据 + - SQLite: 每个用户一个独立的数据库文件 + Args: user_id: 用户ID Returns: 用户专属的异步引擎 """ + # PostgreSQL模式:所有用户共享同一个引擎 + if "postgresql" in settings.database_url: + cache_key = "shared_postgres" + if cache_key in _engine_cache: + return _engine_cache[cache_key] + + async with _cache_lock: + if cache_key not in _engine_cache: + # 优化后的PostgreSQL连接配置 + connect_args = { + "server_settings": { + "application_name": settings.app_name, + "jit": "off", # 关闭JIT以提高短查询性能 + }, + "command_timeout": 60, # 命令超时60秒 + "statement_cache_size": 500, # 启用语句缓存,提升重复查询性能 + } + + engine = create_async_engine( + settings.database_url, + echo=False, # 生产环境关闭SQL日志 + future=True, + pool_size=settings.database_pool_size, # 核心连接数:30 + max_overflow=settings.database_max_overflow, # 溢出连接数:20 + pool_timeout=settings.database_pool_timeout, # 连接超时:60秒 + pool_pre_ping=settings.database_pool_pre_ping, # 连接前检测 + pool_recycle=settings.database_pool_recycle, # 连接回收:1800秒 + pool_use_lifo=settings.database_pool_use_lifo, # LIFO策略提高复用 + connect_args=connect_args + ) + _engine_cache[cache_key] = engine + logger.info( + f"✅ PostgreSQL引擎已创建(优化配置)\n" + f" ├─ 连接池: {settings.database_pool_size} 核心 + {settings.database_max_overflow} 溢出 = {settings.database_pool_size + settings.database_max_overflow} 总连接\n" + f" ├─ 超时: {settings.database_pool_timeout}秒\n" + f" ├─ 回收: {settings.database_pool_recycle}秒\n" + f" ├─ 策略: LIFO(提高复用率)\n" + f" └─ 预估并发: 80-150用户" + ) + + return _engine_cache[cache_key] + + # SQLite模式:每个用户独立的数据库文件 if user_id in _engine_cache: return _engine_cache[user_id] @@ -76,18 +123,30 @@ async def get_engine(user_id: str): ) try: + # 应用优化后的SQLite配置 + cache_size = -1024 * settings.sqlite_cache_size_mb # 负数表示KB单位 + mmap_size = settings.sqlite_mmap_size_mb * 1024 * 1024 # 转换为字节 + async with engine.begin() as conn: await conn.execute(text("PRAGMA journal_mode=WAL")) await conn.execute(text("PRAGMA synchronous=NORMAL")) - await conn.execute(text("PRAGMA cache_size=-64000")) + await conn.execute(text(f"PRAGMA cache_size={cache_size}")) # 128MB缓存 + await conn.execute(text(f"PRAGMA mmap_size={mmap_size}")) # 256MB内存映射 await conn.execute(text("PRAGMA temp_store=MEMORY")) await conn.execute(text("PRAGMA busy_timeout=5000")) + await conn.execute(text(f"PRAGMA wal_autocheckpoint={settings.sqlite_wal_autocheckpoint}")) - logger.info(f"✅ 用户 {user_id} 的数据库已优化(WAL模式 + 64MB缓存)") + logger.info( + f"✅ 用户 {user_id} 的SQLite数据库已优化\n" + f" ├─ WAL模式\n" + f" ├─ 缓存: {settings.sqlite_cache_size_mb}MB\n" + f" ├─ 内存映射: {settings.sqlite_mmap_size_mb}MB\n" + f" └─ 预估并发: 15-20写入用户" + ) except Exception as e: - logger.warning(f"⚠️ 用户 {user_id} 数据库优化失败: {str(e)}") + logger.warning(f"⚠️ 用户 {user_id} SQLite数据库优化失败: {str(e)}") _engine_cache[user_id] = engine - logger.info(f"为用户 {user_id} 创建数据库引擎") + logger.info(f"为用户 {user_id} 创建SQLite数据库引擎") return _engine_cache[user_id] @@ -157,8 +216,11 @@ async def get_db(request: Request): logger.debug(f"📊 会话关闭 [User:{user_id}][ID:{session_id}] - 活跃:{_session_stats['active']}, 总创建:{_session_stats['created']}, 总关闭:{_session_stats['closed']}, 错误:{_session_stats['errors']}") - if _session_stats["active"] > 100: - logger.warning(f"🚨 活跃会话数过多: {_session_stats['active']},可能存在连接泄漏!") + # 使用优化后的会话监控阈值 + if _session_stats["active"] > settings.database_session_leak_threshold: + logger.error(f"🚨 严重告警:活跃会话数 {_session_stats['active']} 超过泄漏阈值 {settings.database_session_leak_threshold}!") + elif _session_stats["active"] > settings.database_session_max_active: + logger.warning(f"⚠️ 警告:活跃会话数 {_session_stats['active']} 超过警告阈值 {settings.database_session_max_active},可能存在连接泄漏!") elif _session_stats["active"] < 0: logger.error(f"🚨 活跃会话数异常: {_session_stats['active']},统计可能不准确!") @@ -324,4 +386,154 @@ async def close_db(): logger.info("所有数据库连接已关闭") except Exception as e: logger.error(f"关闭数据库连接失败: {str(e)}", exc_info=True) - raise \ No newline at end of file + raise + +async def get_database_stats(): + """获取数据库连接和会话统计信息 + + Returns: + dict: 包含数据库统计信息的字典 + """ + from app.config import settings + + stats = { + "session_stats": { + "created": _session_stats["created"], + "closed": _session_stats["closed"], + "active": _session_stats["active"], + "errors": _session_stats["errors"], + "generator_exits": _session_stats["generator_exits"], + "last_check": _session_stats["last_check"], + }, + "engine_cache": { + "total_engines": len(_engine_cache), + "engine_keys": list(_engine_cache.keys()), + }, + "config": { + "database_type": "PostgreSQL" if "postgresql" in settings.database_url else "SQLite", + "pool_size": settings.database_pool_size, + "max_overflow": settings.database_max_overflow, + "total_connections": settings.database_pool_size + settings.database_max_overflow, + "pool_timeout": settings.database_pool_timeout, + "session_max_active_threshold": settings.database_session_max_active, + "session_leak_threshold": settings.database_session_leak_threshold, + }, + "health": { + "status": "healthy", + "warnings": [], + "errors": [], + } + } + + # 健康检查 + if _session_stats["active"] > settings.database_session_leak_threshold: + stats["health"]["status"] = "critical" + stats["health"]["errors"].append( + f"活跃会话数 {_session_stats['active']} 超过泄漏阈值 {settings.database_session_leak_threshold}" + ) + elif _session_stats["active"] > settings.database_session_max_active: + stats["health"]["status"] = "warning" + stats["health"]["warnings"].append( + f"活跃会话数 {_session_stats['active']} 超过警告阈值 {settings.database_session_max_active}" + ) + + if _session_stats["active"] < 0: + stats["health"]["status"] = "error" + stats["health"]["errors"].append(f"活跃会话数异常: {_session_stats['active']}") + + error_rate = (_session_stats["errors"] / max(_session_stats["created"], 1)) * 100 + if error_rate > 5: + stats["health"]["status"] = "warning" + stats["health"]["warnings"].append(f"会话错误率过高: {error_rate:.2f}%") + + stats["health"]["error_rate"] = f"{error_rate:.2f}%" + + return stats + + +async def check_database_health(user_id: str = None) -> dict: + """检查数据库连接健康状态 + + Args: + user_id: 可选的用户ID,如果提供则检查特定用户的数据库 + + Returns: + dict: 健康检查结果 + """ + result = { + "healthy": True, + "checks": {}, + "timestamp": datetime.now().isoformat() + } + + try: + # 检查引擎是否存在 + if user_id: + engine = await get_engine(user_id) + cache_key = user_id + else: + if "postgresql" in settings.database_url: + cache_key = "shared_postgres" + if cache_key not in _engine_cache: + result["checks"]["engine"] = {"status": "not_initialized", "healthy": True} + return result + engine = _engine_cache[cache_key] + else: + result["checks"]["engine"] = {"status": "skipped", "message": "需要提供user_id检查SQLite"} + return result + + # 测试数据库连接 + AsyncSessionLocal = async_sessionmaker( + engine, + class_=AsyncSession, + expire_on_commit=False + ) + + async with AsyncSessionLocal() as session: + # 执行简单查询测试连接 + await session.execute(text("SELECT 1")) + result["checks"]["connection"] = {"status": "ok", "healthy": True} + + # 检查连接池状态(仅PostgreSQL) + if hasattr(engine.pool, 'size'): + pool_status = { + "size": engine.pool.size(), + "checked_in": engine.pool.checkedin(), + "checked_out": engine.pool.checkedout(), + "overflow": engine.pool.overflow(), + "healthy": True + } + + # 连接池健康检查 + if engine.pool.overflow() >= settings.database_max_overflow: + pool_status["healthy"] = False + pool_status["warning"] = "连接池溢出已满" + result["healthy"] = False + + result["checks"]["pool"] = pool_status + + except Exception as e: + result["healthy"] = False + result["checks"]["error"] = { + "status": "error", + "message": str(e), + "healthy": False + } + logger.error(f"数据库健康检查失败: {str(e)}", exc_info=True) + + return result + + +async def reset_session_stats(): + """重置会话统计信息(用于测试或维护)""" + global _session_stats + _session_stats = { + "created": 0, + "closed": 0, + "active": 0, + "errors": 0, + "generator_exits": 0, + "last_check": datetime.now().isoformat() + } + logger.info("✅ 会话统计信息已重置") + return _session_stats \ No newline at end of file diff --git a/backend/app/main.py b/backend/app/main.py index 8679f31..589b685 100644 --- a/backend/app/main.py +++ b/backend/app/main.py @@ -28,7 +28,6 @@ logger = get_logger(__name__) async def lifespan(app: FastAPI): """应用生命周期管理""" logger.info("应用启动,等待用户登录...") - logger.info("💡 MCP插件采用延迟加载策略,将在用户首次使用时自动加载") yield diff --git a/backend/app/mcp/http_client.py b/backend/app/mcp/http_client.py index 604f338..f8d75de 100644 --- a/backend/app/mcp/http_client.py +++ b/backend/app/mcp/http_client.py @@ -267,13 +267,26 @@ class HTTPMCPClient: start_time = time.time() try: - # 尝试连接并列举工具 + # 尝试连接并列举工具(直接调用SDK,避免重复日志) await self._ensure_connected() - tools = await self.list_tools() + + result = await self._session.list_tools() + + # 转换为字典格式 + tools = [] + for tool in result.tools: + tool_dict = { + "name": tool.name, + "description": tool.description or "", + "inputSchema": tool.inputSchema + } + tools.append(tool_dict) end_time = time.time() response_time = round((end_time - start_time) * 1000, 2) + logger.info(f"✅ 连接测试成功,获取到 {len(tools)} 个工具") + return { "success": True, "message": "连接测试成功", diff --git a/backend/app/models/character.py b/backend/app/models/character.py index 2363dca..d6c0115 100644 --- a/backend/app/models/character.py +++ b/backend/app/models/character.py @@ -14,8 +14,8 @@ class Character(Base): # 基本信息 name = Column(String(100), nullable=False, comment="角色/组织名称") - age = Column(String(20), comment="年龄") - gender = Column(String(20), comment="性别") + age = Column(String(50), comment="年龄") + gender = Column(String(50), comment="性别") is_organization = Column(Boolean, default=False, comment="是否为组织") # 角色类型:protagonist(主角)/supporting(配角)/antagonist(反派) diff --git a/backend/app/models/memory.py b/backend/app/models/memory.py index 881ad4b..66670e8 100644 --- a/backend/app/models/memory.py +++ b/backend/app/models/memory.py @@ -9,7 +9,7 @@ class StoryMemory(Base): """故事记忆表 - 存储结构化的故事片段和元数据""" __tablename__ = "story_memories" - id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4())) + id = Column(String(100), primary_key=True, default=lambda: str(uuid.uuid4())) project_id = Column(String(36), ForeignKey("projects.id", ondelete="CASCADE"), nullable=False, index=True) chapter_id = Column(String(36), ForeignKey("chapters.id", ondelete="CASCADE"), nullable=True, index=True) @@ -45,7 +45,7 @@ class StoryMemory(Base): # 伏笔相关字段 is_foreshadow = Column(Integer, default=0, comment="伏笔状态: 0=普通记忆, 1=已埋下伏笔, 2=伏笔已回收") - foreshadow_resolved_at = Column(String(36), ForeignKey("chapters.id", ondelete="SET NULL"), comment="伏笔回收的章节ID") + foreshadow_resolved_at = Column(String(100), ForeignKey("chapters.id", ondelete="SET NULL"), comment="伏笔回收的章节ID") foreshadow_strength = Column(Float, comment="伏笔强度 0.0-1.0") # 向量数据库关联 diff --git a/backend/app/models/project.py b/backend/app/models/project.py index 1913067..d69d57c 100644 --- a/backend/app/models/project.py +++ b/backend/app/models/project.py @@ -10,6 +10,7 @@ class Project(Base): __tablename__ = "projects" id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4())) + user_id = Column(String(36), nullable=False, index=True, comment="用户ID") title = Column(String(200), nullable=False, comment="项目标题") description = Column(Text, comment="项目简介") theme = Column(Text, comment="主题") diff --git a/backend/app/models/relationship.py b/backend/app/models/relationship.py index 70879dc..8b73c0f 100644 --- a/backend/app/models/relationship.py +++ b/backend/app/models/relationship.py @@ -75,7 +75,7 @@ class Organization(Base): # 组织特色 motto = Column(String(200), comment="宗旨/口号") - color = Column(String(20), comment="代表颜色") + color = Column(String(100), comment="代表颜色") created_at = Column(DateTime, server_default=func.now(), comment="创建时间") updated_at = Column(DateTime, server_default=func.now(), onupdate=func.now(), comment="更新时间") diff --git a/backend/app/services/ai_service.py b/backend/app/services/ai_service.py index 4bfeb97..3d8e2c0 100644 --- a/backend/app/services/ai_service.py +++ b/backend/app/services/ai_service.py @@ -82,7 +82,9 @@ class AIService: self.openai_http_client = None self.openai_api_key = None self.openai_base_url = None - logger.warning("OpenAI API key未配置") + # 只有当用户明确选择OpenAI作为提供商时才警告 + if self.api_provider == "openai": + logger.warning("⚠️ OpenAI API key未配置,但被设置为当前AI提供商") # 初始化Anthropic客户端 anthropic_key = api_key if api_provider == "anthropic" else app_settings.anthropic_api_key @@ -118,7 +120,9 @@ class AIService: self.anthropic_client = None else: self.anthropic_client = None - logger.warning("Anthropic API key未配置") + # 只有当用户明确选择Anthropic作为提供商时才警告 + if self.api_provider == "anthropic": + logger.warning("⚠️ Anthropic API key未配置,但被设置为当前AI提供商") async def generate_text( self, diff --git a/backend/app/services/memory_service.py b/backend/app/services/memory_service.py index 081e095..a467385 100644 --- a/backend/app/services/memory_service.py +++ b/backend/app/services/memory_service.py @@ -87,7 +87,7 @@ class MemoryService: 'sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2', cache_folder=model_cache_dir, device='cpu', # 明确指定使用CPU - trust_remote_code=False # 安全起见 + trust_remote_code=False, # 安全起见 ) logger.info("✅ Embedding模型加载成功 (paraphrase-multilingual-MiniLM-L12-v2)") except Exception as e: diff --git a/backend/embedding/models--sentence-transformers--paraphrase-multilingual-MiniLM-L12-v2/.no_exist/86741b4e3f5cb7765a600d3a3d55a0f6a6cb443d/adapter_config.json b/backend/embedding/models--sentence-transformers--paraphrase-multilingual-MiniLM-L12-v2/.no_exist/86741b4e3f5cb7765a600d3a3d55a0f6a6cb443d/adapter_config.json new file mode 100644 index 0000000..e69de29 diff --git a/backend/requirements.txt b/backend/requirements.txt index 138c591..49b1aee 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -5,7 +5,9 @@ python-multipart==0.0.20 # 数据库 sqlalchemy==2.0.25 -aiosqlite==0.19.0 +aiosqlite==0.19.0 # SQLite支持(保留用于开发环境) +asyncpg==0.29.0 # PostgreSQL异步驱动(生产环境) +psycopg2-binary==2.9.9 # PostgreSQL同步驱动(备用) # 数据验证 pydantic==2.12.4 @@ -29,8 +31,8 @@ numpy==1.26.4 chromadb==1.3.2 -# Transformers(锁定兼容版本) -transformers==4.35.2 +# Transformers(更新到最新稳定版本以修复 FutureWarning) +transformers==4.57.1 -# Sentence Transformers(基于PyTorch的文本embedding库) -sentence-transformers==2.3.1 +# Sentence Transformers(更新到最新稳定版本以修复 FutureWarning) +sentence-transformers==5.1.2 diff --git a/backend/scripts/init_postgres.sql b/backend/scripts/init_postgres.sql new file mode 100644 index 0000000..c8fe829 --- /dev/null +++ b/backend/scripts/init_postgres.sql @@ -0,0 +1,30 @@ +-- PostgreSQL 初始化脚本 +-- 此脚本会在PostgreSQL容器首次启动时自动执行 + +-- 确保使用UTF8编码 +SET client_encoding = 'UTF8'; + +-- 创建必要的扩展 +CREATE EXTENSION IF NOT EXISTS "uuid-ossp"; +CREATE EXTENSION IF NOT EXISTS "pg_trgm"; + +-- 设置时区 +SET timezone = 'Asia/Shanghai'; + +-- 优化配置(这些设置会在容器启动后生效) +-- 注意:部分配置已在docker-compose.yml的command中设置 + +-- 创建索引优化查询性能(表会由SQLAlchemy自动创建) +-- 这里只是预留空间,实际索引会在应用启动时创建 + +-- 输出初始化信息 +DO $$ +BEGIN + RAISE NOTICE '=================================================='; + RAISE NOTICE 'MuMuAINovel PostgreSQL 数据库初始化完成'; + RAISE NOTICE '数据库名称: mumuai_novel'; + RAISE NOTICE '字符编码: UTF8'; + RAISE NOTICE '时区设置: Asia/Shanghai'; + RAISE NOTICE '扩展已安装: uuid-ossp, pg_trgm'; + RAISE NOTICE '=================================================='; +END $$; \ No newline at end of file diff --git a/backend/scripts/migrate_sqlite_to_postgres.py b/backend/scripts/migrate_sqlite_to_postgres.py new file mode 100644 index 0000000..454bce8 --- /dev/null +++ b/backend/scripts/migrate_sqlite_to_postgres.py @@ -0,0 +1,816 @@ +#!/usr/bin/env python3 +""" +SQLite to PostgreSQL 数据迁移脚本 + +使用方法: + python backend/scripts/migrate_sqlite_to_postgres.py + +前置条件: + 1. PostgreSQL数据库已创建 + 2. .env文件中DATABASE_URL已配置为PostgreSQL + 3. SQLite数据文件存在于 backend/data/ 目录 +""" +import asyncio +import sys +from pathlib import Path +from typing import List, Dict, Any +import logging +from datetime import datetime + +# 添加项目根目录到Python路径 +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from sqlalchemy import create_engine, text, select +from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker +from app.database import Base +from app.models import ( + Project, Outline, Character, Chapter, GenerationHistory, + Settings, WritingStyle, ProjectDefaultStyle, + RelationshipType, CharacterRelationship, Organization, OrganizationMember, + StoryMemory, PlotAnalysis, AnalysisTask, BatchGenerationTask, + MCPPlugin +) +from app.config import settings + +# 创建日志目录 +log_dir = Path(__file__).parent.parent / "logs" +log_dir.mkdir(exist_ok=True) + +# 生成日志文件名(带时间戳) +log_filename = log_dir / f"migration_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log" + +# 设置日志 - 同时输出到控制台和文件 +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(levelname)s - %(message)s', + handlers=[ + logging.StreamHandler(), # 控制台输出 + logging.FileHandler(log_filename, encoding='utf-8') # 文件输出 + ] +) +logger = logging.getLogger(__name__) +logger.info(f"📝 日志文件: {log_filename}") + + +class SQLiteToPostgresMigrator: + """SQLite到PostgreSQL的数据迁移器""" + + def __init__(self, sqlite_dir: Path, target_user_id: str): + """ + 初始化迁移器 + + Args: + sqlite_dir: SQLite数据库文件目录 + target_user_id: 目标用户ID(迁移后的数据归属) + """ + self.sqlite_dir = sqlite_dir + self.target_user_id = target_user_id + self.sqlite_files = list(sqlite_dir.glob("ai_story_user_*.db")) + + # PostgreSQL连接 + if "postgresql" not in settings.database_url: + raise ValueError("DATABASE_URL必须配置为PostgreSQL") + + self.pg_engine = create_async_engine( + settings.database_url, + echo=False, + pool_pre_ping=True + ) + + self.pg_session_maker = async_sessionmaker( + self.pg_engine, + class_=AsyncSession, + expire_on_commit=False + ) + + async def migrate_all(self): + """迁移所有SQLite数据库""" + if not self.sqlite_files: + logger.warning(f"未找到SQLite数据库文件: {self.sqlite_dir}") + return + + logger.info(f"找到 {len(self.sqlite_files)} 个SQLite数据库文件") + + # 创建PostgreSQL表结构 + await self._create_tables() + + # 初始化关系类型数据 + await self._init_relationship_types() + + # 逐个迁移 + for sqlite_file in self.sqlite_files: + await self._migrate_single_db(sqlite_file) + + logger.info("✅ 所有数据迁移完成") + + async def _create_tables(self): + """创建PostgreSQL表结构""" + logger.info("创建PostgreSQL表结构...") + async with self.pg_engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + logger.info("✅ 表结构创建完成") + + async def _init_relationship_types(self): + """初始化关系类型数据""" + logger.info("初始化关系类型数据...") + + # 预置关系类型数据 + relationship_types = [ + # 家族关系 + {"name": "父亲", "category": "family", "reverse_name": "子女", "intimacy_range": "high", "icon": "👨"}, + {"name": "母亲", "category": "family", "reverse_name": "子女", "intimacy_range": "high", "icon": "👩"}, + {"name": "兄弟", "category": "family", "reverse_name": "兄弟", "intimacy_range": "high", "icon": "👬"}, + {"name": "姐妹", "category": "family", "reverse_name": "姐妹", "intimacy_range": "high", "icon": "👭"}, + {"name": "子女", "category": "family", "reverse_name": "父母", "intimacy_range": "high", "icon": "👶"}, + {"name": "配偶", "category": "family", "reverse_name": "配偶", "intimacy_range": "high", "icon": "💑"}, + {"name": "恋人", "category": "family", "reverse_name": "恋人", "intimacy_range": "high", "icon": "💕"}, + + # 社交关系 + {"name": "师父", "category": "social", "reverse_name": "徒弟", "intimacy_range": "high", "icon": "🎓"}, + {"name": "徒弟", "category": "social", "reverse_name": "师父", "intimacy_range": "high", "icon": "📚"}, + {"name": "朋友", "category": "social", "reverse_name": "朋友", "intimacy_range": "medium", "icon": "🤝"}, + {"name": "同学", "category": "social", "reverse_name": "同学", "intimacy_range": "medium", "icon": "🎒"}, + {"name": "邻居", "category": "social", "reverse_name": "邻居", "intimacy_range": "low", "icon": "🏘️"}, + {"name": "知己", "category": "social", "reverse_name": "知己", "intimacy_range": "high", "icon": "💙"}, + + # 职业关系 + {"name": "上司", "category": "professional", "reverse_name": "下属", "intimacy_range": "low", "icon": "👔"}, + {"name": "下属", "category": "professional", "reverse_name": "上司", "intimacy_range": "low", "icon": "💼"}, + {"name": "同事", "category": "professional", "reverse_name": "同事", "intimacy_range": "medium", "icon": "🤵"}, + {"name": "合作伙伴", "category": "professional", "reverse_name": "合作伙伴", "intimacy_range": "medium", "icon": "🤜🤛"}, + + # 敌对关系 + {"name": "敌人", "category": "hostile", "reverse_name": "敌人", "intimacy_range": "low", "icon": "⚔️"}, + {"name": "仇人", "category": "hostile", "reverse_name": "仇人", "intimacy_range": "low", "icon": "💢"}, + {"name": "竞争对手", "category": "hostile", "reverse_name": "竞争对手", "intimacy_range": "low", "icon": "🎯"}, + {"name": "宿敌", "category": "hostile", "reverse_name": "宿敌", "intimacy_range": "low", "icon": "⚡"}, + ] + + try: + async with self.pg_session_maker() as session: + # 检查是否已经有数据 + result = await session.execute(select(RelationshipType)) + existing = result.scalars().first() + + if existing: + logger.info("关系类型数据已存在,跳过初始化") + return + + # 插入预置数据 + logger.info("开始插入关系类型数据...") + for rt_data in relationship_types: + relationship_type = RelationshipType(**rt_data) + session.add(relationship_type) + + await session.commit() + logger.info(f"✅ 成功插入 {len(relationship_types)} 条关系类型数据") + + except Exception as e: + logger.error(f"初始化关系类型数据失败: {str(e)}", exc_info=True) + # 不抛出异常,继续迁移流程 + logger.warning("关系类型初始化失败,将跳过有外键依赖的记录") + + async def _migrate_single_db(self, sqlite_file: Path): + """迁移单个SQLite数据库""" + # 从文件名提取user_id + filename = sqlite_file.stem # ai_story_user_xxx + if filename.startswith("ai_story_user_"): + user_id = filename.replace("ai_story_user_", "") + else: + user_id = self.target_user_id + + logger.info(f"\n{'='*60}") + logger.info(f"开始迁移: {sqlite_file.name} -> user_id: {user_id}") + logger.info(f"{'='*60}") + + # 创建SQLite连接 + sqlite_url = f"sqlite+aiosqlite:///{sqlite_file.absolute()}" + sqlite_engine = create_async_engine(sqlite_url, echo=False) + sqlite_session_maker = async_sessionmaker( + sqlite_engine, + class_=AsyncSession, + expire_on_commit=False + ) + + try: + # 迁移各个表 + async with sqlite_session_maker() as sqlite_session: + async with self.pg_session_maker() as pg_session: + # 按照依赖顺序迁移 + await self._migrate_table( + sqlite_session, pg_session, user_id, Settings, "设置" + ) + await self._migrate_table( + sqlite_session, pg_session, user_id, Project, "项目" + ) + await self._migrate_table( + sqlite_session, pg_session, user_id, Character, "角色" + ) + await self._migrate_table( + sqlite_session, pg_session, user_id, Outline, "大纲" + ) + await self._migrate_table( + sqlite_session, pg_session, user_id, Chapter, "章节" + ) + await self._migrate_table( + sqlite_session, pg_session, user_id, CharacterRelationship, "角色关系" + ) + await self._migrate_table( + sqlite_session, pg_session, user_id, Organization, "组织" + ) + await self._migrate_table( + sqlite_session, pg_session, user_id, OrganizationMember, "组织成员" + ) + await self._migrate_table( + sqlite_session, pg_session, user_id, GenerationHistory, "生成历史" + ) + await self._migrate_table( + sqlite_session, pg_session, user_id, WritingStyle, "写作风格" + ) + await self._migrate_table( + sqlite_session, pg_session, user_id, ProjectDefaultStyle, "项目默认风格" + ) + await self._migrate_table( + sqlite_session, pg_session, user_id, StoryMemory, "记忆" + ) + await self._migrate_table( + sqlite_session, pg_session, user_id, PlotAnalysis, "剧情分析" + ) + await self._migrate_table( + sqlite_session, pg_session, user_id, AnalysisTask, "分析任务" + ) + await self._migrate_table( + sqlite_session, pg_session, user_id, BatchGenerationTask, "批量生成任务" + ) + await self._migrate_table( + sqlite_session, pg_session, user_id, MCPPlugin, "MCP插件" + ) + + await pg_session.commit() + + logger.info(f"✅ {sqlite_file.name} 迁移完成") + + except Exception as e: + logger.error(f"❌ 迁移失败: {e}", exc_info=True) + finally: + await sqlite_engine.dispose() + + async def _migrate_table( + self, + sqlite_session: AsyncSession, + pg_session: AsyncSession, + user_id: str, + model_class, + table_name: str + ): + """迁移单个表的数据""" + try: + # 获取SQLite表中实际存在的列 + sqlite_table = model_class.__table__ + sqlite_conn = await sqlite_session.connection() + + # 查询SQLite表结构 + inspect_result = await sqlite_conn.execute( + text(f"PRAGMA table_info({sqlite_table.name})") + ) + sqlite_columns = {row[1] for row in inspect_result.fetchall()} # row[1]是列名 + + # 构建只包含SQLite中存在的列的查询 + available_columns = [ + c for c in model_class.__table__.columns + if c.name in sqlite_columns + ] + + if not available_columns: + logger.warning(f" ⚠️ {table_name}: 表结构不匹配,跳过") + return + + # 从SQLite读取数据(只查询存在的列) + result = await sqlite_session.execute( + select(*available_columns) + ) + records = result.all() + + if not records: + logger.info(f" - {table_name}: 无数据") + return + + # 为每条记录创建字典并添加user_id + migrated_count = 0 + skipped_count = 0 + + for record in records: + # 从查询结果构建字典 + record_dict = {} + for i, col in enumerate(available_columns): + record_dict[col.name] = record[i] + + # 添加user_id(如果PostgreSQL模型有这个字段但SQLite没有) + if hasattr(model_class, 'user_id') and 'user_id' not in record_dict: + record_dict['user_id'] = user_id + + # 验证字段长度(防止超长字段导致插入失败) + if not self._validate_field_lengths(model_class, record_dict, table_name): + skipped_count += 1 + record_id = record_dict.get('id', 'unknown') + logger.warning(f" ⚠️ [{table_name}] 跳过超长字段记录 ID={record_id}") + continue + + # 验证外键引用(针对有外键的表) + validation_result = await self._validate_foreign_keys(pg_session, model_class, record_dict) + if not validation_result: + skipped_count += 1 + record_id = record_dict.get('id', 'unknown') + logger.warning(f" ⚠️ [{table_name}] 跳过无效外键记录 ID={record_id}") + # 输出记录详情以便调试 + if model_class.__tablename__ == 'story_memories': + logger.warning(f" 记忆详情: project_id={record_dict.get('project_id')}, " + f"chapter_id={record_dict.get('chapter_id')}, " + f"type={record_dict.get('memory_type')}") + elif model_class.__tablename__ == 'character_relationships': + logger.warning(f" 关系详情: project_id={record_dict.get('project_id')}, " + f"from={record_dict.get('character_from_id')}, " + f"to={record_dict.get('character_to_id')}, " + f"type_id={record_dict.get('relationship_type_id')}") + elif model_class.__tablename__ == 'organizations': + logger.warning(f" 组织详情: project_id={record_dict.get('project_id')}, " + f"character_id={record_dict.get('character_id')}") + elif model_class.__tablename__ == 'organization_members': + logger.warning(f" 成员详情: org_id={record_dict.get('organization_id')}, " + f"character_id={record_dict.get('character_id')}") + elif model_class.__tablename__ == 'writing_styles': + logger.warning(f" 写作风格详情: project_id={record_dict.get('project_id')}, " + f"name={record_dict.get('name')}, " + f"style_type={record_dict.get('style_type')}") + elif model_class.__tablename__ == 'characters': + logger.warning(f" 角色详情: project_id={record_dict.get('project_id')}, " + f"name={record_dict.get('name')}, " + f"is_organization={record_dict.get('is_organization')}") + elif model_class.__tablename__ == 'outlines': + logger.warning(f" 大纲详情: project_id={record_dict.get('project_id')}, " + f"title={record_dict.get('title')}") + elif model_class.__tablename__ == 'chapters': + logger.warning(f" 章节详情: project_id={record_dict.get('project_id')}, " + f"title={record_dict.get('title')}, " + f"chapter_number={record_dict.get('chapter_number')}") + elif model_class.__tablename__ == 'generation_history': + logger.warning(f" 生成历史详情: project_id={record_dict.get('project_id')}, " + f"chapter_id={record_dict.get('chapter_id')}, " + f"model={record_dict.get('model')}") + elif model_class.__tablename__ == 'plot_analysis': + logger.warning(f" 剧情分析详情: project_id={record_dict.get('project_id')}, " + f"chapter_id={record_dict.get('chapter_id')}, " + f"plot_stage={record_dict.get('plot_stage')}") + elif model_class.__tablename__ == 'analysis_tasks': + logger.warning(f" 分析任务详情: chapter_id={record_dict.get('chapter_id')}, " + f"project_id={record_dict.get('project_id')}, " + f"status={record_dict.get('status')}") + elif model_class.__tablename__ == 'batch_generation_tasks': + logger.warning(f" 批量生成任务详情: project_id={record_dict.get('project_id')}, " + f"status={record_dict.get('status')}, " + f"completed={record_dict.get('completed_chapters')}/{record_dict.get('total_chapters')}") + elif model_class.__tablename__ == 'project_default_styles': + logger.warning(f" 项目默认风格详情: project_id={record_dict.get('project_id')}, " + f"style_id={record_dict.get('style_id')}") + continue + + # 检查记录是否已存在(避免主键冲突) + record_id = record_dict.get('id') + if record_id and await self._record_exists(pg_session, model_class, record_id): + skipped_count += 1 + logger.debug(f" 跳过已存在的记录: {record_id}") + continue + + # 创建新记录 + try: + new_record = model_class(**record_dict) + pg_session.add(new_record) + migrated_count += 1 + except Exception as e: + logger.warning(f" ⚠️ 跳过无效记录: {str(e)[:100]}") + skipped_count += 1 + continue + + await pg_session.flush() + + if skipped_count > 0: + logger.info(f" ✅ {table_name}: {migrated_count} 条记录(跳过 {skipped_count} 条无效记录)") + else: + logger.info(f" ✅ {table_name}: {migrated_count} 条记录") + + except Exception as e: + logger.error(f" ❌ {table_name} 迁移失败: {e}") + raise + + async def _record_exists( + self, + pg_session: AsyncSession, + model_class, + record_id: Any + ) -> bool: + """ + 检查记录是否已存在 + + Args: + pg_session: PostgreSQL会话 + model_class: 模型类 + record_id: 记录ID + + Returns: + bool: 记录是否存在 + """ + try: + # 获取主键列 + pk_column = list(model_class.__table__.primary_key.columns)[0] + result = await pg_session.execute( + select(pk_column).where(pk_column == record_id) + ) + return result.scalar_one_or_none() is not None + except Exception: + return False + + async def _validate_foreign_keys( + self, + pg_session: AsyncSession, + model_class, + record_dict: Dict[str, Any] + ) -> bool: + """ + 验证记录的外键是否有效 + + Args: + pg_session: PostgreSQL会话 + model_class: 模型类 + record_dict: 记录字典 + + Returns: + bool: 外键是否全部有效 + """ + from app.models import Character, Project, Chapter + + # 使用no_autoflush防止过早flush + with pg_session.no_autoflush: + # 针对StoryMemory表验证外键 + if model_class.__tablename__ == 'story_memories': + # 验证project_id + project_id = record_dict.get('project_id') + if project_id: + result = await pg_session.execute( + select(Project.id).where(Project.id == project_id) + ) + if not result.scalar_one_or_none(): + logger.warning(f" ❌ [记忆] 无效的project_id: {project_id}") + return False + + # 验证chapter_id(可选) + chapter_id = record_dict.get('chapter_id') + if chapter_id: + result = await pg_session.execute( + select(Chapter.id).where(Chapter.id == chapter_id) + ) + if not result.scalar_one_or_none(): + logger.warning(f" ❌ [记忆] 无效的chapter_id: {chapter_id}") + return False + + # 针对CharacterRelationship表验证外键 + elif model_class.__tablename__ == 'character_relationships': + # 验证project_id + project_id = record_dict.get('project_id') + if project_id: + result = await pg_session.execute( + select(Project.id).where(Project.id == project_id) + ) + if not result.scalar_one_or_none(): + logger.warning(f" ❌ 无效的project_id: {project_id}") + return False + + # 验证character_from_id + char_from_id = record_dict.get('character_from_id') + if char_from_id: + result = await pg_session.execute( + select(Character.id).where(Character.id == char_from_id) + ) + if not result.scalar_one_or_none(): + logger.warning(f" ❌ 无效的character_from_id: {char_from_id}") + return False + + # 验证character_to_id + char_to_id = record_dict.get('character_to_id') + if char_to_id: + result = await pg_session.execute( + select(Character.id).where(Character.id == char_to_id) + ) + if not result.scalar_one_or_none(): + logger.warning(f" ❌ 无效的character_to_id: {char_to_id}") + return False + + # 验证relationship_type_id + rel_type_id = record_dict.get('relationship_type_id') + if rel_type_id: + result = await pg_session.execute( + select(RelationshipType.id).where(RelationshipType.id == rel_type_id) + ) + if not result.scalar_one_or_none(): + logger.warning(f" ❌ 无效的relationship_type_id: {rel_type_id}") + return False + + # 针对Organization表验证外键 + elif model_class.__tablename__ == 'organizations': + # 验证character_id + char_id = record_dict.get('character_id') + if char_id: + result = await pg_session.execute( + select(Character.id).where(Character.id == char_id) + ) + if not result.scalar_one_or_none(): + logger.warning(f" ❌ [组织] 无效的character_id: {char_id}") + return False + + # 针对OrganizationMember表验证外键 + elif model_class.__tablename__ == 'organization_members': + from app.models import Organization + + # 验证organization_id + org_id = record_dict.get('organization_id') + if org_id: + result = await pg_session.execute( + select(Organization.id).where(Organization.id == org_id) + ) + if not result.scalar_one_or_none(): + logger.warning(f" ❌ 无效的organization_id: {org_id}") + return False + + # 验证character_id + char_id = record_dict.get('character_id') + if char_id: + result = await pg_session.execute( + select(Character.id).where(Character.id == char_id) + ) + if not result.scalar_one_or_none(): + logger.warning(f" ❌ [组织成员] 无效的character_id: {char_id}") + return False + + # 针对Character表验证外键 + elif model_class.__tablename__ == 'characters': + # 验证project_id + project_id = record_dict.get('project_id') + if project_id: + result = await pg_session.execute( + select(Project.id).where(Project.id == project_id) + ) + if not result.scalar_one_or_none(): + logger.warning(f" ❌ [角色] 无效的project_id: {project_id}") + return False + + # 针对Outline表验证外键 + elif model_class.__tablename__ == 'outlines': + # 验证project_id + project_id = record_dict.get('project_id') + if project_id: + result = await pg_session.execute( + select(Project.id).where(Project.id == project_id) + ) + if not result.scalar_one_or_none(): + logger.warning(f" ❌ [大纲] 无效的project_id: {project_id}") + return False + + # 针对Chapter表验证外键 + elif model_class.__tablename__ == 'chapters': + # 验证project_id + project_id = record_dict.get('project_id') + if project_id: + result = await pg_session.execute( + select(Project.id).where(Project.id == project_id) + ) + if not result.scalar_one_or_none(): + logger.warning(f" ❌ [章节] 无效的project_id: {project_id}") + return False + + # 针对WritingStyle表验证外键 + elif model_class.__tablename__ == 'writing_styles': + # 验证project_id(可选) + project_id = record_dict.get('project_id') + if project_id: + result = await pg_session.execute( + select(Project.id).where(Project.id == project_id) + ) + if not result.scalar_one_or_none(): + logger.warning(f" ❌ [写作风格] 无效的project_id: {project_id}") + return False + + # 针对GenerationHistory表验证外键 + elif model_class.__tablename__ == 'generation_history': + # 验证project_id + project_id = record_dict.get('project_id') + if project_id: + result = await pg_session.execute( + select(Project.id).where(Project.id == project_id) + ) + if not result.scalar_one_or_none(): + logger.warning(f" ❌ [生成历史] 无效的project_id: {project_id}") + return False + + # 验证chapter_id(可选) + chapter_id = record_dict.get('chapter_id') + if chapter_id: + result = await pg_session.execute( + select(Chapter.id).where(Chapter.id == chapter_id) + ) + if not result.scalar_one_or_none(): + logger.warning(f" ❌ [生成历史] 无效的chapter_id: {chapter_id}") + return False + + # 针对PlotAnalysis表验证外键 + elif model_class.__tablename__ == 'plot_analysis': + # 验证project_id(必需) + project_id = record_dict.get('project_id') + if project_id: + result = await pg_session.execute( + select(Project.id).where(Project.id == project_id) + ) + if not result.scalar_one_or_none(): + logger.warning(f" ❌ [剧情分析] 无效的project_id: {project_id}") + return False + + # 验证chapter_id(必需) + chapter_id = record_dict.get('chapter_id') + if chapter_id: + result = await pg_session.execute( + select(Chapter.id).where(Chapter.id == chapter_id) + ) + if not result.scalar_one_or_none(): + logger.warning(f" ❌ [剧情分析] 无效的chapter_id: {chapter_id}") + return False + + # 针对AnalysisTask表验证外键 + elif model_class.__tablename__ == 'analysis_tasks': + # 验证chapter_id(必需) + chapter_id = record_dict.get('chapter_id') + if chapter_id: + result = await pg_session.execute( + select(Chapter.id).where(Chapter.id == chapter_id) + ) + if not result.scalar_one_or_none(): + logger.warning(f" ❌ [分析任务] 无效的chapter_id: {chapter_id}") + return False + + # 验证project_id + project_id = record_dict.get('project_id') + if project_id: + result = await pg_session.execute( + select(Project.id).where(Project.id == project_id) + ) + if not result.scalar_one_or_none(): + logger.warning(f" ❌ [分析任务] 无效的project_id: {project_id}") + return False + + # 针对BatchGenerationTask表验证外键 + elif model_class.__tablename__ == 'batch_generation_tasks': + # 验证project_id(必需) + project_id = record_dict.get('project_id') + if project_id: + result = await pg_session.execute( + select(Project.id).where(Project.id == project_id) + ) + if not result.scalar_one_or_none(): + logger.warning(f" ❌ [批量生成任务] 无效的project_id: {project_id}") + return False + + # 针对ProjectDefaultStyle表验证外键 + elif model_class.__tablename__ == 'project_default_styles': + from app.models import WritingStyle + + # 验证project_id(必需) + project_id = record_dict.get('project_id') + if project_id: + result = await pg_session.execute( + select(Project.id).where(Project.id == project_id) + ) + if not result.scalar_one_or_none(): + logger.warning(f" ❌ [项目默认风格] 无效的project_id: {project_id}") + return False + + # 验证style_id(必需) + style_id = record_dict.get('style_id') + if style_id: + result = await pg_session.execute( + select(WritingStyle.id).where(WritingStyle.id == style_id) + ) + if not result.scalar_one_or_none(): + logger.warning(f" ❌ [项目默认风格] 无效的style_id: {style_id}") + return False + + return True + + def _validate_field_lengths( + self, + model_class, + record_dict: Dict[str, Any], + table_name: str + ) -> bool: + """ + 验证记录的字段长度是否符合模型定义 + + Args: + model_class: 模型类 + record_dict: 记录字典 + table_name: 表名(用于日志) + + Returns: + bool: 字段长度是否全部有效 + """ + from sqlalchemy import String + + # 检查所有字符串类型字段 + for column in model_class.__table__.columns: + # 只检查有长度限制的String类型字段 + if isinstance(column.type, String) and column.type.length: + field_name = column.name + field_value = record_dict.get(field_name) + max_length = column.type.length + + # 如果字段有值且超过最大长度 + if field_value and isinstance(field_value, str) and len(field_value) > max_length: + logger.warning( + f" ❌ [{table_name}] 字段 '{field_name}' 超长: " + f"{len(field_value)} > {max_length} (截断了 {len(field_value) - max_length} 字符)" + ) + # 对于敏感字段如API密钥,记录部分内容 + if field_name in ['api_key', 'api_base_url']: + preview = field_value[:50] + "..." + field_value[-20:] if len(field_value) > 70 else field_value + logger.warning(f" 值预览: {preview}") + return False + + return True + + async def cleanup(self): + """清理资源""" + await self.pg_engine.dispose() + + +async def main(): + """主函数""" + banner = """ +╔══════════════════════════════════════════════════════════════╗ +║ SQLite to PostgreSQL 数据迁移工具 ║ +║ ║ +║ 此工具将SQLite数据迁移到PostgreSQL ║ +║ 请确保: ║ +║ 1. PostgreSQL数据库已创建 ║ +║ 2. .env中DATABASE_URL已配置为PostgreSQL ║ +║ 3. SQLite数据文件存在 ║ +╚══════════════════════════════════════════════════════════════╝ + """ + print(banner) + logger.info(banner) + + # 配置 + sqlite_dir = Path(__file__).parent.parent / "data" + target_user_id = "migrated_user" # 默认用户ID + + config_info = f""" +配置信息: + SQLite目录: {sqlite_dir} + PostgreSQL: {settings.database_url} + 目标用户ID: {target_user_id} + 日志文件: {log_filename} +""" + print(config_info) + logger.info(config_info) + + # 确认 + response = input("是否继续迁移? (yes/no): ") + if response.lower() not in ['yes', 'y']: + print("已取消迁移") + return + + # 执行迁移 + migrator = SQLiteToPostgresMigrator(sqlite_dir, target_user_id) + + try: + await migrator.migrate_all() + success_msg = """ +🎉 数据迁移成功完成! + +下一步: + 1. 测试应用功能 + 2. 验证数据完整性 + 3. 备份SQLite文件后可删除 + +详细日志已保存到: {} + """.format(log_filename) + print(success_msg) + logger.info(success_msg) + + except Exception as e: + error_msg = f"\n❌ 迁移失败: {e}\n详细日志已保存到: {log_filename}" + print(error_msg) + logger.error("迁移过程出错", exc_info=True) + + finally: + await migrator.cleanup() + logger.info(f"🔒 数据库连接已关闭,日志文件: {log_filename}") + + +if __name__ == "__main__": + asyncio.run(main()) \ No newline at end of file diff --git a/backend/scripts/setup_postgres.py b/backend/scripts/setup_postgres.py new file mode 100644 index 0000000..9ae6a32 --- /dev/null +++ b/backend/scripts/setup_postgres.py @@ -0,0 +1,408 @@ +#!/usr/bin/env python3 +""" +PostgreSQL 数据库自动设置脚本 + +功能: +1. 自动连接到PostgreSQL服务器 +2. 创建数据库和用户 +3. 设置权限 +4. 初始化表结构 + +使用方法: + python backend/scripts/setup_postgres.py + +前置条件: + - PostgreSQL服务已安装并运行 + - 知道PostgreSQL的超级用户密码(通常是postgres用户) +""" +import sys +import asyncio +from pathlib import Path +from getpass import getpass +import logging + +# 添加项目根目录到Python路径 +sys.path.insert(0, str(Path(__file__).parent.parent)) + +try: + import psycopg2 + from psycopg2 import sql + from psycopg2.extensions import ISOLATION_LEVEL_AUTOCOMMIT +except ImportError: + print("❌ 缺少psycopg2依赖,正在安装...") + import subprocess + subprocess.check_call([sys.executable, "-m", "pip", "install", "psycopg2-binary"]) + import psycopg2 + from psycopg2 import sql + from psycopg2.extensions import ISOLATION_LEVEL_AUTOCOMMIT + +from app.database import init_db + +# 设置日志 +logging.basicConfig( + level=logging.INFO, + format='%(message)s' +) +logger = logging.getLogger(__name__) + + +class PostgreSQLSetup: + """PostgreSQL数据库自动设置""" + + def __init__( + self, + host: str = "localhost", + port: int = 5432, + admin_user: str = "postgres", + admin_password: str = None, + db_name: str = "mumuai_novel", + db_user: str = "mumuai", + db_password: str = "123456" + ): + """ + 初始化设置参数 + + Args: + host: PostgreSQL主机地址 + port: PostgreSQL端口 + admin_user: 管理员用户名 + admin_password: 管理员密码 + db_name: 要创建的数据库名 + db_user: 要创建的用户名 + db_password: 用户密码 + """ + self.host = host + self.port = port + self.admin_user = admin_user + self.admin_password = admin_password + self.db_name = db_name + self.db_user = db_user + self.db_password = db_password + self.conn = None + + def connect_as_admin(self) -> bool: + """连接到PostgreSQL(使用管理员权限)""" + try: + logger.info(f"🔌 连接到 PostgreSQL ({self.host}:{self.port})...") + + self.conn = psycopg2.connect( + host=self.host, + port=self.port, + user=self.admin_user, + password=self.admin_password, + database="postgres" # 连接到默认数据库 + ) + self.conn.set_isolation_level(ISOLATION_LEVEL_AUTOCOMMIT) + + logger.info(f"✅ 已连接到 PostgreSQL") + return True + + except psycopg2.OperationalError as e: + logger.error(f"❌ 连接失败: {e}") + logger.error("\n可能的原因:") + logger.error("1. PostgreSQL服务未启动") + logger.error("2. 管理员密码错误") + logger.error("3. 主机地址或端口错误") + logger.error("4. pg_hba.conf配置不允许连接") + return False + + def database_exists(self) -> bool: + """检查数据库是否存在""" + cursor = self.conn.cursor() + cursor.execute( + "SELECT 1 FROM pg_database WHERE datname = %s", + (self.db_name,) + ) + exists = cursor.fetchone() is not None + cursor.close() + return exists + + def user_exists(self) -> bool: + """检查用户是否存在""" + cursor = self.conn.cursor() + cursor.execute( + "SELECT 1 FROM pg_user WHERE usename = %s", + (self.db_user,) + ) + exists = cursor.fetchone() is not None + cursor.close() + return exists + + def create_user(self) -> bool: + """创建数据库用户""" + try: + if self.user_exists(): + logger.info(f"ℹ️ 用户 '{self.db_user}' 已存在") + + # 询问是否重置密码 + response = input(f"是否重置用户 '{self.db_user}' 的密码? (yes/no): ") + if response.lower() in ['yes', 'y']: + cursor = self.conn.cursor() + cursor.execute( + sql.SQL("ALTER USER {} WITH PASSWORD %s").format( + sql.Identifier(self.db_user) + ), + (self.db_password,) + ) + cursor.close() + logger.info(f"✅ 用户密码已更新") + + return True + + logger.info(f"👤 创建用户 '{self.db_user}'...") + cursor = self.conn.cursor() + cursor.execute( + sql.SQL("CREATE USER {} WITH PASSWORD %s").format( + sql.Identifier(self.db_user) + ), + (self.db_password,) + ) + cursor.close() + logger.info(f"✅ 用户创建成功") + return True + + except Exception as e: + logger.error(f"❌ 创建用户失败: {e}") + return False + + def create_database(self) -> bool: + """创建数据库""" + try: + if self.database_exists(): + logger.info(f"ℹ️ 数据库 '{self.db_name}' 已存在") + + # 询问是否删除重建 + response = input(f"是否删除并重建数据库 '{self.db_name}'? (yes/no): ") + if response.lower() in ['yes', 'y']: + logger.warning(f"⚠️ 删除数据库 '{self.db_name}'...") + cursor = self.conn.cursor() + # 断开所有连接 + cursor.execute( + sql.SQL(""" + SELECT pg_terminate_backend(pg_stat_activity.pid) + FROM pg_stat_activity + WHERE pg_stat_activity.datname = %s + AND pid <> pg_backend_pid() + """), + (self.db_name,) + ) + cursor.execute( + sql.SQL("DROP DATABASE {}").format( + sql.Identifier(self.db_name) + ) + ) + cursor.close() + logger.info(f"✅ 数据库已删除") + else: + return True + + logger.info(f"🗄️ 创建数据库 '{self.db_name}'...") + cursor = self.conn.cursor() + cursor.execute( + sql.SQL("CREATE DATABASE {} OWNER {}").format( + sql.Identifier(self.db_name), + sql.Identifier(self.db_user) + ) + ) + cursor.close() + logger.info(f"✅ 数据库创建成功") + return True + + except Exception as e: + logger.error(f"❌ 创建数据库失败: {e}") + return False + + def grant_privileges(self) -> bool: + """授予用户权限""" + try: + logger.info(f"🔐 授予用户权限...") + cursor = self.conn.cursor() + + # 授予数据库所有权限 + cursor.execute( + sql.SQL("GRANT ALL PRIVILEGES ON DATABASE {} TO {}").format( + sql.Identifier(self.db_name), + sql.Identifier(self.db_user) + ) + ) + + cursor.close() + logger.info(f"✅ 权限授予成功") + return True + + except Exception as e: + logger.error(f"❌ 授予权限失败: {e}") + return False + + def update_env_file(self) -> bool: + """更新.env文件""" + try: + env_file = Path(__file__).parent.parent / ".env" + + database_url = ( + f"postgresql+asyncpg://{self.db_user}:{self.db_password}" + f"@{self.host}:{self.port}/{self.db_name}" + ) + + if env_file.exists(): + # 读取现有内容 + with open(env_file, 'r', encoding='utf-8') as f: + lines = f.readlines() + + # 更新DATABASE_URL + updated = False + for i, line in enumerate(lines): + if line.startswith('DATABASE_URL='): + lines[i] = f"DATABASE_URL={database_url}\n" + updated = True + break + + if not updated: + lines.append(f"\nDATABASE_URL={database_url}\n") + + # 写回文件 + with open(env_file, 'w', encoding='utf-8') as f: + f.writelines(lines) + else: + # 创建新文件 + with open(env_file, 'w', encoding='utf-8') as f: + f.write(f"DATABASE_URL={database_url}\n") + + logger.info(f"✅ .env 文件已更新") + logger.info(f" DATABASE_URL={database_url}") + return True + + except Exception as e: + logger.error(f"❌ 更新.env文件失败: {e}") + return False + + async def initialize_tables(self) -> bool: + """初始化数据库表结构""" + try: + logger.info(f"📋 初始化数据库表结构...") + await init_db('system') + logger.info(f"✅ 表结构初始化成功") + return True + except Exception as e: + logger.error(f"❌ 初始化表结构失败: {e}") + return False + + def close(self): + """关闭数据库连接""" + if self.conn: + self.conn.close() + logger.info(f"🔌 已断开连接") + + async def setup(self) -> bool: + """执行完整设置流程""" + try: + # 1. 连接 + if not self.connect_as_admin(): + return False + + # 2. 创建用户 + if not self.create_user(): + return False + + # 3. 创建数据库 + if not self.create_database(): + return False + + # 4. 授予权限 + if not self.grant_privileges(): + return False + + # 5. 更新配置 + if not self.update_env_file(): + return False + + # 6. 关闭管理员连接 + self.close() + + # 7. 初始化表结构 + if not await self.initialize_tables(): + return False + + return True + + except Exception as e: + logger.error(f"❌ 设置过程出错: {e}") + return False + finally: + if self.conn: + self.close() + + +async def main(): + """主函数""" + print(""" +╔═══════════════════════════════════════════════════════════════╗ +║ PostgreSQL 数据库自动设置工具 ║ +║ ║ +║ 此工具将自动完成: ║ +║ 1. 连接到PostgreSQL服务器 ║ +║ 2. 创建数据库和用户 ║ +║ 3. 设置权限 ║ +║ 4. 初始化表结构 ║ +║ 5. 更新.env配置文件 ║ +╚═══════════════════════════════════════════════════════════════╝ + """) + + # 获取配置 + print("请输入PostgreSQL配置信息:\n") + + host = input("主机地址 [localhost]: ").strip() or "localhost" + port = input("端口 [5432]: ").strip() or "5432" + port = int(port) + + admin_user = input("管理员用户名 [postgres]: ").strip() or "postgres" + admin_password = getpass(f"管理员密码: ") + + print("\n请输入要创建的数据库信息:\n") + db_name = input("数据库名 [mumuai_novel]: ").strip() or "mumuai_novel" + db_user = input("数据库用户名 [mumuai]: ").strip() or "mumuai" + db_password = getpass("数据库用户密码 [mumuai123]: ") or "mumuai123" + + print(f"\n{'='*60}") + print(f"配置摘要:") + print(f" 服务器: {host}:{port}") + print(f" 数据库: {db_name}") + print(f" 用户: {db_user}") + print(f"{'='*60}\n") + + response = input("确认开始设置? (yes/no): ") + if response.lower() not in ['yes', 'y']: + print("已取消设置") + return + + # 执行设置 + setup = PostgreSQLSetup( + host=host, + port=port, + admin_user=admin_user, + admin_password=admin_password, + db_name=db_name, + db_user=db_user, + db_password=db_password + ) + + print(f"\n{'='*60}") + success = await setup.setup() + print(f"{'='*60}\n") + + if success: + print("🎉 PostgreSQL设置完成!\n") + print("下一步:") + print("1. 启动应用: python -m app.main") + print("2. 访问: http://localhost:8000") + print("3. 查看API文档: http://localhost:8000/docs") + else: + print("❌ 设置过程中出现错误,请检查日志") + print("\n故障排查:") + print("1. 确认PostgreSQL服务正在运行") + print("2. 检查管理员用户名和密码") + print("3. 查看PostgreSQL日志") + + +if __name__ == "__main__": + asyncio.run(main()) \ No newline at end of file diff --git a/docker-compose.yml b/docker-compose.yml index a312b71..ce68e2b 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -1,15 +1,67 @@ services: + postgres: + image: postgres:18-alpine + container_name: mumuainovel-postgres + environment: + POSTGRES_DB: mumuai_novel + POSTGRES_USER: mumuai + POSTGRES_PASSWORD: ${POSTGRES_PASSWORD:-mumuai_password_2024} + POSTGRES_INITDB_ARGS: "--encoding=UTF8 --locale=C" + TZ: Asia/Shanghai + volumes: + - postgres_data:/var/lib/postgresql/data + - ./backend/scripts/init_postgres.sql:/docker-entrypoint-initdb.d/init.sql:ro + ports: + - "5432:5432" + restart: unless-stopped + healthcheck: + test: ["CMD-SHELL", "pg_isready -U mumuai -d mumuai_novel"] + interval: 10s + timeout: 5s + retries: 5 + start_period: 10s + networks: + - ai-story-network + command: + - postgres + - -c + - max_connections=200 + - -c + - shared_buffers=256MB + - -c + - effective_cache_size=1GB + - -c + - maintenance_work_mem=64MB + - -c + - checkpoint_completion_target=0.9 + - -c + - wal_buffers=16MB + - -c + - default_statistics_target=100 + - -c + - random_page_cost=1.1 + - -c + - effective_io_concurrency=200 + - -c + - work_mem=4MB + - -c + - min_wal_size=1GB + - -c + - max_wal_size=4GB + mumuainovel: build: context: . dockerfile: Dockerfile image: mumujie/mumuainovel:latest container_name: mumuainovel + depends_on: + postgres: + condition: service_healthy ports: - "8000:8000" volumes: - # 持久化数据库和日志 - - ./data:/app/data + # 持久化日志 - ./logs:/app/logs # 挂载环境变量文件(可选) - ./.env:/app/.env:ro @@ -21,8 +73,20 @@ services: - APP_PORT=8000 - DEBUG=false - # 重要:环境变量会从 .env 文件自动加载 - # 也可以在这里显式设置,优先级:此处设置 > .env 文件 + # 数据库配置(使用PostgreSQL) + - DATABASE_URL=postgresql+asyncpg://mumuai:${POSTGRES_PASSWORD:-mumuai_password_2024}@postgres:5432/mumuai_novel + + # PostgreSQL连接池配置 + - DATABASE_POOL_SIZE=30 + - DATABASE_MAX_OVERFLOW=20 + - DATABASE_POOL_TIMEOUT=60 + - DATABASE_POOL_RECYCLE=1800 + - DATABASE_POOL_PRE_PING=True + - DATABASE_POOL_USE_LIFO=True + + - HTTP_PROXY=http://172.16.66.175:7890 + - HTTPS_PROXY=http://172.16.66.175:7890 + - NO_PROXY=localhost,127.0.0.1 # AI服务配置(建议在 .env 文件中设置) # - OPENAI_API_KEY=${OPENAI_API_KEY} @@ -41,10 +105,15 @@ services: interval: 30s timeout: 10s retries: 3 - start_period: 10s + start_period: 30s networks: - ai-story-network +volumes: + postgres_data: + driver: local + networks: ai-story-network: driver: bridge + diff --git a/frontend/src/components/ChapterAnalysis.tsx b/frontend/src/components/ChapterAnalysis.tsx index 64c59ed..559f31d 100644 --- a/frontend/src/components/ChapterAnalysis.tsx +++ b/frontend/src/components/ChapterAnalysis.tsx @@ -67,6 +67,14 @@ export default function ChapterAnalysis({ chapterId, visible, onClose }: Chapter } const taskData: AnalysisTask = await response.json(); + + // 如果状态为 none(无任务),设置 task 为 null,让前端显示"开始分析"按钮 + if (taskData.status === 'none' || !taskData.has_task) { + setTask(null); + setError(null); // 清除错误,这不是错误状态 + return; + } + setTask(taskData); if (taskData.status === 'completed') { diff --git a/frontend/src/pages/Chapters.tsx b/frontend/src/pages/Chapters.tsx index e427630..f4d8af0 100644 --- a/frontend/src/pages/Chapters.tsx +++ b/frontend/src/pages/Chapters.tsx @@ -321,6 +321,7 @@ export default function Chapters() { setAnalysisTasksMap(prev => ({ ...prev, [editingId]: { + has_task: true, task_id: taskId, chapter_id: editingId, status: 'pending', diff --git a/frontend/src/types/index.ts b/frontend/src/types/index.ts index 3b25137..4b654ff 100644 --- a/frontend/src/types/index.ts +++ b/frontend/src/types/index.ts @@ -390,14 +390,16 @@ export interface ApiError { // 章节分析任务相关类型 export interface AnalysisTask { - task_id: string; + has_task: boolean; + task_id: string | null; chapter_id: string; - status: 'pending' | 'running' | 'completed' | 'failed'; + status: 'pending' | 'running' | 'completed' | 'failed' | 'none'; progress: number; - error_message?: string; - created_at?: string; - started_at?: string; - completed_at?: string; + error_message?: string | null; + auto_recovered?: boolean; + created_at?: string | null; + started_at?: string | null; + completed_at?: string | null; } // 分析结果 - 钩子