update:1.切换数据库PostgreSQL

This commit is contained in:
xiamuceer
2025-11-10 21:16:55 +08:00
parent dfea51cfa4
commit 20d9319a16
31 changed files with 2526 additions and 256 deletions
+1
View File
@@ -39,6 +39,7 @@ Thumbs.db
# 数据库文件(不包含在镜像中) # 数据库文件(不包含在镜像中)
data/*.db data/*.db
backend/data/*.db backend/data/*.db
postgres_data/
# ChromaDB数据库(不包含在镜像中,会在运行时生成) # ChromaDB数据库(不包含在镜像中,会在运行时生成)
backend/data/chroma_db/ backend/data/chroma_db/
+100 -24
View File
@@ -26,7 +26,7 @@
- 🌐 **世界观设定** - 构建完整的故事世界观和背景设定 - 🌐 **世界观设定** - 构建完整的故事世界观和背景设定
- 🔐 **多种登录方式** - 支持 LinuxDO OAuth 登录和本地账户登录 - 🔐 **多种登录方式** - 支持 LinuxDO OAuth 登录和本地账户登录
- 🐳 **Docker 部署** - 一键部署,开箱即用 - 🐳 **Docker 部署** - 一键部署,开箱即用
- 💾 **数据持久化** - 基于 SQLite 的本地数据存储,支持多用户隔离 - 💾 **数据持久化** - 支持 PostgreSQL 和 SQLite 双数据库,多用户数据隔离
- 🎨 **现代化 UI** - 基于 Ant Design 的美观界面,响应式设计 - 🎨 **现代化 UI** - 基于 Ant Design 的美观界面,响应式设计
@@ -36,7 +36,7 @@
- [ ] **灵感模式** - 提供创作灵感和点子生成功能 - [ ] **灵感模式** - 提供创作灵感和点子生成功能
- [✔] **自定义写作风格** - 支持自定义AI写作风格和语言风格 - [✔] **自定义写作风格** - 支持自定义AI写作风格和语言风格
- [ ] **支持数据导入导出** - 支持项目数据的导入和导出功能 - [✔] **支持数据导入导出** - 支持项目数据的导入和导出功能
- [ ] **添加prompt调整界面** - 提供可视化的prompt模板编辑和调整界面 - [ ] **添加prompt调整界面** - 提供可视化的prompt模板编辑和调整界面
- [✔] **开放章节内容字数限制** - 支持用户在生成章节内容时设置字数 @wyf007 - [✔] **开放章节内容字数限制** - 支持用户在生成章节内容时设置字数 @wyf007
- [ ] **设定追溯与矛盾检测** - 对大纲、世界观、角色档案中的设定支持悬停查看注释,显示相关章节来源和佐证原文;自动检测新章节与已有设定的矛盾(吃书),标记为"矛盾"设定并提供解决建议,当新设定解决矛盾后自动更新注释说明 @lulujiang - [ ] **设定追溯与矛盾检测** - 对大纲、世界观、角色档案中的设定支持悬停查看注释,显示相关章节来源和佐证原文;自动检测新章节与已有设定的矛盾(吃书),标记为"矛盾"设定并提供解决建议,当新设定解决矛盾后自动更新注释说明 @lulujiang
@@ -116,9 +116,77 @@ npm run build
## 🐳 部署方式 ## 🐳 部署方式
### Docker Compose 部署 ### Docker Compose 部署PostgreSQL
#### 使用 Docker Hub 镜像(推荐) **推荐生产环境使用**,提供更好的性能和并发支持。
#### 快速启动
```bash
# 1. 克隆项目
git clone https://github.com/xiamuceer-j/MuMuAINovel.git
cd MuMuAINovel
# 2. 配置环境变量
cp backend/.env.example .env
# 编辑 .env 文件,设置必要的配置:
# - POSTGRES_PASSWORD(数据库密码)
# - OPENAI_API_KEYAI服务密钥)
# - 其他可选配置
# 3. 启动服务(包含PostgreSQL
docker-compose up -d
# 4. 查看服务状态
docker-compose ps
# 5. 查看日志
docker-compose logs -f
# 6. 访问应用
# 打开浏览器访问 http://localhost:8000
```
#### 环境变量配置
创建 `.env` 文件并配置:
```bash
# PostgreSQL数据库密码(必须设置)
POSTGRES_PASSWORD=your_secure_password_here
# AI服务配置(必须设置)
OPENAI_API_KEY=your_openai_api_key
DEFAULT_AI_PROVIDER=openai
DEFAULT_MODEL=gpt-4
# 本地账户登录(可选)
LOCAL_AUTH_ENABLED=true
LOCAL_AUTH_USERNAME=admin
LOCAL_AUTH_PASSWORD=admin123
# 其他配置见 backend/.env.example
```
#### 服务说明
- **postgres**: PostgreSQL 18 数据库
- 端口:5432
- 数据持久化:`./postgres_data`
- 已优化配置,支持80-150并发用户
- **mumuainovel**: 主应用服务
- 端口:8000
- 自动等待数据库就绪
- 日志持久化:`./logs`
详细部署指南请参考:[Docker + PostgreSQL 部署文档](docs/docker-postgres-deployment.md)
### Docker Compose 部署(SQLite
适合个人使用或小团队,配置更简单。
#### 使用 Docker Hub 镜像
项目已发布到 Docker Hub,可直接拉取使用: 项目已发布到 Docker Hub,可直接拉取使用:
@@ -197,7 +265,11 @@ networks:
#### 2. 数据持久化 #### 2. 数据持久化
数据目录已通过 volume 挂载,数据不会丢失 **PostgreSQL部署**
- `./postgres_data`PostgreSQL 数据库文件
- `./logs`:应用日志文件
**SQLite部署**
- `./data`SQLite 数据库文件 - `./data`SQLite 数据库文件
- `./logs`:应用日志文件 - `./logs`:应用日志文件
@@ -238,11 +310,33 @@ services:
## ⚙️ 配置说明 ## ⚙️ 配置说明
### 数据库选择
项目支持两种数据库:
```bash
# .env 配置
DATABASE_URL=postgresql+asyncpg://mumuai:password@postgres:5432/mumuai_novel
```
```bash
# .env 配置
DATABASE_URL=sqlite+aiosqlite:///data/ai_story.db
```
### 环境变量 ### 环境变量
创建 `.env` 文件并配置以下变量: 创建 `.env` 文件并配置以下变量:
```bash ```bash
# ===== 数据库配置 =====
# PostgreSQL(生产环境推荐)
DATABASE_URL=postgresql+asyncpg://mumuai:password@postgres:5432/mumuai_novel
POSTGRES_PASSWORD=your_secure_password_here
# 或使用 SQLite(开发环境)
# DATABASE_URL=sqlite+aiosqlite:///data/ai_story.db
# ===== AI 服务配置(必填)===== # ===== AI 服务配置(必填)=====
# OpenAI 配置(支持官方API和中转API) # OpenAI 配置(支持官方API和中转API)
OPENAI_API_KEY=your_openai_key_here OPENAI_API_KEY=your_openai_key_here
@@ -325,24 +419,6 @@ OPENAI_API_KEY=sk-xxxxxxxxxxxxxxxx
OPENAI_BASE_URL=https://api.new-api.com/v1 OPENAI_BASE_URL=https://api.new-api.com/v1
``` ```
**API2D**
```bash
OPENAI_API_KEY=fk-xxxxxxxxxxxxxxxx
OPENAI_BASE_URL=https://api.api2d.com/v1
```
**OpenAI-SB**
```bash
OPENAI_API_KEY=sb-xxxxxxxxxxxxxxxx
OPENAI_BASE_URL=https://api.openai-sb.com/v1
```
**自建 One API / New API**
```bash
OPENAI_API_KEY=sk-xxxxxxxxxxxxxxxx
OPENAI_BASE_URL=https://your-domain.com/v1
```
##### 注意事项 ##### 注意事项
- ✅ 所有支持 OpenAI 接口格式的服务都可以使用 - ✅ 所有支持 OpenAI 接口格式的服务都可以使用
@@ -436,7 +512,7 @@ MuMuAINovel/
### 后端 ### 后端
- **框架**FastAPI 0.109.0 - **框架**FastAPI 0.109.0
- **数据库**SQLite + SQLAlchemy(异步) - **数据库**PostgreSQL / SQLite + SQLAlchemy(异步)
- **AI 集成**OpenAI、Anthropic、Google Gemini SDK - **AI 集成**OpenAI、Anthropic、Google Gemini SDK
- **认证**LinuxDO OAuth2、本地账户 - **认证**LinuxDO OAuth2、本地账户
- **日志**Python logging + 文件轮转 - **日志**Python logging + 文件轮转
+124 -38
View File
@@ -1,54 +1,140 @@
# AI服务配置 # ==========================================
# OpenAI配置 # MuMuAINovel 配置文件示例
OPENAI_API_KEY=your_openai_key_here # ==========================================
OPENAI_BASE_URL=https://api.openai.com/v1 # 复制此文件为 .env 并修改配置值
# cp .env.example .env
# Anthropic配置
ANTHROPIC_API_KEY=your_anthropic_key_here
ANTHROPIC_BASE_URL=https://api.anthropic.com
# 默认AI提供商:openai, gemini, anthropic
DEFAULT_AI_PROVIDER=openai
DEFAULT_MODEL=gpt-4.1
DEFAULT_TEMPERATURE=0.8
DEFAULT_MAX_TOKENS=32000
# ==========================================
# 应用配置 # 应用配置
# ==========================================
APP_NAME=MuMuAINovel APP_NAME=MuMuAINovel
APP_VERSION=1.0.0 APP_VERSION=1.0.0
APP_HOST=0.0.0.0 APP_HOST=0.0.0.0
APP_PORT=8000 APP_PORT=8000
DEBUG=true DEBUG=True
# ==========================================
# 数据库配置
# ==========================================
# 选项1: PostgreSQL(生产环境推荐)
# DATABASE_URL=postgresql+asyncpg://username:password@localhost:5432/database_name
# 示例:
# DATABASE_URL=postgresql+asyncpg://mumuai:your_password@localhost:5432/mumuai_novel
# 选项2: SQLite(开发环境,默认)
DATABASE_URL=sqlite+aiosqlite:///data/ai_story.db
# PostgreSQL 连接池配置(优化后,支持80-150并发用户)
DATABASE_POOL_SIZE=30 # 核心连接数(默认30,小团队可用20)
DATABASE_MAX_OVERFLOW=20 # 最大溢出连接数(默认20,小团队可用10)
DATABASE_POOL_TIMEOUT=60 # 连接等待超时秒数(默认60
DATABASE_POOL_RECYCLE=1800 # 连接回收时间秒数(默认1800=30分钟)
DATABASE_POOL_PRE_PING=True # 连接前检测是否有效
DATABASE_POOL_USE_LIFO=True # 使用LIFO策略提高连接复用率
# 会话监控配置
DATABASE_SESSION_MAX_ACTIVE=50 # 活跃会话警告阈值
DATABASE_SESSION_LEAK_THRESHOLD=100 # 会话泄漏严重告警阈值
# SQLite 性能优化配置(仅在使用SQLite时生效)
SQLITE_CACHE_SIZE_MB=128 # SQLite缓存大小(MB),默认128
SQLITE_MMAP_SIZE_MB=256 # 内存映射I/O大小(MB),默认256
SQLITE_WAL_AUTOCHECKPOINT=1000 # WAL自动检查点间隔
# 数据库监控配置
DATABASE_ENABLE_SLOW_QUERY_LOG=True # 启用慢查询日志
DATABASE_SLOW_QUERY_THRESHOLD=1.0 # 慢查询阈值(秒)
DATABASE_ENABLE_METRICS=True # 启用性能指标收集
# ==========================================
# 日志配置
# ==========================================
LOG_LEVEL=INFO
LOG_TO_FILE=True
LOG_FILE_PATH=logs/app.log
LOG_MAX_BYTES=10485760
LOG_BACKUP_COUNT=30
# ==========================================
# CORS配置
# ==========================================
CORS_ORIGINS=["http://localhost:8000","http://127.0.0.1:8000"]
# ==========================================
# AI服务配置
# ==========================================
# OpenAI配置
OPENAI_API_KEY=your_openai_api_key_here
OPENAI_BASE_URL=https://api.openai.com/v1
# 或使用兼容的API服务(如DeepSeek
# OPENAI_BASE_URL=https://api.deepseek.com/v1
# Gemini配置(可选)
# GEMINI_API_KEY=your_gemini_api_key_here
# GEMINI_BASE_URL=https://generativelanguage.googleapis.com
# Anthropic配置(可选)
# ANTHROPIC_API_KEY=your_anthropic_api_key_here
# ANTHROPIC_BASE_URL=https://api.anthropic.com
# 默认AI配置
DEFAULT_AI_PROVIDER=openai
DEFAULT_MODEL=gpt-4
DEFAULT_TEMPERATURE=0.7
DEFAULT_MAX_TOKENS=2000
# ==========================================
# LinuxDO OAuth2 配置(可选) # LinuxDO OAuth2 配置(可选)
# 注意:Docker部署时,LINUXDO_REDIRECT_URI 应该使用实际的域名或服务器IP # ==========================================
# 本地开发: http://localhost:8000/api/auth/callback # LINUXDO_CLIENT_ID=your_client_id
# 生产环境: https://your-domain.com/api/auth/callback 或 http://your-server-ip:8000/api/auth/callback # LINUXDO_CLIENT_SECRET=your_client_secret
LINUXDO_CLIENT_ID=your_client_id_here # LINUXDO_REDIRECT_URI=http://localhost:8000/api/auth/callback
LINUXDO_CLIENT_SECRET=your_client_secret_here
LINUXDO_REDIRECT_URI=http://localhost:8000/api/auth/callback
# 前端URL配置(用于OAuth回调后重定向到前端 # 前端URLOAuth回调后重定向)
# 本地开发: http://localhost:8000
# 生产环境: https://your-domain.com 或 http://your-server-ip:8000
FRONTEND_URL=http://localhost:8000 FRONTEND_URL=http://localhost:8000
# 本地账户登录配置 # 初始管理员(LinuxDO user_id
# 启用本地账户登录(true/false) # INITIAL_ADMIN_LINUXDO_ID=12345
LOCAL_AUTH_ENABLED=true
# 本地登录用户名
LOCAL_AUTH_USERNAME=admin
# 本地登录密码
LOCAL_AUTH_PASSWORD=your_secure_password_here
# 本地用户显示名称
LOCAL_AUTH_DISPLAY_NAME=管理员
# ==========================================
# 本地账户登录配置
# ==========================================
LOCAL_AUTH_ENABLED=True
LOCAL_AUTH_USERNAME=admin
LOCAL_AUTH_PASSWORD=admin123
LOCAL_AUTH_DISPLAY_NAME=本地管理员
# ==========================================
# 会话配置 # 会话配置
# 会话过期时间(分钟),默认120分钟(2小时) # ==========================================
SESSION_EXPIRE_MINUTES=120 SESSION_EXPIRE_MINUTES=120
# 会话刷新阈值(分钟),剩余时间少于此值时可刷新,默认30分钟
SESSION_REFRESH_THRESHOLD_MINUTES=30 SESSION_REFRESH_THRESHOLD_MINUTES=30
# CORS配置(生产环境) # ==========================================
# 允许的跨域来源,多个用逗号分隔 # 部署配置说明
# CORS_ORIGINS=https://your-domain.com,https://www.your-domain.com # ==========================================
# 生产环境 PostgreSQL 配置示例(50-100并发用户):
# DATABASE_URL=postgresql+asyncpg://mumuai:SecurePassword123@db.example.com:5432/mumuai_prod
# DATABASE_POOL_SIZE=30
# DATABASE_MAX_OVERFLOW=20
# DATABASE_POOL_TIMEOUT=60
# DATABASE_POOL_RECYCLE=1800
# DATABASE_SESSION_MAX_ACTIVE=50
# DEBUG=False
# LOG_LEVEL=WARNING
# 高并发环境 PostgreSQL 配置示例(100+并发用户):
# DATABASE_URL=postgresql+asyncpg://mumuai:SecurePassword123@db.example.com:5432/mumuai_prod
# DATABASE_POOL_SIZE=40
# DATABASE_MAX_OVERFLOW=30
# DATABASE_SESSION_MAX_ACTIVE=80
# DATABASE_SESSION_LEAK_THRESHOLD=150
# Docker 部署配置示例:
# DATABASE_URL=postgresql+asyncpg://mumuai:password@postgres:5432/mumuai_novel
# OPENAI_BASE_URL=https://api.openai.com/v1
# FRONTEND_URL=https://your-domain.com
# LINUXDO_REDIRECT_URI=https://your-domain.com/api/auth/callback
+125 -17
View File
@@ -43,6 +43,39 @@ logger = get_logger(__name__)
db_write_locks: dict[str, Lock] = {} db_write_locks: dict[str, Lock] = {}
async def verify_project_access(project_id: str, user_id: str, db: AsyncSession) -> Project:
"""
验证用户是否有权访问指定项目
Args:
project_id: 项目ID
user_id: 用户ID
db: 数据库会话
Returns:
Project: 项目对象
Raises:
HTTPException: 401 未登录,404 项目不存在或无权访问
"""
if not user_id:
raise HTTPException(status_code=401, detail="未登录")
result = await db.execute(
select(Project).where(
Project.id == project_id,
Project.user_id == user_id
)
)
project = result.scalar_one_or_none()
if not project:
logger.warning(f"项目访问被拒绝: project_id={project_id}, user_id={user_id}")
raise HTTPException(status_code=404, detail="项目不存在或无权访问")
return project
async def get_db_write_lock(user_id: str) -> Lock: async def get_db_write_lock(user_id: str) -> Lock:
"""获取或创建用户的数据库写入锁""" """获取或创建用户的数据库写入锁"""
if user_id not in db_write_locks: if user_id not in db_write_locks:
@@ -54,16 +87,13 @@ async def get_db_write_lock(user_id: str) -> Lock:
@router.post("", response_model=ChapterResponse, summary="创建章节") @router.post("", response_model=ChapterResponse, summary="创建章节")
async def create_chapter( async def create_chapter(
chapter: ChapterCreate, chapter: ChapterCreate,
request: Request,
db: AsyncSession = Depends(get_db) db: AsyncSession = Depends(get_db)
): ):
"""创建新的章节""" """创建新的章节"""
# 验证项目是否存在 # 验证用户权限和项目是否存在
result = await db.execute( user_id = getattr(request.state, 'user_id', None)
select(Project).where(Project.id == chapter.project_id) project = await verify_project_access(chapter.project_id, user_id, db)
)
project = result.scalar_one_or_none()
if not project:
raise HTTPException(status_code=404, detail="项目不存在")
# 计算字数 # 计算字数
word_count = len(chapter.content) word_count = len(chapter.content)
@@ -85,9 +115,14 @@ async def create_chapter(
@router.get("/project/{project_id}", response_model=ChapterListResponse, summary="获取项目的所有章节") @router.get("/project/{project_id}", response_model=ChapterListResponse, summary="获取项目的所有章节")
async def get_project_chapters( async def get_project_chapters(
project_id: str, project_id: str,
request: Request,
db: AsyncSession = Depends(get_db) db: AsyncSession = Depends(get_db)
): ):
"""获取指定项目的所有章节(路径参数版本)""" """获取指定项目的所有章节(路径参数版本)"""
# 验证用户权限
user_id = getattr(request.state, 'user_id', None)
await verify_project_access(project_id, user_id, db)
# 获取总数 # 获取总数
count_result = await db.execute( count_result = await db.execute(
select(func.count(Chapter.id)).where(Chapter.project_id == project_id) select(func.count(Chapter.id)).where(Chapter.project_id == project_id)
@@ -108,6 +143,7 @@ async def get_project_chapters(
@router.get("/{chapter_id}", response_model=ChapterResponse, summary="获取章节详情") @router.get("/{chapter_id}", response_model=ChapterResponse, summary="获取章节详情")
async def get_chapter( async def get_chapter(
chapter_id: str, chapter_id: str,
request: Request,
db: AsyncSession = Depends(get_db) db: AsyncSession = Depends(get_db)
): ):
"""根据ID获取章节详情""" """根据ID获取章节详情"""
@@ -119,12 +155,17 @@ async def get_chapter(
if not chapter: if not chapter:
raise HTTPException(status_code=404, detail="章节不存在") raise HTTPException(status_code=404, detail="章节不存在")
# 验证用户权限
user_id = getattr(request.state, 'user_id', None)
await verify_project_access(chapter.project_id, user_id, db)
return chapter return chapter
@router.get("/{chapter_id}/navigation", summary="获取章节导航信息") @router.get("/{chapter_id}/navigation", summary="获取章节导航信息")
async def get_chapter_navigation( async def get_chapter_navigation(
chapter_id: str, chapter_id: str,
request: Request,
db: AsyncSession = Depends(get_db) db: AsyncSession = Depends(get_db)
): ):
""" """
@@ -140,6 +181,10 @@ async def get_chapter_navigation(
if not current_chapter: if not current_chapter:
raise HTTPException(status_code=404, detail="章节不存在") raise HTTPException(status_code=404, detail="章节不存在")
# 验证用户权限
user_id = getattr(request.state, 'user_id', None)
await verify_project_access(current_chapter.project_id, user_id, db)
# 获取上一章 # 获取上一章
prev_result = await db.execute( prev_result = await db.execute(
select(Chapter) select(Chapter)
@@ -183,6 +228,7 @@ async def get_chapter_navigation(
async def update_chapter( async def update_chapter(
chapter_id: str, chapter_id: str,
chapter_update: ChapterUpdate, chapter_update: ChapterUpdate,
request: Request,
db: AsyncSession = Depends(get_db) db: AsyncSession = Depends(get_db)
): ):
"""更新章节信息""" """更新章节信息"""
@@ -194,6 +240,10 @@ async def update_chapter(
if not chapter: if not chapter:
raise HTTPException(status_code=404, detail="章节不存在") raise HTTPException(status_code=404, detail="章节不存在")
# 验证用户权限
user_id = getattr(request.state, 'user_id', None)
await verify_project_access(chapter.project_id, user_id, db)
# 记录旧字数 # 记录旧字数
old_word_count = chapter.word_count or 0 old_word_count = chapter.word_count or 0
@@ -223,6 +273,7 @@ async def update_chapter(
@router.delete("/{chapter_id}", summary="删除章节") @router.delete("/{chapter_id}", summary="删除章节")
async def delete_chapter( async def delete_chapter(
chapter_id: str, chapter_id: str,
request: Request,
db: AsyncSession = Depends(get_db) db: AsyncSession = Depends(get_db)
): ):
"""删除章节""" """删除章节"""
@@ -234,6 +285,10 @@ async def delete_chapter(
if not chapter: if not chapter:
raise HTTPException(status_code=404, detail="章节不存在") raise HTTPException(status_code=404, detail="章节不存在")
# 验证用户权限
user_id = getattr(request.state, 'user_id', None)
await verify_project_access(chapter.project_id, user_id, db)
# 更新项目字数 # 更新项目字数
result = await db.execute( result = await db.execute(
select(Project).where(Project.id == chapter.project_id) select(Project).where(Project.id == chapter.project_id)
@@ -481,6 +536,7 @@ async def build_smart_chapter_context(
@router.get("/{chapter_id}/can-generate", summary="检查章节是否可以生成") @router.get("/{chapter_id}/can-generate", summary="检查章节是否可以生成")
async def check_can_generate( async def check_can_generate(
chapter_id: str, chapter_id: str,
request: Request,
db: AsyncSession = Depends(get_db) db: AsyncSession = Depends(get_db)
): ):
""" """
@@ -495,6 +551,10 @@ async def check_can_generate(
if not chapter: if not chapter:
raise HTTPException(status_code=404, detail="章节不存在") raise HTTPException(status_code=404, detail="章节不存在")
# 验证用户权限
user_id = getattr(request.state, 'user_id', None)
await verify_project_access(chapter.project_id, user_id, db)
# 检查前置条件 # 检查前置条件
can_generate, error_msg, previous_chapters = await check_prerequisites(db, chapter) can_generate, error_msg, previous_chapters = await check_prerequisites(db, chapter)
@@ -1238,6 +1298,7 @@ async def generate_chapter_content_stream(
@router.get("/{chapter_id}/analysis/status", summary="查询章节分析任务状态") @router.get("/{chapter_id}/analysis/status", summary="查询章节分析任务状态")
async def get_analysis_task_status( async def get_analysis_task_status(
chapter_id: str, chapter_id: str,
request: Request,
db: AsyncSession = Depends(get_db) db: AsyncSession = Depends(get_db)
): ):
""" """
@@ -1248,16 +1309,32 @@ async def get_analysis_task_status(
- 如果任务状态为pending且超过2分钟未启动,自动标记为failed - 如果任务状态为pending且超过2分钟未启动,自动标记为failed
返回: 返回:
- task_id: 任务ID - has_task: 是否存在分析任务
- status: pending/running/completed/failed - task_id: 任务ID(如果存在)
- status: pending/running/completed/failed/none(如果不存在则为none
- progress: 0-100 - progress: 0-100
- error_message: 错误信息(如果失败) - error_message: 错误信息(如果失败)
- auto_recovered: 是否被自动恢复 - auto_recovered: 是否被自动恢复
- created_at: 创建时间 - created_at: 创建时间
- completed_at: 完成时间 - completed_at: 完成时间
注意:当章节不存在或无权访问时返回404,当没有分析任务时返回has_task=false
""" """
from datetime import timedelta from datetime import timedelta
# 先获取章节以验证存在性和权限
chapter_result = await db.execute(
select(Chapter).where(Chapter.id == chapter_id)
)
chapter = chapter_result.scalar_one_or_none()
if not chapter:
raise HTTPException(status_code=404, detail="章节不存在")
# 验证用户权限
user_id = getattr(request.state, 'user_id', None)
await verify_project_access(chapter.project_id, user_id, db)
# 获取该章节最新的分析任务 # 获取该章节最新的分析任务
result = await db.execute( result = await db.execute(
select(AnalysisTask) select(AnalysisTask)
@@ -1268,7 +1345,19 @@ async def get_analysis_task_status(
task = result.scalar_one_or_none() task = result.scalar_one_or_none()
if not task: if not task:
raise HTTPException(status_code=404, detail="未找到分析任务") # 返回无任务状态,而不是抛出404错误
return {
"has_task": False,
"chapter_id": chapter_id,
"status": "none",
"progress": 0,
"error_message": None,
"auto_recovered": False,
"task_id": None,
"created_at": None,
"started_at": None,
"completed_at": None
}
auto_recovered = False auto_recovered = False
current_time = datetime.now() current_time = datetime.now()
@@ -1299,6 +1388,7 @@ async def get_analysis_task_status(
logger.warning(f"🔄 自动恢复未启动的任务: {task.id}, 章节: {chapter_id}") logger.warning(f"🔄 自动恢复未启动的任务: {task.id}, 章节: {chapter_id}")
return { return {
"has_task": True,
"task_id": task.id, "task_id": task.id,
"chapter_id": task.chapter_id, "chapter_id": task.chapter_id,
"status": task.status, "status": task.status,
@@ -1314,6 +1404,7 @@ async def get_analysis_task_status(
@router.get("/{chapter_id}/analysis", summary="获取章节分析结果") @router.get("/{chapter_id}/analysis", summary="获取章节分析结果")
async def get_chapter_analysis( async def get_chapter_analysis(
chapter_id: str, chapter_id: str,
request: Request,
db: AsyncSession = Depends(get_db) db: AsyncSession = Depends(get_db)
): ):
""" """
@@ -1325,6 +1416,16 @@ async def get_chapter_analysis(
- memories: 提取的记忆列表 - memories: 提取的记忆列表
- created_at: 分析时间 - created_at: 分析时间
""" """
# 先获取章节以验证权限
chapter_result_check = await db.execute(
select(Chapter).where(Chapter.id == chapter_id)
)
chapter_check = chapter_result_check.scalar_one_or_none()
if chapter_check:
# 验证用户权限
user_id = getattr(request.state, 'user_id', None)
await verify_project_access(chapter_check.project_id, user_id, db)
# 获取分析结果 # 获取分析结果
analysis_result = await db.execute( analysis_result = await db.execute(
select(PlotAnalysis) select(PlotAnalysis)
@@ -1369,6 +1470,7 @@ async def get_chapter_analysis(
@router.get("/{chapter_id}/annotations", summary="获取章节标注数据") @router.get("/{chapter_id}/annotations", summary="获取章节标注数据")
async def get_chapter_annotations( async def get_chapter_annotations(
chapter_id: str, chapter_id: str,
request: Request,
db: AsyncSession = Depends(get_db) db: AsyncSession = Depends(get_db)
): ):
""" """
@@ -1377,6 +1479,9 @@ async def get_chapter_annotations(
返回格式化的标注列表,包含精确位置信息 返回格式化的标注列表,包含精确位置信息
适用于章节内容的可视化标注展示 适用于章节内容的可视化标注展示
""" """
# 验证用户权限
user_id = getattr(request.state, 'user_id', None)
# 获取章节 # 获取章节
chapter_result = await db.execute( chapter_result = await db.execute(
select(Chapter).where(Chapter.id == chapter_id) select(Chapter).where(Chapter.id == chapter_id)
@@ -1386,6 +1491,9 @@ async def get_chapter_annotations(
if not chapter: if not chapter:
raise HTTPException(status_code=404, detail="章节不存在") raise HTTPException(status_code=404, detail="章节不存在")
# 验证项目访问权限
await verify_project_access(chapter.project_id, user_id, db)
# 获取分析结果 # 获取分析结果
analysis_result = await db.execute( analysis_result = await db.execute(
select(PlotAnalysis) select(PlotAnalysis)
@@ -1623,13 +1731,8 @@ async def batch_generate_chapters_in_order(
if not user_id: if not user_id:
raise HTTPException(status_code=401, detail="未登录") raise HTTPException(status_code=401, detail="未登录")
# 验证项目存在 # 验证项目存在和用户权限
project_result = await db.execute( project = await verify_project_access(project_id, user_id, db)
select(Project).where(Project.id == project_id)
)
project = project_result.scalar_one_or_none()
if not project:
raise HTTPException(status_code=404, detail="项目不存在")
# 获取项目的所有章节,按序号排序 # 获取项目的所有章节,按序号排序
result = await db.execute( result = await db.execute(
@@ -1750,12 +1853,17 @@ async def get_batch_generation_status(
@router.get("/project/{project_id}/batch-generate/active", summary="获取项目当前运行中的批量生成任务") @router.get("/project/{project_id}/batch-generate/active", summary="获取项目当前运行中的批量生成任务")
async def get_active_batch_generation( async def get_active_batch_generation(
project_id: str, project_id: str,
request: Request,
db: AsyncSession = Depends(get_db) db: AsyncSession = Depends(get_db)
): ):
""" """
获取项目当前运行中的批量生成任务 获取项目当前运行中的批量生成任务
用于页面刷新后恢复任务状态 用于页面刷新后恢复任务状态
""" """
# 验证用户权限
user_id = getattr(request.state, 'user_id', None)
await verify_project_access(project_id, user_id, db)
result = await db.execute( result = await db.execute(
select(BatchGenerationTask) select(BatchGenerationTask)
.where(BatchGenerationTask.project_id == project_id) .where(BatchGenerationTask.project_id == project_id)
+61 -10
View File
@@ -24,12 +24,50 @@ router = APIRouter(prefix="/characters", tags=["角色管理"])
logger = get_logger(__name__) logger = get_logger(__name__)
async def verify_project_access(project_id: str, user_id: str, db: AsyncSession) -> Project:
"""
验证用户是否有权访问指定项目
Args:
project_id: 项目ID
user_id: 用户ID
db: 数据库会话
Returns:
Project: 项目对象
Raises:
HTTPException: 401 未登录,404 项目不存在或无权访问
"""
if not user_id:
raise HTTPException(status_code=401, detail="未登录")
result = await db.execute(
select(Project).where(
Project.id == project_id,
Project.user_id == user_id
)
)
project = result.scalar_one_or_none()
if not project:
logger.warning(f"项目访问被拒绝: project_id={project_id}, user_id={user_id}")
raise HTTPException(status_code=404, detail="项目不存在或无权访问")
return project
@router.get("", response_model=CharacterListResponse, summary="获取角色列表") @router.get("", response_model=CharacterListResponse, summary="获取角色列表")
async def get_characters( async def get_characters(
project_id: str, project_id: str,
request: Request,
db: AsyncSession = Depends(get_db) db: AsyncSession = Depends(get_db)
): ):
"""获取指定项目的所有角色(query参数版本)""" """获取指定项目的所有角色(query参数版本)"""
# 验证用户权限
user_id = getattr(request.state, 'user_id', None)
await verify_project_access(project_id, user_id, db)
# 获取总数 # 获取总数
count_result = await db.execute( count_result = await db.execute(
select(func.count(Character.id)).where(Character.project_id == project_id) select(func.count(Character.id)).where(Character.project_id == project_id)
@@ -93,9 +131,14 @@ async def get_characters(
@router.get("/project/{project_id}", response_model=CharacterListResponse, summary="获取项目的所有角色") @router.get("/project/{project_id}", response_model=CharacterListResponse, summary="获取项目的所有角色")
async def get_project_characters( async def get_project_characters(
project_id: str, project_id: str,
request: Request,
db: AsyncSession = Depends(get_db) db: AsyncSession = Depends(get_db)
): ):
"""获取指定项目的所有角色(路径参数版本)""" """获取指定项目的所有角色(路径参数版本)"""
# 验证用户权限
user_id = getattr(request.state, 'user_id', None)
await verify_project_access(project_id, user_id, db)
# 获取总数 # 获取总数
count_result = await db.execute( count_result = await db.execute(
select(func.count(Character.id)).where(Character.project_id == project_id) select(func.count(Character.id)).where(Character.project_id == project_id)
@@ -159,6 +202,7 @@ async def get_project_characters(
@router.get("/{character_id}", response_model=CharacterResponse, summary="获取角色详情") @router.get("/{character_id}", response_model=CharacterResponse, summary="获取角色详情")
async def get_character( async def get_character(
character_id: str, character_id: str,
request: Request,
db: AsyncSession = Depends(get_db) db: AsyncSession = Depends(get_db)
): ):
"""根据ID获取角色详情""" """根据ID获取角色详情"""
@@ -170,6 +214,10 @@ async def get_character(
if not character: if not character:
raise HTTPException(status_code=404, detail="角色不存在") raise HTTPException(status_code=404, detail="角色不存在")
# 验证用户权限
user_id = getattr(request.state, 'user_id', None)
await verify_project_access(character.project_id, user_id, db)
return character return character
@@ -177,6 +225,7 @@ async def get_character(
async def update_character( async def update_character(
character_id: str, character_id: str,
character_update: CharacterUpdate, character_update: CharacterUpdate,
request: Request,
db: AsyncSession = Depends(get_db) db: AsyncSession = Depends(get_db)
): ):
"""更新角色信息""" """更新角色信息"""
@@ -188,6 +237,10 @@ async def update_character(
if not character: if not character:
raise HTTPException(status_code=404, detail="角色不存在") raise HTTPException(status_code=404, detail="角色不存在")
# 验证用户权限
user_id = getattr(request.state, 'user_id', None)
await verify_project_access(character.project_id, user_id, db)
# 更新字段 # 更新字段
update_data = character_update.model_dump(exclude_unset=True) update_data = character_update.model_dump(exclude_unset=True)
for field, value in update_data.items(): for field, value in update_data.items():
@@ -201,6 +254,7 @@ async def update_character(
@router.delete("/{character_id}", summary="删除角色") @router.delete("/{character_id}", summary="删除角色")
async def delete_character( async def delete_character(
character_id: str, character_id: str,
request: Request,
db: AsyncSession = Depends(get_db) db: AsyncSession = Depends(get_db)
): ):
"""删除角色""" """删除角色"""
@@ -212,6 +266,10 @@ async def delete_character(
if not character: if not character:
raise HTTPException(status_code=404, detail="角色不存在") raise HTTPException(status_code=404, detail="角色不存在")
# 验证用户权限
user_id = getattr(request.state, 'user_id', None)
await verify_project_access(character.project_id, user_id, db)
await db.delete(character) await db.delete(character)
await db.commit() await db.commit()
@@ -233,13 +291,9 @@ async def generate_character(
生成内容包括:姓名、年龄、性别、性格、外貌、背景故事、人际关系等 生成内容包括:姓名、年龄、性别、性格、外貌、背景故事、人际关系等
""" """
# 验证项目是否存在并获取项目信息 # 验证用户权限和项目是否存在
result = await db.execute( user_id = getattr(http_request.state, 'user_id', None)
select(Project).where(Project.id == request.project_id) project = await verify_project_access(request.project_id, user_id, db)
)
project = result.scalar_one_or_none()
if not project:
raise HTTPException(status_code=404, detail="项目不存在")
try: try:
# 获取已存在的角色列表,用于关系网络 # 获取已存在的角色列表,用于关系网络
@@ -295,9 +349,6 @@ async def generate_character(
user_input=user_input user_input=user_input
) )
# 获取user_id用于MCP工具调用
user_id = http_request.state.user_id if hasattr(http_request.state, 'user_id') else 'default_user'
# 调用AI生成角色(支持MCP工具) # 调用AI生成角色(支持MCP工具)
logger.info(f"🎯 开始为项目 {request.project_id} 生成角色(启用MCP") logger.info(f"🎯 开始为项目 {request.project_id} 生成角色(启用MCP")
logger.info(f" - 角色名:{request.name or 'AI生成'}") logger.info(f" - 角色名:{request.name or 'AI生成'}")
+55 -8
View File
@@ -6,6 +6,7 @@ from typing import List, Optional
from app.database import get_db from app.database import get_db
from app.models.memory import StoryMemory, PlotAnalysis from app.models.memory import StoryMemory, PlotAnalysis
from app.models.chapter import Chapter from app.models.chapter import Chapter
from app.models.project import Project
from app.services.memory_service import memory_service from app.services.memory_service import memory_service
from app.services.plot_analyzer import get_plot_analyzer from app.services.plot_analyzer import get_plot_analyzer
from app.services.ai_service import create_user_ai_service from app.services.ai_service import create_user_ai_service
@@ -17,6 +18,26 @@ logger = get_logger(__name__)
router = APIRouter(prefix="/api/memories", tags=["memories"]) router = APIRouter(prefix="/api/memories", tags=["memories"])
async def verify_project_access(project_id: str, user_id: str, db: AsyncSession) -> Project:
"""验证用户是否有权访问指定项目"""
if not user_id:
raise HTTPException(status_code=401, detail="未登录")
result = await db.execute(
select(Project).where(
Project.id == project_id,
Project.user_id == user_id
)
)
project = result.scalar_one_or_none()
if not project:
logger.warning(f"项目访问被拒绝: project_id={project_id}, user_id={user_id}")
raise HTTPException(status_code=404, detail="项目不存在或无权访问")
return project
@router.post("/projects/{project_id}/analyze-chapter/{chapter_id}") @router.post("/projects/{project_id}/analyze-chapter/{chapter_id}")
async def analyze_chapter( async def analyze_chapter(
project_id: str, project_id: str,
@@ -30,7 +51,10 @@ async def analyze_chapter(
对指定章节进行剧情分析,提取钩子、伏笔、情节点等,并存入记忆系统 对指定章节进行剧情分析,提取钩子、伏笔、情节点等,并存入记忆系统
""" """
try: try:
user_id = request.state.user_id user_id = getattr(request.state, 'user_id', None)
# 验证用户权限
await verify_project_access(project_id, user_id, db)
# 获取章节内容 # 获取章节内容
result = await db.execute( result = await db.execute(
@@ -192,7 +216,10 @@ async def get_project_memories(
): ):
"""获取项目的记忆列表""" """获取项目的记忆列表"""
try: try:
user_id = request.state.user_id user_id = getattr(request.state, 'user_id', None)
# 验证用户权限
await verify_project_access(project_id, user_id, db)
# 构建查询 # 构建查询
query = select(StoryMemory).where(StoryMemory.project_id == project_id) query = select(StoryMemory).where(StoryMemory.project_id == project_id)
@@ -222,10 +249,16 @@ async def get_project_memories(
async def get_chapter_analysis( async def get_chapter_analysis(
project_id: str, project_id: str,
chapter_id: str, chapter_id: str,
request: Request,
db: AsyncSession = Depends(get_db) db: AsyncSession = Depends(get_db)
): ):
"""获取章节的剧情分析""" """获取章节的剧情分析"""
try: try:
user_id = getattr(request.state, 'user_id', None)
# 验证用户权限
await verify_project_access(project_id, user_id, db)
result = await db.execute( result = await db.execute(
select(PlotAnalysis).where( select(PlotAnalysis).where(
and_( and_(
@@ -258,11 +291,15 @@ async def search_memories(
query: str, query: str,
memory_types: Optional[List[str]] = None, memory_types: Optional[List[str]] = None,
limit: int = 10, limit: int = 10,
min_importance: float = 0.0 min_importance: float = 0.0,
db: AsyncSession = Depends(get_db)
): ):
"""语义搜索项目记忆""" """语义搜索项目记忆"""
try: try:
user_id = request.state.user_id user_id = getattr(request.state, 'user_id', None)
# 验证用户权限
await verify_project_access(project_id, user_id, db)
memories = await memory_service.search_memories( memories = await memory_service.search_memories(
user_id=user_id, user_id=user_id,
@@ -294,7 +331,10 @@ async def get_unresolved_foreshadows(
): ):
"""获取未完结的伏笔""" """获取未完结的伏笔"""
try: try:
user_id = request.state.user_id user_id = getattr(request.state, 'user_id', None)
# 验证用户权限
await verify_project_access(project_id, user_id, db)
# 从向量库搜索 # 从向量库搜索
foreshadows = await memory_service.find_unresolved_foreshadows( foreshadows = await memory_service.find_unresolved_foreshadows(
@@ -317,11 +357,15 @@ async def get_unresolved_foreshadows(
@router.get("/projects/{project_id}/stats") @router.get("/projects/{project_id}/stats")
async def get_memory_stats( async def get_memory_stats(
project_id: str, project_id: str,
request: Request request: Request,
db: AsyncSession = Depends(get_db)
): ):
"""获取记忆统计信息""" """获取记忆统计信息"""
try: try:
user_id = request.state.user_id user_id = getattr(request.state, 'user_id', None)
# 验证用户权限
await verify_project_access(project_id, user_id, db)
stats = await memory_service.get_memory_stats( stats = await memory_service.get_memory_stats(
user_id=user_id, user_id=user_id,
@@ -347,7 +391,10 @@ async def delete_chapter_memories(
): ):
"""删除章节的所有记忆""" """删除章节的所有记忆"""
try: try:
user_id = request.state.user_id user_id = getattr(request.state, 'user_id', None)
# 验证用户权限
await verify_project_access(project_id, user_id, db)
# 从数据库删除 # 从数据库删除
result = await db.execute( result = await db.execute(
+91 -24
View File
@@ -1,5 +1,5 @@
"""组织管理API""" """组织管理API"""
from fastapi import APIRouter, Depends, HTTPException from fastapi import APIRouter, Depends, HTTPException, Request
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, and_ from sqlalchemy import select, and_
from typing import List, Optional from typing import List, Optional
@@ -31,6 +31,26 @@ router = APIRouter(prefix="/organizations", tags=["组织管理"])
logger = get_logger(__name__) logger = get_logger(__name__)
async def verify_project_access(project_id: str, user_id: str, db: AsyncSession) -> Project:
"""验证用户是否有权访问指定项目"""
if not user_id:
raise HTTPException(status_code=401, detail="未登录")
result = await db.execute(
select(Project).where(
Project.id == project_id,
Project.user_id == user_id
)
)
project = result.scalar_one_or_none()
if not project:
logger.warning(f"项目访问被拒绝: project_id={project_id}, user_id={user_id}")
raise HTTPException(status_code=404, detail="项目不存在或无权访问")
return project
class OrganizationGenerateRequest(BaseModel): class OrganizationGenerateRequest(BaseModel):
"""AI生成组织的请求模型""" """AI生成组织的请求模型"""
project_id: str = Field(..., description="项目ID") project_id: str = Field(..., description="项目ID")
@@ -44,8 +64,13 @@ class OrganizationGenerateRequest(BaseModel):
@router.get("/project/{project_id}", response_model=List[OrganizationDetailResponse], summary="获取项目的所有组织") @router.get("/project/{project_id}", response_model=List[OrganizationDetailResponse], summary="获取项目的所有组织")
async def get_project_organizations( async def get_project_organizations(
project_id: str, project_id: str,
request: Request,
db: AsyncSession = Depends(get_db) db: AsyncSession = Depends(get_db)
): ):
# 验证用户权限
user_id = getattr(request.state, 'user_id', None)
await verify_project_access(project_id, user_id, db)
""" """
获取项目中的所有组织及其详情 获取项目中的所有组织及其详情
@@ -85,6 +110,7 @@ async def get_project_organizations(
@router.get("/{org_id}", response_model=OrganizationResponse, summary="获取组织详情") @router.get("/{org_id}", response_model=OrganizationResponse, summary="获取组织详情")
async def get_organization( async def get_organization(
org_id: str, org_id: str,
request: Request,
db: AsyncSession = Depends(get_db) db: AsyncSession = Depends(get_db)
): ):
"""获取组织的详细信息""" """获取组织的详细信息"""
@@ -96,12 +122,17 @@ async def get_organization(
if not org: if not org:
raise HTTPException(status_code=404, detail="组织不存在") raise HTTPException(status_code=404, detail="组织不存在")
# 验证用户权限
user_id = getattr(request.state, 'user_id', None)
await verify_project_access(org.project_id, user_id, db)
return org return org
@router.post("/", response_model=OrganizationResponse, summary="创建组织") @router.post("/", response_model=OrganizationResponse, summary="创建组织")
async def create_organization( async def create_organization(
organization: OrganizationCreate, organization: OrganizationCreate,
request: Request,
db: AsyncSession = Depends(get_db) db: AsyncSession = Depends(get_db)
): ):
""" """
@@ -110,6 +141,10 @@ async def create_organization(
- 需要关联到一个已存在的角色记录(is_organization=True - 需要关联到一个已存在的角色记录(is_organization=True
- 可以设置父组织、势力等级等属性 - 可以设置父组织、势力等级等属性
""" """
# 验证用户权限
user_id = getattr(request.state, 'user_id', None)
await verify_project_access(organization.project_id, user_id, db)
# 验证角色是否存在且是组织 # 验证角色是否存在且是组织
char_result = await db.execute( char_result = await db.execute(
select(Character).where(Character.id == organization.character_id) select(Character).where(Character.id == organization.character_id)
@@ -142,6 +177,7 @@ async def create_organization(
async def update_organization( async def update_organization(
org_id: str, org_id: str,
organization: OrganizationUpdate, organization: OrganizationUpdate,
request: Request,
db: AsyncSession = Depends(get_db) db: AsyncSession = Depends(get_db)
): ):
"""更新组织的属性""" """更新组织的属性"""
@@ -153,6 +189,10 @@ async def update_organization(
if not db_org: if not db_org:
raise HTTPException(status_code=404, detail="组织不存在") raise HTTPException(status_code=404, detail="组织不存在")
# 验证用户权限
user_id = getattr(request.state, 'user_id', None)
await verify_project_access(db_org.project_id, user_id, db)
# 更新字段 # 更新字段
update_data = organization.model_dump(exclude_unset=True) update_data = organization.model_dump(exclude_unset=True)
for field, value in update_data.items(): for field, value in update_data.items():
@@ -168,6 +208,7 @@ async def update_organization(
@router.delete("/{org_id}", summary="删除组织") @router.delete("/{org_id}", summary="删除组织")
async def delete_organization( async def delete_organization(
org_id: str, org_id: str,
request: Request,
db: AsyncSession = Depends(get_db) db: AsyncSession = Depends(get_db)
): ):
"""删除组织(会级联删除所有成员关系)""" """删除组织(会级联删除所有成员关系)"""
@@ -179,6 +220,10 @@ async def delete_organization(
if not db_org: if not db_org:
raise HTTPException(status_code=404, detail="组织不存在") raise HTTPException(status_code=404, detail="组织不存在")
# 验证用户权限
user_id = getattr(request.state, 'user_id', None)
await verify_project_access(db_org.project_id, user_id, db)
await db.delete(db_org) await db.delete(db_org)
await db.commit() await db.commit()
@@ -191,6 +236,7 @@ async def delete_organization(
@router.get("/{org_id}/members", response_model=List[OrganizationMemberDetailResponse], summary="获取组织成员") @router.get("/{org_id}/members", response_model=List[OrganizationMemberDetailResponse], summary="获取组织成员")
async def get_organization_members( async def get_organization_members(
org_id: str, org_id: str,
request: Request,
db: AsyncSession = Depends(get_db) db: AsyncSession = Depends(get_db)
): ):
""" """
@@ -202,9 +248,14 @@ async def get_organization_members(
org_result = await db.execute( org_result = await db.execute(
select(Organization).where(Organization.id == org_id) select(Organization).where(Organization.id == org_id)
) )
if not org_result.scalar_one_or_none(): org = org_result.scalar_one_or_none()
if not org:
raise HTTPException(status_code=404, detail="组织不存在") raise HTTPException(status_code=404, detail="组织不存在")
# 验证用户权限
user_id = getattr(request.state, 'user_id', None)
await verify_project_access(org.project_id, user_id, db)
# 获取成员列表 # 获取成员列表
result = await db.execute( result = await db.execute(
select(OrganizationMember) select(OrganizationMember)
@@ -244,6 +295,7 @@ async def get_organization_members(
async def add_organization_member( async def add_organization_member(
org_id: str, org_id: str,
member: OrganizationMemberCreate, member: OrganizationMemberCreate,
request: Request,
db: AsyncSession = Depends(get_db) db: AsyncSession = Depends(get_db)
): ):
""" """
@@ -260,6 +312,10 @@ async def add_organization_member(
if not org: if not org:
raise HTTPException(status_code=404, detail="组织不存在") raise HTTPException(status_code=404, detail="组织不存在")
# 验证用户权限
user_id = getattr(request.state, 'user_id', None)
await verify_project_access(org.project_id, user_id, db)
# 验证角色存在 # 验证角色存在
char_result = await db.execute( char_result = await db.execute(
select(Character).where(Character.id == member.character_id) select(Character).where(Character.id == member.character_id)
@@ -304,6 +360,7 @@ async def add_organization_member(
async def update_organization_member( async def update_organization_member(
member_id: str, member_id: str,
member: OrganizationMemberUpdate, member: OrganizationMemberUpdate,
request: Request,
db: AsyncSession = Depends(get_db) db: AsyncSession = Depends(get_db)
): ):
"""更新组织成员的职位、忠诚度等信息""" """更新组织成员的职位、忠诚度等信息"""
@@ -315,6 +372,14 @@ async def update_organization_member(
if not db_member: if not db_member:
raise HTTPException(status_code=404, detail="成员记录不存在") raise HTTPException(status_code=404, detail="成员记录不存在")
# 通过成员所属的组织验证用户权限
org_result = await db.execute(
select(Organization).where(Organization.id == db_member.organization_id)
)
org = org_result.scalar_one()
user_id = getattr(request.state, 'user_id', None)
await verify_project_access(org.project_id, user_id, db)
# 更新字段 # 更新字段
update_data = member.model_dump(exclude_unset=True) update_data = member.model_dump(exclude_unset=True)
for field, value in update_data.items(): for field, value in update_data.items():
@@ -330,6 +395,7 @@ async def update_organization_member(
@router.delete("/members/{member_id}", summary="移除组织成员") @router.delete("/members/{member_id}", summary="移除组织成员")
async def remove_organization_member( async def remove_organization_member(
member_id: str, member_id: str,
request: Request,
db: AsyncSession = Depends(get_db) db: AsyncSession = Depends(get_db)
): ):
""" """
@@ -350,6 +416,10 @@ async def remove_organization_member(
select(Organization).where(Organization.id == db_member.organization_id) select(Organization).where(Organization.id == db_member.organization_id)
) )
org = org_result.scalar_one() org = org_result.scalar_one()
# 验证用户权限
user_id = getattr(request.state, 'user_id', None)
await verify_project_access(org.project_id, user_id, db)
org.member_count = max(0, org.member_count - 1) org.member_count = max(0, org.member_count - 1)
await db.delete(db_member) await db.delete(db_member)
@@ -360,7 +430,8 @@ async def remove_organization_member(
@router.post("/generate", response_model=CharacterResponse, summary="AI生成组织") @router.post("/generate", response_model=CharacterResponse, summary="AI生成组织")
async def generate_organization( async def generate_organization(
request: OrganizationGenerateRequest, gen_request: OrganizationGenerateRequest,
http_request: Request,
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
user_ai_service: AIService = Depends(get_user_ai_service) user_ai_service: AIService = Depends(get_user_ai_service)
): ):
@@ -372,19 +443,15 @@ async def generate_organization(
生成内容包括:组织名称、类型、特性、背景、目的、势力等级等 生成内容包括:组织名称、类型、特性、背景、目的、势力等级等
""" """
# 验证项目是否存在并获取项目信息 # 验证用户权限
result = await db.execute( user_id = getattr(http_request.state, 'user_id', None)
select(Project).where(Project.id == request.project_id) project = await verify_project_access(gen_request.project_id, user_id, db)
)
project = result.scalar_one_or_none()
if not project:
raise HTTPException(status_code=404, detail="项目不存在")
try: try:
# 获取已存在的角色和组织列表 # 获取已存在的角色和组织列表
existing_chars_result = await db.execute( existing_chars_result = await db.execute(
select(Character) select(Character)
.where(Character.project_id == request.project_id) .where(Character.project_id == gen_request.project_id)
.order_by(Character.created_at.desc()) .order_by(Character.created_at.desc())
) )
existing_characters = existing_chars_result.scalars().all() existing_characters = existing_chars_result.scalars().all()
@@ -422,10 +489,10 @@ async def generate_organization(
# 构建用户输入信息 # 构建用户输入信息
user_input = f""" user_input = f"""
用户要求: 用户要求:
- 组织名称:{request.name or '请AI生成'} - 组织名称:{gen_request.name or '请AI生成'}
- 组织类型:{request.organization_type or '请AI根据世界观决定'} - 组织类型:{gen_request.organization_type or '请AI根据世界观决定'}
- 背景设定:{request.background or '无特殊要求'} - 背景设定:{gen_request.background or '无特殊要求'}
- 其他要求:{request.requirements or ''} - 其他要求:{gen_request.requirements or ''}
""" """
# 使用统一的提示词服务 # 使用统一的提示词服务
@@ -435,10 +502,10 @@ async def generate_organization(
) )
# 调用AI生成组织 # 调用AI生成组织
logger.info(f"🎯 开始为项目 {request.project_id} 生成组织") logger.info(f"🎯 开始为项目 {gen_request.project_id} 生成组织")
logger.info(f" - 组织名:{request.name or 'AI生成'}") logger.info(f" - 组织名:{gen_request.name or 'AI生成'}")
logger.info(f" - 组织类型:{request.organization_type or 'AI决定'}") logger.info(f" - 组织类型:{gen_request.organization_type or 'AI决定'}")
logger.info(f" - 背景设定:{request.background or ''}") logger.info(f" - 背景设定:{gen_request.background or ''}")
logger.info(f" - AI提供商:{user_ai_service.api_provider}") logger.info(f" - AI提供商:{user_ai_service.api_provider}")
logger.info(f" - AI模型:{user_ai_service.default_model}") logger.info(f" - AI模型:{user_ai_service.default_model}")
logger.info(f" - Prompt长度:{len(prompt)} 字符") logger.info(f" - Prompt长度:{len(prompt)} 字符")
@@ -492,8 +559,8 @@ async def generate_organization(
# 创建角色记录(组织也是角色的一种) # 创建角色记录(组织也是角色的一种)
character = Character( character = Character(
project_id=request.project_id, project_id=gen_request.project_id,
name=organization_data.get("name", request.name or "未命名组织"), name=organization_data.get("name", gen_request.name or "未命名组织"),
is_organization=True, is_organization=True,
role_type="supporting", # 组织通常作为配角 role_type="supporting", # 组织通常作为配角
personality=organization_data.get("personality", ""), personality=organization_data.get("personality", ""),
@@ -518,7 +585,7 @@ async def generate_organization(
# 自动创建Organization详情记录 # 自动创建Organization详情记录
organization = Organization( organization = Organization(
character_id=character.id, character_id=character.id,
project_id=request.project_id, project_id=gen_request.project_id,
member_count=0, member_count=0,
power_level=organization_data.get("power_level", 50), power_level=organization_data.get("power_level", 50),
location=organization_data.get("location"), location=organization_data.get("location"),
@@ -532,7 +599,7 @@ async def generate_organization(
# 记录生成历史 # 记录生成历史
history = GenerationHistory( history = GenerationHistory(
project_id=request.project_id, project_id=gen_request.project_id,
prompt=prompt, prompt=prompt,
generated_content=ai_content, generated_content=ai_content,
model=user_ai_service.default_model model=user_ai_service.default_model
@@ -542,7 +609,7 @@ async def generate_organization(
await db.commit() await db.commit()
await db.refresh(character) await db.refresh(character)
logger.info(f"🎉 成功为项目 {request.project_id} 生成组织: {character.name}") logger.info(f"🎉 成功为项目 {gen_request.project_id} 生成组织: {character.name}")
return character return character
+81 -23
View File
@@ -30,19 +30,49 @@ router = APIRouter(prefix="/outlines", tags=["大纲管理"])
logger = get_logger(__name__) logger = get_logger(__name__)
async def verify_project_access(project_id: str, user_id: str, db: AsyncSession) -> Project:
"""
验证用户是否有权访问指定项目
Args:
project_id: 项目ID
user_id: 用户ID
db: 数据库会话
Returns:
Project: 项目对象
Raises:
HTTPException: 401 未登录,404 项目不存在或无权访问
"""
if not user_id:
raise HTTPException(status_code=401, detail="未登录")
result = await db.execute(
select(Project).where(
Project.id == project_id,
Project.user_id == user_id
)
)
project = result.scalar_one_or_none()
if not project:
logger.warning(f"项目访问被拒绝: project_id={project_id}, user_id={user_id}")
raise HTTPException(status_code=404, detail="项目不存在或无权访问")
return project
@router.post("", response_model=OutlineResponse, summary="创建大纲") @router.post("", response_model=OutlineResponse, summary="创建大纲")
async def create_outline( async def create_outline(
outline: OutlineCreate, outline: OutlineCreate,
request: Request,
db: AsyncSession = Depends(get_db) db: AsyncSession = Depends(get_db)
): ):
"""创建新的章节大纲,同时创建对应的章节记录""" """创建新的章节大纲,同时创建对应的章节记录"""
# 验证项目是否存在 # 验证用户权限
result = await db.execute( user_id = getattr(request.state, 'user_id', None)
select(Project).where(Project.id == outline.project_id) await verify_project_access(outline.project_id, user_id, db)
)
project = result.scalar_one_or_none()
if not project:
raise HTTPException(status_code=404, detail="项目不存在")
# 创建大纲 # 创建大纲
db_outline = Outline(**outline.model_dump()) db_outline = Outline(**outline.model_dump())
@@ -66,9 +96,14 @@ async def create_outline(
@router.get("", response_model=OutlineListResponse, summary="获取大纲列表") @router.get("", response_model=OutlineListResponse, summary="获取大纲列表")
async def get_outlines( async def get_outlines(
project_id: str, project_id: str,
request: Request,
db: AsyncSession = Depends(get_db) db: AsyncSession = Depends(get_db)
): ):
"""获取指定项目的所有大纲""" """获取指定项目的所有大纲"""
# 验证用户权限
user_id = getattr(request.state, 'user_id', None)
await verify_project_access(project_id, user_id, db)
# 获取总数 # 获取总数
count_result = await db.execute( count_result = await db.execute(
select(func.count(Outline.id)).where(Outline.project_id == project_id) select(func.count(Outline.id)).where(Outline.project_id == project_id)
@@ -89,9 +124,14 @@ async def get_outlines(
@router.get("/project/{project_id}", response_model=OutlineListResponse, summary="获取项目的所有大纲") @router.get("/project/{project_id}", response_model=OutlineListResponse, summary="获取项目的所有大纲")
async def get_project_outlines( async def get_project_outlines(
project_id: str, project_id: str,
request: Request,
db: AsyncSession = Depends(get_db) db: AsyncSession = Depends(get_db)
): ):
"""获取指定项目的所有大纲(路径参数版本)""" """获取指定项目的所有大纲(路径参数版本)"""
# 验证用户权限
user_id = getattr(request.state, 'user_id', None)
await verify_project_access(project_id, user_id, db)
# 获取总数 # 获取总数
count_result = await db.execute( count_result = await db.execute(
select(func.count(Outline.id)).where(Outline.project_id == project_id) select(func.count(Outline.id)).where(Outline.project_id == project_id)
@@ -112,6 +152,7 @@ async def get_project_outlines(
@router.get("/{outline_id}", response_model=OutlineResponse, summary="获取大纲详情") @router.get("/{outline_id}", response_model=OutlineResponse, summary="获取大纲详情")
async def get_outline( async def get_outline(
outline_id: str, outline_id: str,
request: Request,
db: AsyncSession = Depends(get_db) db: AsyncSession = Depends(get_db)
): ):
"""根据ID获取大纲详情""" """根据ID获取大纲详情"""
@@ -123,6 +164,10 @@ async def get_outline(
if not outline: if not outline:
raise HTTPException(status_code=404, detail="大纲不存在") raise HTTPException(status_code=404, detail="大纲不存在")
# 验证用户权限
user_id = getattr(request.state, 'user_id', None)
await verify_project_access(outline.project_id, user_id, db)
return outline return outline
@@ -130,6 +175,7 @@ async def get_outline(
async def update_outline( async def update_outline(
outline_id: str, outline_id: str,
outline_update: OutlineUpdate, outline_update: OutlineUpdate,
request: Request,
db: AsyncSession = Depends(get_db) db: AsyncSession = Depends(get_db)
): ):
"""更新大纲信息,同步更新对应章节和structure字段""" """更新大纲信息,同步更新对应章节和structure字段"""
@@ -141,6 +187,10 @@ async def update_outline(
if not outline: if not outline:
raise HTTPException(status_code=404, detail="大纲不存在") raise HTTPException(status_code=404, detail="大纲不存在")
# 验证用户权限
user_id = getattr(request.state, 'user_id', None)
await verify_project_access(outline.project_id, user_id, db)
# 更新字段 # 更新字段
update_data = outline_update.model_dump(exclude_unset=True) update_data = outline_update.model_dump(exclude_unset=True)
for field, value in update_data.items(): for field, value in update_data.items():
@@ -196,6 +246,7 @@ async def update_outline(
@router.delete("/{outline_id}", summary="删除大纲") @router.delete("/{outline_id}", summary="删除大纲")
async def delete_outline( async def delete_outline(
outline_id: str, outline_id: str,
request: Request,
db: AsyncSession = Depends(get_db) db: AsyncSession = Depends(get_db)
): ):
"""删除大纲,同步删除章节,并重新排序后续项""" """删除大纲,同步删除章节,并重新排序后续项"""
@@ -207,6 +258,10 @@ async def delete_outline(
if not outline: if not outline:
raise HTTPException(status_code=404, detail="大纲不存在") raise HTTPException(status_code=404, detail="大纲不存在")
# 验证用户权限
user_id = getattr(request.state, 'user_id', None)
await verify_project_access(outline.project_id, user_id, db)
project_id = outline.project_id project_id = outline.project_id
deleted_order = outline.order_index deleted_order = outline.order_index
@@ -252,7 +307,8 @@ async def delete_outline(
@router.post("/reorder", summary="批量重排序大纲") @router.post("/reorder", summary="批量重排序大纲")
async def reorder_outlines( async def reorder_outlines(
request: OutlineReorderRequest, reorder_request: OutlineReorderRequest,
http_request: Request,
db: AsyncSession = Depends(get_db) db: AsyncSession = Depends(get_db)
): ):
""" """
@@ -261,10 +317,20 @@ async def reorder_outlines(
策略:先收集所有变更,最后一次性提交,避免临时冲突 策略:先收集所有变更,最后一次性提交,避免临时冲突
""" """
try: try:
# 验证用户权限(通过第一个大纲的project_id)
user_id = getattr(http_request.state, 'user_id', None)
if reorder_request.orders and len(reorder_request.orders) > 0:
first_outline_result = await db.execute(
select(Outline).where(Outline.id == reorder_request.orders[0].id)
)
first_outline = first_outline_result.scalar_one_or_none()
if first_outline:
await verify_project_access(first_outline.project_id, user_id, db)
# 第一步:收集所有大纲和对应的章节 # 第一步:收集所有大纲和对应的章节
outline_chapter_map = {} # {outline_id: (outline, chapter, old_order, new_order)} outline_chapter_map = {} # {outline_id: (outline, chapter, old_order, new_order)}
for item in request.orders: for item in reorder_request.orders:
outline_id = item.id outline_id = item.id
new_order = item.order_index new_order = item.order_index
@@ -341,13 +407,9 @@ async def generate_outline(
- new: 强制全新生成 - new: 强制全新生成
- continue: 强制续写模式 - continue: 强制续写模式
""" """
# 验证项目是否存在 # 验证用户权限
result = await db.execute( user_id = getattr(http_request.state, 'user_id', None)
select(Project).where(Project.id == request.project_id) project = await verify_project_access(request.project_id, user_id, db)
)
project = result.scalar_one_or_none()
if not project:
raise HTTPException(status_code=404, detail="项目不存在")
try: try:
# 获取现有大纲(强制从数据库获取最新数据,包括用户手动修改的内容) # 获取现有大纲(强制从数据库获取最新数据,包括用户手动修改的内容)
@@ -1472,13 +1534,9 @@ async def generate_outline_stream(
"model": "gpt-4" // 可选 "model": "gpt-4" // 可选
} }
""" """
# 验证项目是否存在 # 验证用户权限
result = await db.execute( user_id = getattr(request.state, 'user_id', None)
select(Project).where(Project.id == data.get("project_id")) project = await verify_project_access(data.get("project_id"), user_id, db)
)
project = result.scalar_one_or_none()
if not project:
raise HTTPException(status_code=404, detail="项目不存在")
# 判断模式 # 判断模式
mode = data.get("mode", "auto") mode = data.get("mode", "auto")
+155 -41
View File
@@ -41,17 +41,31 @@ router = APIRouter(prefix="/projects", tags=["项目管理"])
@router.post("", response_model=ProjectResponse, summary="创建项目") @router.post("", response_model=ProjectResponse, summary="创建项目")
async def create_project( async def create_project(
project: ProjectCreate, project: ProjectCreate,
db: AsyncSession = Depends(get_db) db: AsyncSession = Depends(get_db),
request: Request = None
): ):
try: try:
logger.info(f"创建新项目: {project.title}") # 从认证中间件获取用户ID
db_project = Project(**project.model_dump()) user_id = getattr(request.state, 'user_id', None)
if not user_id:
logger.warning("未登录用户尝试创建项目")
raise HTTPException(status_code=401, detail="未登录")
logger.info(f"创建新项目: {project.title}, user_id={user_id}")
# 创建项目时自动设置user_id
project_data = project.model_dump()
project_data['user_id'] = user_id
db_project = Project(**project_data)
db.add(db_project) db.add(db_project)
await db.commit() await db.commit()
await db.refresh(db_project) await db.refresh(db_project)
logger.info(f"项目创建成功: {db_project.id}") logger.info(f"项目创建成功: project_id={db_project.id}, user_id={user_id}")
return db_project return db_project
except HTTPException:
raise
except Exception as e: except Exception as e:
logger.error(f"创建项目失败: {str(e)}", exc_info=True) logger.error(f"创建项目失败: {str(e)}", exc_info=True)
raise raise
@@ -61,24 +75,38 @@ async def create_project(
async def get_projects( async def get_projects(
skip: int = 0, skip: int = 0,
limit: int = 100, limit: int = 100,
db: AsyncSession = Depends(get_db) db: AsyncSession = Depends(get_db),
request: Request = None
): ):
"""获取所有项目列表""" """获取当前用户的项目列表"""
try: try:
logger.debug(f"获取项目列表: skip={skip}, limit={limit}") # 从认证中间件获取用户ID
count_result = await db.execute(select(func.count(Project.id))) user_id = getattr(request.state, 'user_id', None)
if not user_id:
logger.warning("未登录用户尝试获取项目列表")
raise HTTPException(status_code=401, detail="未登录")
logger.debug(f"获取项目列表: user_id={user_id}, skip={skip}, limit={limit}")
# 只查询当前用户的项目
count_result = await db.execute(
select(func.count(Project.id)).where(Project.user_id == user_id)
)
total = count_result.scalar_one() total = count_result.scalar_one()
result = await db.execute( result = await db.execute(
select(Project) select(Project)
.where(Project.user_id == user_id)
.order_by(Project.updated_at.desc()) .order_by(Project.updated_at.desc())
.offset(skip) .offset(skip)
.limit(limit) .limit(limit)
) )
projects = result.scalars().all() projects = result.scalars().all()
logger.info(f"获取项目列表成功: 共{total}个项目") logger.info(f"获取项目列表成功: user_id={user_id},{total}个项目")
return ProjectListResponse(total=total, items=projects) return ProjectListResponse(total=total, items=projects)
except HTTPException:
raise
except Exception as e: except Exception as e:
logger.error(f"获取项目列表失败: {str(e)}", exc_info=True) logger.error(f"获取项目列表失败: {str(e)}", exc_info=True)
raise raise
@@ -87,17 +115,29 @@ async def get_projects(
@router.get("/{project_id}", response_model=ProjectResponse, summary="获取项目详情") @router.get("/{project_id}", response_model=ProjectResponse, summary="获取项目详情")
async def get_project( async def get_project(
project_id: str, project_id: str,
db: AsyncSession = Depends(get_db) db: AsyncSession = Depends(get_db),
request: Request = None
): ):
try: try:
logger.debug(f"获取项目详情: {project_id}") # 从认证中间件获取用户ID
user_id = getattr(request.state, 'user_id', None)
if not user_id:
logger.warning("未登录用户尝试获取项目详情")
raise HTTPException(status_code=401, detail="未登录")
logger.debug(f"获取项目详情: project_id={project_id}, user_id={user_id}")
# 只查询当前用户的项目
result = await db.execute( result = await db.execute(
select(Project).where(Project.id == project_id) select(Project).where(
Project.id == project_id,
Project.user_id == user_id
)
) )
project = result.scalar_one_or_none() project = result.scalar_one_or_none()
if not project: if not project:
logger.warning(f"项目不存在: {project_id}") logger.warning(f"项目不存在或无权访问: project_id={project_id}, user_id={user_id}")
raise HTTPException(status_code=404, detail="项目不存在") raise HTTPException(status_code=404, detail="项目不存在")
logger.info(f"获取项目详情成功: {project.title}") logger.info(f"获取项目详情成功: {project.title}")
@@ -113,17 +153,29 @@ async def get_project(
async def update_project( async def update_project(
project_id: str, project_id: str,
project_update: ProjectUpdate, project_update: ProjectUpdate,
db: AsyncSession = Depends(get_db) db: AsyncSession = Depends(get_db),
request: Request = None
): ):
try: try:
logger.info(f"更新项目: {project_id}") # 从认证中间件获取用户ID
user_id = getattr(request.state, 'user_id', None)
if not user_id:
logger.warning("未登录用户尝试更新项目")
raise HTTPException(status_code=401, detail="未登录")
logger.info(f"更新项目: project_id={project_id}, user_id={user_id}")
# 只查询当前用户的项目
result = await db.execute( result = await db.execute(
select(Project).where(Project.id == project_id) select(Project).where(
Project.id == project_id,
Project.user_id == user_id
)
) )
project = result.scalar_one_or_none() project = result.scalar_one_or_none()
if not project: if not project:
logger.warning(f"项目不存在: {project_id}") logger.warning(f"项目不存在或无权访问: project_id={project_id}, user_id={user_id}")
raise HTTPException(status_code=404, detail="项目不存在") raise HTTPException(status_code=404, detail="项目不存在")
update_data = project_update.model_dump(exclude_unset=True) update_data = project_update.model_dump(exclude_unset=True)
@@ -149,22 +201,30 @@ async def delete_project(
db: AsyncSession = Depends(get_db) db: AsyncSession = Depends(get_db)
): ):
try: try:
logger.info(f"删除项目: {project_id}") # 从认证中间件获取用户ID
user_id = getattr(request.state, 'user_id', None)
if not user_id:
logger.warning("未登录用户尝试删除项目")
raise HTTPException(status_code=401, detail="未登录")
logger.info(f"删除项目: project_id={project_id}, user_id={user_id}")
# 只查询当前用户的项目
result = await db.execute( result = await db.execute(
select(Project).where(Project.id == project_id) select(Project).where(
Project.id == project_id,
Project.user_id == user_id
)
) )
project = result.scalar_one_or_none() project = result.scalar_one_or_none()
if not project: if not project:
logger.warning(f"项目不存在: {project_id}") logger.warning(f"项目不存在或无权访问: project_id={project_id}, user_id={user_id}")
raise HTTPException(status_code=404, detail="项目不存在") raise HTTPException(status_code=404, detail="项目不存在")
project_title = project.title project_title = project.title
# 从认证中间件获取用户ID # 删除向量数据库中的记忆(user_id已在上面获取)
user_id = getattr(request.state, 'user_id', None)
# 删除向量数据库中的记忆
if user_id: if user_id:
try: try:
await memory_service.delete_project_memories(user_id, project_id) await memory_service.delete_project_memories(user_id, project_id)
@@ -234,22 +294,33 @@ async def delete_project(
@router.get("/{project_id}/export", summary="导出项目章节为TXT") @router.get("/{project_id}/export", summary="导出项目章节为TXT")
async def export_project_chapters( async def export_project_chapters(
project_id: str, project_id: str,
db: AsyncSession = Depends(get_db) db: AsyncSession = Depends(get_db),
request: Request = None
): ):
""" """
导出项目的所有章节内容为TXT文本文件 导出项目的所有章节内容为TXT文本文件
按章节顺序组织,包含项目基本信息 按章节顺序组织,包含项目基本信息
""" """
try: try:
logger.info(f"开始导出项目: {project_id}") # 从认证中间件获取用户ID
user_id = getattr(request.state, 'user_id', None)
if not user_id:
logger.warning("未登录用户尝试导出项目")
raise HTTPException(status_code=401, detail="未登录")
logger.info(f"开始导出项目: project_id={project_id}, user_id={user_id}")
# 只查询当前用户的项目
result = await db.execute( result = await db.execute(
select(Project).where(Project.id == project_id) select(Project).where(
Project.id == project_id,
Project.user_id == user_id
)
) )
project = result.scalar_one_or_none() project = result.scalar_one_or_none()
if not project: if not project:
logger.warning(f"项目不存在: {project_id}") logger.warning(f"项目不存在或无权访问: project_id={project_id}, user_id={user_id}")
raise HTTPException(status_code=404, detail="项目不存在") raise HTTPException(status_code=404, detail="项目不存在")
chapters_result = await db.execute( chapters_result = await db.execute(
@@ -326,6 +397,7 @@ async def export_project_chapters(
@router.post("/{project_id}/check-consistency", summary="检查数据一致性") @router.post("/{project_id}/check-consistency", summary="检查数据一致性")
async def check_project_consistency( async def check_project_consistency(
project_id: str, project_id: str,
request: Request,
auto_fix: bool = True, auto_fix: bool = True,
db: AsyncSession = Depends(get_db) db: AsyncSession = Depends(get_db)
): ):
@@ -343,15 +415,25 @@ async def check_project_consistency(
- organization_members: 验证组织成员数据完整性 - organization_members: 验证组织成员数据完整性
""" """
try: try:
logger.info(f"开始数据一致性检查: {project_id}, auto_fix={auto_fix}") # 从认证中间件获取用户ID
user_id = getattr(request.state, 'user_id', None)
if not user_id:
logger.warning("未登录用户尝试检查数据一致性")
raise HTTPException(status_code=401, detail="未登录")
logger.info(f"开始数据一致性检查: project_id={project_id}, user_id={user_id}, auto_fix={auto_fix}")
# 只查询当前用户的项目
result = await db.execute( result = await db.execute(
select(Project).where(Project.id == project_id) select(Project).where(
Project.id == project_id,
Project.user_id == user_id
)
) )
project = result.scalar_one_or_none() project = result.scalar_one_or_none()
if not project: if not project:
logger.warning(f"项目不存在: {project_id}") logger.warning(f"项目不存在或无权访问: project_id={project_id}, user_id={user_id}")
raise HTTPException(status_code=404, detail="项目不存在") raise HTTPException(status_code=404, detail="项目不存在")
report = await run_full_data_consistency_check(project_id, db, auto_fix) report = await run_full_data_consistency_check(project_id, db, auto_fix)
@@ -369,6 +451,7 @@ async def check_project_consistency(
@router.post("/{project_id}/fix-organizations", summary="修复组织记录") @router.post("/{project_id}/fix-organizations", summary="修复组织记录")
async def fix_project_organizations( async def fix_project_organizations(
project_id: str, project_id: str,
request: Request,
db: AsyncSession = Depends(get_db) db: AsyncSession = Depends(get_db)
): ):
""" """
@@ -377,15 +460,25 @@ async def fix_project_organizations(
为所有is_organization=True但没有Organization记录的Character创建记录 为所有is_organization=True但没有Organization记录的Character创建记录
""" """
try: try:
logger.info(f"开始修复组织记录: {project_id}") # 从认证中间件获取用户ID
user_id = getattr(request.state, 'user_id', None)
if not user_id:
logger.warning("未登录用户尝试修复组织记录")
raise HTTPException(status_code=401, detail="未登录")
logger.info(f"开始修复组织记录: project_id={project_id}, user_id={user_id}")
# 只查询当前用户的项目
result = await db.execute( result = await db.execute(
select(Project).where(Project.id == project_id) select(Project).where(
Project.id == project_id,
Project.user_id == user_id
)
) )
project = result.scalar_one_or_none() project = result.scalar_one_or_none()
if not project: if not project:
logger.warning(f"项目不存在: {project_id}") logger.warning(f"项目不存在或无权访问: project_id={project_id}, user_id={user_id}")
raise HTTPException(status_code=404, detail="项目不存在") raise HTTPException(status_code=404, detail="项目不存在")
fixed_count, total_count = await fix_missing_organization_records(project_id, db) fixed_count, total_count = await fix_missing_organization_records(project_id, db)
@@ -407,6 +500,7 @@ async def fix_project_organizations(
@router.post("/{project_id}/fix-member-counts", summary="修复成员计数") @router.post("/{project_id}/fix-member-counts", summary="修复成员计数")
async def fix_project_member_counts( async def fix_project_member_counts(
project_id: str, project_id: str,
request: Request,
db: AsyncSession = Depends(get_db) db: AsyncSession = Depends(get_db)
): ):
""" """
@@ -415,15 +509,25 @@ async def fix_project_member_counts(
从实际成员记录重新计算每个组织的member_count 从实际成员记录重新计算每个组织的member_count
""" """
try: try:
logger.info(f"开始修复成员计数: {project_id}") # 从认证中间件获取用户ID
user_id = getattr(request.state, 'user_id', None)
if not user_id:
logger.warning("未登录用户尝试修复成员计数")
raise HTTPException(status_code=401, detail="未登录")
logger.info(f"开始修复成员计数: project_id={project_id}, user_id={user_id}")
# 只查询当前用户的项目
result = await db.execute( result = await db.execute(
select(Project).where(Project.id == project_id) select(Project).where(
Project.id == project_id,
Project.user_id == user_id
)
) )
project = result.scalar_one_or_none() project = result.scalar_one_or_none()
if not project: if not project:
logger.warning(f"项目不存在: {project_id}") logger.warning(f"项目不存在或无权访问: project_id={project_id}, user_id={user_id}")
raise HTTPException(status_code=404, detail="项目不存在") raise HTTPException(status_code=404, detail="项目不存在")
fixed_count, total_count = await fix_organization_member_counts(project_id, db) fixed_count, total_count = await fix_organization_member_counts(project_id, db)
@@ -445,6 +549,7 @@ async def fix_project_member_counts(
@router.post("/{project_id}/export-data", summary="导出项目数据为JSON") @router.post("/{project_id}/export-data", summary="导出项目数据为JSON")
async def export_project_data( async def export_project_data(
project_id: str, project_id: str,
request: Request,
options: ExportOptions, options: ExportOptions,
db: AsyncSession = Depends(get_db) db: AsyncSession = Depends(get_db)
): ):
@@ -459,16 +564,25 @@ async def export_project_data(
JSON文件下载 JSON文件下载
""" """
try: try:
logger.info(f"开始导出项目数据: {project_id}") # 从认证中间件获取用户ID
user_id = getattr(request.state, 'user_id', None)
if not user_id:
logger.warning("未登录用户尝试导出项目数据")
raise HTTPException(status_code=401, detail="未登录")
# 检查项目是否存在 logger.info(f"开始导出项目数据: project_id={project_id}, user_id={user_id}")
# 只查询当前用户的项目
result = await db.execute( result = await db.execute(
select(Project).where(Project.id == project_id) select(Project).where(
Project.id == project_id,
Project.user_id == user_id
)
) )
project = result.scalar_one_or_none() project = result.scalar_one_or_none()
if not project: if not project:
logger.warning(f"项目不存在: {project_id}") logger.warning(f"项目不存在或无权访问: project_id={project_id}, user_id={user_id}")
raise HTTPException(status_code=404, detail="项目不存在") raise HTTPException(status_code=404, detail="项目不存在")
# 导出数据 # 导出数据
+47 -1
View File
@@ -1,5 +1,5 @@
"""关系管理API""" """关系管理API"""
from fastapi import APIRouter, Depends, HTTPException, Query from fastapi import APIRouter, Depends, HTTPException, Query, Request
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, or_, and_ from sqlalchemy import select, or_, and_
from typing import List, Optional from typing import List, Optional
@@ -12,6 +12,7 @@ from app.models.relationship import (
OrganizationMember OrganizationMember
) )
from app.models.character import Character from app.models.character import Character
from app.models.project import Project
from app.schemas.relationship import ( from app.schemas.relationship import (
RelationshipTypeResponse, RelationshipTypeResponse,
CharacterRelationshipCreate, CharacterRelationshipCreate,
@@ -27,6 +28,26 @@ router = APIRouter(prefix="/relationships", tags=["关系管理"])
logger = get_logger(__name__) logger = get_logger(__name__)
async def verify_project_access(project_id: str, user_id: str, db: AsyncSession) -> Project:
"""验证用户是否有权访问指定项目"""
if not user_id:
raise HTTPException(status_code=401, detail="未登录")
result = await db.execute(
select(Project).where(
Project.id == project_id,
Project.user_id == user_id
)
)
project = result.scalar_one_or_none()
if not project:
logger.warning(f"项目访问被拒绝: project_id={project_id}, user_id={user_id}")
raise HTTPException(status_code=404, detail="项目不存在或无权访问")
return project
@router.get("/types", response_model=List[RelationshipTypeResponse], summary="获取关系类型列表") @router.get("/types", response_model=List[RelationshipTypeResponse], summary="获取关系类型列表")
async def get_relationship_types(db: AsyncSession = Depends(get_db)): async def get_relationship_types(db: AsyncSession = Depends(get_db)):
"""获取所有预定义的关系类型""" """获取所有预定义的关系类型"""
@@ -38,9 +59,14 @@ async def get_relationship_types(db: AsyncSession = Depends(get_db)):
@router.get("/project/{project_id}", response_model=List[CharacterRelationshipResponse], summary="获取项目的所有关系") @router.get("/project/{project_id}", response_model=List[CharacterRelationshipResponse], summary="获取项目的所有关系")
async def get_project_relationships( async def get_project_relationships(
project_id: str, project_id: str,
request: Request,
character_id: Optional[str] = Query(None, description="筛选特定角色的关系"), character_id: Optional[str] = Query(None, description="筛选特定角色的关系"),
db: AsyncSession = Depends(get_db) db: AsyncSession = Depends(get_db)
): ):
# 验证用户权限
user_id = getattr(request.state, 'user_id', None)
await verify_project_access(project_id, user_id, db)
""" """
获取项目中的所有角色关系 获取项目中的所有角色关系
@@ -70,8 +96,13 @@ async def get_project_relationships(
@router.get("/graph/{project_id}", response_model=RelationshipGraphData, summary="获取关系图谱数据") @router.get("/graph/{project_id}", response_model=RelationshipGraphData, summary="获取关系图谱数据")
async def get_relationship_graph( async def get_relationship_graph(
project_id: str, project_id: str,
request: Request,
db: AsyncSession = Depends(get_db) db: AsyncSession = Depends(get_db)
): ):
# 验证用户权限
user_id = getattr(request.state, 'user_id', None)
await verify_project_access(project_id, user_id, db)
""" """
获取用于可视化的关系图谱数据 获取用于可视化的关系图谱数据
@@ -122,6 +153,7 @@ async def get_relationship_graph(
@router.post("/", response_model=CharacterRelationshipResponse, summary="创建角色关系") @router.post("/", response_model=CharacterRelationshipResponse, summary="创建角色关系")
async def create_relationship( async def create_relationship(
relationship: CharacterRelationshipCreate, relationship: CharacterRelationshipCreate,
request: Request,
db: AsyncSession = Depends(get_db) db: AsyncSession = Depends(get_db)
): ):
""" """
@@ -131,6 +163,10 @@ async def create_relationship(
- 可以指定预定义的关系类型或自定义关系名称 - 可以指定预定义的关系类型或自定义关系名称
- 可以设置亲密度、状态等属性 - 可以设置亲密度、状态等属性
""" """
# 验证用户权限
user_id = getattr(request.state, 'user_id', None)
await verify_project_access(relationship.project_id, user_id, db)
# 验证角色是否存在 # 验证角色是否存在
char_from = await db.execute( char_from = await db.execute(
select(Character).where(Character.id == relationship.character_from_id) select(Character).where(Character.id == relationship.character_from_id)
@@ -161,6 +197,7 @@ async def create_relationship(
async def update_relationship( async def update_relationship(
relationship_id: str, relationship_id: str,
relationship: CharacterRelationshipUpdate, relationship: CharacterRelationshipUpdate,
request: Request,
db: AsyncSession = Depends(get_db) db: AsyncSession = Depends(get_db)
): ):
"""更新角色关系的属性(亲密度、状态等)""" """更新角色关系的属性(亲密度、状态等)"""
@@ -174,6 +211,10 @@ async def update_relationship(
if not db_rel: if not db_rel:
raise HTTPException(status_code=404, detail="关系不存在") raise HTTPException(status_code=404, detail="关系不存在")
# 验证用户权限
user_id = getattr(request.state, 'user_id', None)
await verify_project_access(db_rel.project_id, user_id, db)
# 更新字段 # 更新字段
update_data = relationship.model_dump(exclude_unset=True) update_data = relationship.model_dump(exclude_unset=True)
for field, value in update_data.items(): for field, value in update_data.items():
@@ -189,6 +230,7 @@ async def update_relationship(
@router.delete("/{relationship_id}", summary="删除关系") @router.delete("/{relationship_id}", summary="删除关系")
async def delete_relationship( async def delete_relationship(
relationship_id: str, relationship_id: str,
request: Request,
db: AsyncSession = Depends(get_db) db: AsyncSession = Depends(get_db)
): ):
"""删除角色关系""" """删除角色关系"""
@@ -202,6 +244,10 @@ async def delete_relationship(
if not db_rel: if not db_rel:
raise HTTPException(status_code=404, detail="关系不存在") raise HTTPException(status_code=404, detail="关系不存在")
# 验证用户权限
user_id = getattr(request.state, 'user_id', None)
await verify_project_access(db_rel.project_id, user_id, db)
await db.delete(db_rel) await db.delete(db_rel)
await db.commit() await db.commit()
+6
View File
@@ -183,7 +183,13 @@ async def world_building_generator(
# 保存到数据库 # 保存到数据库
yield await SSEResponse.send_progress("保存到数据库...", 90) yield await SSEResponse.send_progress("保存到数据库...", 90)
# 确保user_id存在
if not user_id:
yield await SSEResponse.send_error("用户ID缺失,无法创建项目", 401)
return
project = Project( project = Project(
user_id=user_id, # 添加user_id字段
title=title, title=title,
description=description, description=description,
theme=theme, theme=theme,
+50 -30
View File
@@ -1,5 +1,5 @@
"""写作风格管理 API""" """写作风格管理 API"""
from fastapi import APIRouter, Depends, HTTPException from fastapi import APIRouter, Depends, HTTPException, Request
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, func, delete from sqlalchemy import select, func, delete
from typing import List from typing import List
@@ -16,8 +16,30 @@ from ..schemas.writing_style import (
SetDefaultStyleRequest SetDefaultStyleRequest
) )
from ..services.prompt_service import WritingStyleManager from ..services.prompt_service import WritingStyleManager
from ..logger import get_logger
router = APIRouter(prefix="/writing-styles", tags=["writing-styles"]) router = APIRouter(prefix="/writing-styles", tags=["writing-styles"])
logger = get_logger(__name__)
async def verify_project_access(project_id: str, user_id: str, db: AsyncSession) -> Project:
"""验证用户是否有权访问指定项目"""
if not user_id:
raise HTTPException(status_code=401, detail="未登录")
result = await db.execute(
select(Project).where(
Project.id == project_id,
Project.user_id == user_id
)
)
project = result.scalar_one_or_none()
if not project:
logger.warning(f"项目访问被拒绝: project_id={project_id}, user_id={user_id}")
raise HTTPException(status_code=404, detail="项目不存在或无权访问")
return project
@router.get("/presets/list", response_model=List[dict]) @router.get("/presets/list", response_model=List[dict])
@@ -42,6 +64,7 @@ async def get_preset_styles():
@router.post("", response_model=WritingStyleResponse, status_code=201) @router.post("", response_model=WritingStyleResponse, status_code=201)
async def create_writing_style( async def create_writing_style(
style_data: WritingStyleCreate, style_data: WritingStyleCreate,
request: Request,
db: AsyncSession = Depends(get_db) db: AsyncSession = Depends(get_db)
): ):
""" """
@@ -50,13 +73,9 @@ async def create_writing_style(
- **基于预设创建**:提供 preset_id,系统会自动填充预设内容 - **基于预设创建**:提供 preset_id,系统会自动填充预设内容
- **完全自定义**:不提供 preset_id,需要手动填写所有字段 - **完全自定义**:不提供 preset_id,需要手动填写所有字段
""" """
# 验证项目是否存在 # 验证用户权限
result = await db.execute( user_id = getattr(request.state, 'user_id', None)
select(Project).where(Project.id == style_data.project_id) await verify_project_access(style_data.project_id, user_id, db)
)
project = result.scalar_one_or_none()
if not project:
raise HTTPException(status_code=404, detail="项目不存在")
# 如果基于预设创建,获取预设内容 # 如果基于预设创建,获取预设内容
if style_data.preset_id: if style_data.preset_id:
@@ -120,6 +139,7 @@ async def create_writing_style(
@router.get("/project/{project_id}", response_model=WritingStyleListResponse) @router.get("/project/{project_id}", response_model=WritingStyleListResponse)
async def get_project_styles( async def get_project_styles(
project_id: str, project_id: str,
request: Request,
db: AsyncSession = Depends(get_db) db: AsyncSession = Depends(get_db)
): ):
""" """
@@ -128,13 +148,9 @@ async def get_project_styles(
返回:全局预设风格 + 该项目的自定义风格 返回:全局预设风格 + 该项目的自定义风格
按 order_index 排序,并标记哪个是当前项目的默认风格 按 order_index 排序,并标记哪个是当前项目的默认风格
""" """
# 验证项目是否存在 # 验证用户权限
result = await db.execute( user_id = getattr(request.state, 'user_id', None)
select(Project).where(Project.id == project_id) await verify_project_access(project_id, user_id, db)
)
project = result.scalar_one_or_none()
if not project:
raise HTTPException(status_code=404, detail="项目不存在")
# 获取该项目的默认风格ID # 获取该项目的默认风格ID
result = await db.execute( result = await db.execute(
@@ -222,6 +238,7 @@ async def get_writing_style(
async def update_writing_style( async def update_writing_style(
style_id: int, style_id: int,
style_data: WritingStyleUpdate, style_data: WritingStyleUpdate,
request: Request,
db: AsyncSession = Depends(get_db) db: AsyncSession = Depends(get_db)
): ):
""" """
@@ -241,6 +258,10 @@ async def update_writing_style(
if style.project_id is None: if style.project_id is None:
raise HTTPException(status_code=403, detail="不能修改全局预设风格,只能修改自定义风格") raise HTTPException(status_code=403, detail="不能修改全局预设风格,只能修改自定义风格")
# 验证用户权限
user_id = getattr(request.state, 'user_id', None)
await verify_project_access(style.project_id, user_id, db)
# 更新字段 # 更新字段
update_data = style_data.model_dump(exclude_unset=True) update_data = style_data.model_dump(exclude_unset=True)
@@ -279,6 +300,7 @@ async def update_writing_style(
@router.delete("/{style_id}", status_code=204) @router.delete("/{style_id}", status_code=204)
async def delete_writing_style( async def delete_writing_style(
style_id: int, style_id: int,
request: Request,
db: AsyncSession = Depends(get_db) db: AsyncSession = Depends(get_db)
): ):
""" """
@@ -300,6 +322,10 @@ async def delete_writing_style(
if style.project_id is None: if style.project_id is None:
raise HTTPException(status_code=403, detail="不能删除全局预设风格,只能删除自定义风格") raise HTTPException(status_code=403, detail="不能删除全局预设风格,只能删除自定义风格")
# 验证用户权限
user_id = getattr(request.state, 'user_id', None)
await verify_project_access(style.project_id, user_id, db)
# 检查是否有项目将其设置为默认风格 # 检查是否有项目将其设置为默认风格
result = await db.execute( result = await db.execute(
select(ProjectDefaultStyle).where(ProjectDefaultStyle.style_id == style_id) select(ProjectDefaultStyle).where(ProjectDefaultStyle.style_id == style_id)
@@ -321,6 +347,7 @@ async def delete_writing_style(
async def set_default_style( async def set_default_style(
style_id: int, style_id: int,
request_data: SetDefaultStyleRequest, request_data: SetDefaultStyleRequest,
request: Request,
db: AsyncSession = Depends(get_db) db: AsyncSession = Depends(get_db)
): ):
""" """
@@ -335,13 +362,9 @@ async def set_default_style(
""" """
project_id = request_data.project_id project_id = request_data.project_id
# 验证项目是否存在 # 验证用户权限
result = await db.execute( user_id = getattr(request.state, 'user_id', None)
select(Project).where(Project.id == project_id) await verify_project_access(project_id, user_id, db)
)
project = result.scalar_one_or_none()
if not project:
raise HTTPException(status_code=404, detail="项目不存在")
# 验证风格是否存在 # 验证风格是否存在
result = await db.execute( result = await db.execute(
@@ -379,6 +402,7 @@ async def set_default_style(
@router.post("/project/{project_id}/init-defaults", response_model=WritingStyleListResponse) @router.post("/project/{project_id}/init-defaults", response_model=WritingStyleListResponse)
async def initialize_default_styles( async def initialize_default_styles(
project_id: str, project_id: str,
request: Request,
db: AsyncSession = Depends(get_db) db: AsyncSession = Depends(get_db)
): ):
""" """
@@ -387,13 +411,9 @@ async def initialize_default_styles(
新架构下,预设风格是全局的,不需要为每个项目单独初始化 新架构下,预设风格是全局的,不需要为每个项目单独初始化
该接口保留用于兼容性,直接返回项目可用的所有风格 该接口保留用于兼容性,直接返回项目可用的所有风格
""" """
# 验证项目是否存在 # 验证用户权限
result = await db.execute( user_id = getattr(request.state, 'user_id', None)
select(Project).where(Project.id == project_id) await verify_project_access(project_id, user_id, db)
)
project = result.scalar_one_or_none()
if not project:
raise HTTPException(status_code=404, detail="项目不存在")
# 直接返回项目可用的所有风格(全局预设 + 项目自定义) # 直接返回项目可用的所有风格(全局预设 + 项目自定义)
return await get_project_styles(project_id, db) return await get_project_styles(project_id, request, db)
+31 -6
View File
@@ -3,6 +3,7 @@ from pydantic_settings import BaseSettings
from typing import Optional from typing import Optional
from pathlib import Path from pathlib import Path
import logging import logging
import os
# 获取项目根目录(从backend/app/config.py向上两级) # 获取项目根目录(从backend/app/config.py向上两级)
PROJECT_ROOT = Path(__file__).parent.parent PROJECT_ROOT = Path(__file__).parent.parent
@@ -12,13 +13,15 @@ DATA_DIR.mkdir(exist_ok=True)
# 配置模块使用标准logging(在logger.py初始化之前) # 配置模块使用标准logging(在logger.py初始化之前)
config_logger = logging.getLogger(__name__) config_logger = logging.getLogger(__name__)
# 数据库文件路径(绝对路径) # 数据库配置:支持PostgreSQL和SQLite
# 优先使用环境变量DATABASE_URL,否则使用SQLite
DB_FILE = DATA_DIR / "ai_story.db" DB_FILE = DATA_DIR / "ai_story.db"
DEFAULT_SQLITE_URL = f"sqlite+aiosqlite:///{str(DB_FILE.absolute()).replace(chr(92), '/')}"
# 生成数据库URL(在类外部生成,确保使用绝对路径) # 从环境变量获取数据库URL,如果未设置则使用SQLite
# 将Windows反斜杠转换为正斜杠,SQLite URL格式要求 DATABASE_URL = os.getenv("DATABASE_URL", DEFAULT_SQLITE_URL)
DATABASE_URL = f"sqlite+aiosqlite:///{str(DB_FILE.absolute()).replace(chr(92), '/')}"
config_logger.debug(f"数据库文件路径: {DB_FILE}") config_logger.debug(f"数据库类型: {'PostgreSQL' if 'postgresql' in DATABASE_URL else 'SQLite'}")
config_logger.debug(f"数据库URL: {DATABASE_URL}") config_logger.debug(f"数据库URL: {DATABASE_URL}")
class Settings(BaseSettings): class Settings(BaseSettings):
@@ -41,9 +44,31 @@ class Settings(BaseSettings):
# CORS配置 # CORS配置
cors_origins: list[str] = ["http://localhost:8000", "http://127.0.0.1:8000"] cors_origins: list[str] = ["http://localhost:8000", "http://127.0.0.1:8000"]
# 数据库配置 - 使用预先计算好的绝对路径URL # 数据库配置 - 支持PostgreSQL和SQLite
database_url: str = DATABASE_URL database_url: str = DATABASE_URL
# PostgreSQL连接池配置(优化后支持80-150并发用户)
database_pool_size: int = 30 # 核心连接池大小(从20提升到30
database_max_overflow: int = 20 # 最大溢出连接数(从10提升到20
database_pool_timeout: int = 60 # 连接池超时秒数(从30提升到60
database_pool_recycle: int = 1800 # 连接回收时间秒数(从3600降低到1800,30分钟)
database_pool_pre_ping: bool = True # 连接前ping检测,确保连接有效
database_pool_use_lifo: bool = True # 使用LIFO策略提高连接复用率
# 会话监控配置
database_session_max_active: int = 50 # 活跃会话警告阈值(从100降低到50)
database_session_leak_threshold: int = 100 # 会话泄漏严重告警阈值
# SQLite优化配置
sqlite_cache_size_mb: int = 128 # SQLite缓存大小MB(从64提升到128
sqlite_mmap_size_mb: int = 256 # 内存映射I/O大小MB
sqlite_wal_autocheckpoint: int = 1000 # WAL自动检查点间隔
# 数据库监控配置
database_enable_slow_query_log: bool = True # 启用慢查询日志
database_slow_query_threshold: float = 1.0 # 慢查询阈值(秒)
database_enable_metrics: bool = True # 启用性能指标收集
# AI服务配置 # AI服务配置
openai_api_key: Optional[str] = None openai_api_key: Optional[str] = None
openai_base_url: Optional[str] = None openai_base_url: Optional[str] = None
+219 -7
View File
@@ -45,12 +45,59 @@ _session_stats = {
async def get_engine(user_id: str): async def get_engine(user_id: str):
"""获取或创建用户专属的数据库引擎(线程安全) """获取或创建用户专属的数据库引擎(线程安全)
支持PostgreSQL和SQLite两种数据库:
- PostgreSQL: 所有用户共享一个数据库,通过user_id字段隔离数据
- SQLite: 每个用户一个独立的数据库文件
Args: Args:
user_id: 用户ID user_id: 用户ID
Returns: Returns:
用户专属的异步引擎 用户专属的异步引擎
""" """
# PostgreSQL模式:所有用户共享同一个引擎
if "postgresql" in settings.database_url:
cache_key = "shared_postgres"
if cache_key in _engine_cache:
return _engine_cache[cache_key]
async with _cache_lock:
if cache_key not in _engine_cache:
# 优化后的PostgreSQL连接配置
connect_args = {
"server_settings": {
"application_name": settings.app_name,
"jit": "off", # 关闭JIT以提高短查询性能
},
"command_timeout": 60, # 命令超时60秒
"statement_cache_size": 500, # 启用语句缓存,提升重复查询性能
}
engine = create_async_engine(
settings.database_url,
echo=False, # 生产环境关闭SQL日志
future=True,
pool_size=settings.database_pool_size, # 核心连接数:30
max_overflow=settings.database_max_overflow, # 溢出连接数:20
pool_timeout=settings.database_pool_timeout, # 连接超时:60秒
pool_pre_ping=settings.database_pool_pre_ping, # 连接前检测
pool_recycle=settings.database_pool_recycle, # 连接回收:1800秒
pool_use_lifo=settings.database_pool_use_lifo, # LIFO策略提高复用
connect_args=connect_args
)
_engine_cache[cache_key] = engine
logger.info(
f"✅ PostgreSQL引擎已创建(优化配置)\n"
f" ├─ 连接池: {settings.database_pool_size} 核心 + {settings.database_max_overflow} 溢出 = {settings.database_pool_size + settings.database_max_overflow} 总连接\n"
f" ├─ 超时: {settings.database_pool_timeout}\n"
f" ├─ 回收: {settings.database_pool_recycle}\n"
f" ├─ 策略: LIFO(提高复用率)\n"
f" └─ 预估并发: 80-150用户"
)
return _engine_cache[cache_key]
# SQLite模式:每个用户独立的数据库文件
if user_id in _engine_cache: if user_id in _engine_cache:
return _engine_cache[user_id] return _engine_cache[user_id]
@@ -76,18 +123,30 @@ async def get_engine(user_id: str):
) )
try: try:
# 应用优化后的SQLite配置
cache_size = -1024 * settings.sqlite_cache_size_mb # 负数表示KB单位
mmap_size = settings.sqlite_mmap_size_mb * 1024 * 1024 # 转换为字节
async with engine.begin() as conn: async with engine.begin() as conn:
await conn.execute(text("PRAGMA journal_mode=WAL")) await conn.execute(text("PRAGMA journal_mode=WAL"))
await conn.execute(text("PRAGMA synchronous=NORMAL")) await conn.execute(text("PRAGMA synchronous=NORMAL"))
await conn.execute(text("PRAGMA cache_size=-64000")) await conn.execute(text(f"PRAGMA cache_size={cache_size}")) # 128MB缓存
await conn.execute(text(f"PRAGMA mmap_size={mmap_size}")) # 256MB内存映射
await conn.execute(text("PRAGMA temp_store=MEMORY")) await conn.execute(text("PRAGMA temp_store=MEMORY"))
await conn.execute(text("PRAGMA busy_timeout=5000")) await conn.execute(text("PRAGMA busy_timeout=5000"))
await conn.execute(text(f"PRAGMA wal_autocheckpoint={settings.sqlite_wal_autocheckpoint}"))
logger.info(f"✅ 用户 {user_id} 的数据库已优化(WAL模式 + 64MB缓存)") logger.info(
f"✅ 用户 {user_id} 的SQLite数据库已优化\n"
f" ├─ WAL模式\n"
f" ├─ 缓存: {settings.sqlite_cache_size_mb}MB\n"
f" ├─ 内存映射: {settings.sqlite_mmap_size_mb}MB\n"
f" └─ 预估并发: 15-20写入用户"
)
except Exception as e: except Exception as e:
logger.warning(f"⚠️ 用户 {user_id} 数据库优化失败: {str(e)}") logger.warning(f"⚠️ 用户 {user_id} SQLite数据库优化失败: {str(e)}")
_engine_cache[user_id] = engine _engine_cache[user_id] = engine
logger.info(f"为用户 {user_id} 创建数据库引擎") logger.info(f"为用户 {user_id} 创建SQLite数据库引擎")
return _engine_cache[user_id] return _engine_cache[user_id]
@@ -157,8 +216,11 @@ async def get_db(request: Request):
logger.debug(f"📊 会话关闭 [User:{user_id}][ID:{session_id}] - 活跃:{_session_stats['active']}, 总创建:{_session_stats['created']}, 总关闭:{_session_stats['closed']}, 错误:{_session_stats['errors']}") logger.debug(f"📊 会话关闭 [User:{user_id}][ID:{session_id}] - 活跃:{_session_stats['active']}, 总创建:{_session_stats['created']}, 总关闭:{_session_stats['closed']}, 错误:{_session_stats['errors']}")
if _session_stats["active"] > 100: # 使用优化后的会话监控阈值
logger.warning(f"🚨 活跃会话数过多: {_session_stats['active']},可能存在连接泄漏!") if _session_stats["active"] > settings.database_session_leak_threshold:
logger.error(f"🚨 严重告警:活跃会话数 {_session_stats['active']} 超过泄漏阈值 {settings.database_session_leak_threshold}")
elif _session_stats["active"] > settings.database_session_max_active:
logger.warning(f"⚠️ 警告:活跃会话数 {_session_stats['active']} 超过警告阈值 {settings.database_session_max_active},可能存在连接泄漏!")
elif _session_stats["active"] < 0: elif _session_stats["active"] < 0:
logger.error(f"🚨 活跃会话数异常: {_session_stats['active']},统计可能不准确!") logger.error(f"🚨 活跃会话数异常: {_session_stats['active']},统计可能不准确!")
@@ -324,4 +386,154 @@ async def close_db():
logger.info("所有数据库连接已关闭") logger.info("所有数据库连接已关闭")
except Exception as e: except Exception as e:
logger.error(f"关闭数据库连接失败: {str(e)}", exc_info=True) logger.error(f"关闭数据库连接失败: {str(e)}", exc_info=True)
raise raise
async def get_database_stats():
"""获取数据库连接和会话统计信息
Returns:
dict: 包含数据库统计信息的字典
"""
from app.config import settings
stats = {
"session_stats": {
"created": _session_stats["created"],
"closed": _session_stats["closed"],
"active": _session_stats["active"],
"errors": _session_stats["errors"],
"generator_exits": _session_stats["generator_exits"],
"last_check": _session_stats["last_check"],
},
"engine_cache": {
"total_engines": len(_engine_cache),
"engine_keys": list(_engine_cache.keys()),
},
"config": {
"database_type": "PostgreSQL" if "postgresql" in settings.database_url else "SQLite",
"pool_size": settings.database_pool_size,
"max_overflow": settings.database_max_overflow,
"total_connections": settings.database_pool_size + settings.database_max_overflow,
"pool_timeout": settings.database_pool_timeout,
"session_max_active_threshold": settings.database_session_max_active,
"session_leak_threshold": settings.database_session_leak_threshold,
},
"health": {
"status": "healthy",
"warnings": [],
"errors": [],
}
}
# 健康检查
if _session_stats["active"] > settings.database_session_leak_threshold:
stats["health"]["status"] = "critical"
stats["health"]["errors"].append(
f"活跃会话数 {_session_stats['active']} 超过泄漏阈值 {settings.database_session_leak_threshold}"
)
elif _session_stats["active"] > settings.database_session_max_active:
stats["health"]["status"] = "warning"
stats["health"]["warnings"].append(
f"活跃会话数 {_session_stats['active']} 超过警告阈值 {settings.database_session_max_active}"
)
if _session_stats["active"] < 0:
stats["health"]["status"] = "error"
stats["health"]["errors"].append(f"活跃会话数异常: {_session_stats['active']}")
error_rate = (_session_stats["errors"] / max(_session_stats["created"], 1)) * 100
if error_rate > 5:
stats["health"]["status"] = "warning"
stats["health"]["warnings"].append(f"会话错误率过高: {error_rate:.2f}%")
stats["health"]["error_rate"] = f"{error_rate:.2f}%"
return stats
async def check_database_health(user_id: str = None) -> dict:
"""检查数据库连接健康状态
Args:
user_id: 可选的用户ID,如果提供则检查特定用户的数据库
Returns:
dict: 健康检查结果
"""
result = {
"healthy": True,
"checks": {},
"timestamp": datetime.now().isoformat()
}
try:
# 检查引擎是否存在
if user_id:
engine = await get_engine(user_id)
cache_key = user_id
else:
if "postgresql" in settings.database_url:
cache_key = "shared_postgres"
if cache_key not in _engine_cache:
result["checks"]["engine"] = {"status": "not_initialized", "healthy": True}
return result
engine = _engine_cache[cache_key]
else:
result["checks"]["engine"] = {"status": "skipped", "message": "需要提供user_id检查SQLite"}
return result
# 测试数据库连接
AsyncSessionLocal = async_sessionmaker(
engine,
class_=AsyncSession,
expire_on_commit=False
)
async with AsyncSessionLocal() as session:
# 执行简单查询测试连接
await session.execute(text("SELECT 1"))
result["checks"]["connection"] = {"status": "ok", "healthy": True}
# 检查连接池状态(仅PostgreSQL)
if hasattr(engine.pool, 'size'):
pool_status = {
"size": engine.pool.size(),
"checked_in": engine.pool.checkedin(),
"checked_out": engine.pool.checkedout(),
"overflow": engine.pool.overflow(),
"healthy": True
}
# 连接池健康检查
if engine.pool.overflow() >= settings.database_max_overflow:
pool_status["healthy"] = False
pool_status["warning"] = "连接池溢出已满"
result["healthy"] = False
result["checks"]["pool"] = pool_status
except Exception as e:
result["healthy"] = False
result["checks"]["error"] = {
"status": "error",
"message": str(e),
"healthy": False
}
logger.error(f"数据库健康检查失败: {str(e)}", exc_info=True)
return result
async def reset_session_stats():
"""重置会话统计信息(用于测试或维护)"""
global _session_stats
_session_stats = {
"created": 0,
"closed": 0,
"active": 0,
"errors": 0,
"generator_exits": 0,
"last_check": datetime.now().isoformat()
}
logger.info("✅ 会话统计信息已重置")
return _session_stats
-1
View File
@@ -28,7 +28,6 @@ logger = get_logger(__name__)
async def lifespan(app: FastAPI): async def lifespan(app: FastAPI):
"""应用生命周期管理""" """应用生命周期管理"""
logger.info("应用启动,等待用户登录...") logger.info("应用启动,等待用户登录...")
logger.info("💡 MCP插件采用延迟加载策略,将在用户首次使用时自动加载")
yield yield
+15 -2
View File
@@ -267,13 +267,26 @@ class HTTPMCPClient:
start_time = time.time() start_time = time.time()
try: try:
# 尝试连接并列举工具 # 尝试连接并列举工具(直接调用SDK,避免重复日志)
await self._ensure_connected() await self._ensure_connected()
tools = await self.list_tools()
result = await self._session.list_tools()
# 转换为字典格式
tools = []
for tool in result.tools:
tool_dict = {
"name": tool.name,
"description": tool.description or "",
"inputSchema": tool.inputSchema
}
tools.append(tool_dict)
end_time = time.time() end_time = time.time()
response_time = round((end_time - start_time) * 1000, 2) response_time = round((end_time - start_time) * 1000, 2)
logger.info(f"✅ 连接测试成功,获取到 {len(tools)} 个工具")
return { return {
"success": True, "success": True,
"message": "连接测试成功", "message": "连接测试成功",
+2 -2
View File
@@ -14,8 +14,8 @@ class Character(Base):
# 基本信息 # 基本信息
name = Column(String(100), nullable=False, comment="角色/组织名称") name = Column(String(100), nullable=False, comment="角色/组织名称")
age = Column(String(20), comment="年龄") age = Column(String(50), comment="年龄")
gender = Column(String(20), comment="性别") gender = Column(String(50), comment="性别")
is_organization = Column(Boolean, default=False, comment="是否为组织") is_organization = Column(Boolean, default=False, comment="是否为组织")
# 角色类型:protagonist(主角)/supporting(配角)/antagonist(反派) # 角色类型:protagonist(主角)/supporting(配角)/antagonist(反派)
+2 -2
View File
@@ -9,7 +9,7 @@ class StoryMemory(Base):
"""故事记忆表 - 存储结构化的故事片段和元数据""" """故事记忆表 - 存储结构化的故事片段和元数据"""
__tablename__ = "story_memories" __tablename__ = "story_memories"
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4())) id = Column(String(100), primary_key=True, default=lambda: str(uuid.uuid4()))
project_id = Column(String(36), ForeignKey("projects.id", ondelete="CASCADE"), nullable=False, index=True) project_id = Column(String(36), ForeignKey("projects.id", ondelete="CASCADE"), nullable=False, index=True)
chapter_id = Column(String(36), ForeignKey("chapters.id", ondelete="CASCADE"), nullable=True, index=True) chapter_id = Column(String(36), ForeignKey("chapters.id", ondelete="CASCADE"), nullable=True, index=True)
@@ -45,7 +45,7 @@ class StoryMemory(Base):
# 伏笔相关字段 # 伏笔相关字段
is_foreshadow = Column(Integer, default=0, comment="伏笔状态: 0=普通记忆, 1=已埋下伏笔, 2=伏笔已回收") is_foreshadow = Column(Integer, default=0, comment="伏笔状态: 0=普通记忆, 1=已埋下伏笔, 2=伏笔已回收")
foreshadow_resolved_at = Column(String(36), ForeignKey("chapters.id", ondelete="SET NULL"), comment="伏笔回收的章节ID") foreshadow_resolved_at = Column(String(100), ForeignKey("chapters.id", ondelete="SET NULL"), comment="伏笔回收的章节ID")
foreshadow_strength = Column(Float, comment="伏笔强度 0.0-1.0") foreshadow_strength = Column(Float, comment="伏笔强度 0.0-1.0")
# 向量数据库关联 # 向量数据库关联
+1
View File
@@ -10,6 +10,7 @@ class Project(Base):
__tablename__ = "projects" __tablename__ = "projects"
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4())) id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
user_id = Column(String(36), nullable=False, index=True, comment="用户ID")
title = Column(String(200), nullable=False, comment="项目标题") title = Column(String(200), nullable=False, comment="项目标题")
description = Column(Text, comment="项目简介") description = Column(Text, comment="项目简介")
theme = Column(Text, comment="主题") theme = Column(Text, comment="主题")
+1 -1
View File
@@ -75,7 +75,7 @@ class Organization(Base):
# 组织特色 # 组织特色
motto = Column(String(200), comment="宗旨/口号") motto = Column(String(200), comment="宗旨/口号")
color = Column(String(20), comment="代表颜色") color = Column(String(100), comment="代表颜色")
created_at = Column(DateTime, server_default=func.now(), comment="创建时间") created_at = Column(DateTime, server_default=func.now(), comment="创建时间")
updated_at = Column(DateTime, server_default=func.now(), onupdate=func.now(), comment="更新时间") updated_at = Column(DateTime, server_default=func.now(), onupdate=func.now(), comment="更新时间")
+6 -2
View File
@@ -82,7 +82,9 @@ class AIService:
self.openai_http_client = None self.openai_http_client = None
self.openai_api_key = None self.openai_api_key = None
self.openai_base_url = None self.openai_base_url = None
logger.warning("OpenAI API key未配置") # 只有当用户明确选择OpenAI作为提供商时才警告
if self.api_provider == "openai":
logger.warning("⚠️ OpenAI API key未配置,但被设置为当前AI提供商")
# 初始化Anthropic客户端 # 初始化Anthropic客户端
anthropic_key = api_key if api_provider == "anthropic" else app_settings.anthropic_api_key anthropic_key = api_key if api_provider == "anthropic" else app_settings.anthropic_api_key
@@ -118,7 +120,9 @@ class AIService:
self.anthropic_client = None self.anthropic_client = None
else: else:
self.anthropic_client = None self.anthropic_client = None
logger.warning("Anthropic API key未配置") # 只有当用户明确选择Anthropic作为提供商时才警告
if self.api_provider == "anthropic":
logger.warning("⚠️ Anthropic API key未配置,但被设置为当前AI提供商")
async def generate_text( async def generate_text(
self, self,
+1 -1
View File
@@ -87,7 +87,7 @@ class MemoryService:
'sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2', 'sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2',
cache_folder=model_cache_dir, cache_folder=model_cache_dir,
device='cpu', # 明确指定使用CPU device='cpu', # 明确指定使用CPU
trust_remote_code=False # 安全起见 trust_remote_code=False, # 安全起见
) )
logger.info("✅ Embedding模型加载成功 (paraphrase-multilingual-MiniLM-L12-v2)") logger.info("✅ Embedding模型加载成功 (paraphrase-multilingual-MiniLM-L12-v2)")
except Exception as e: except Exception as e:
+7 -5
View File
@@ -5,7 +5,9 @@ python-multipart==0.0.20
# 数据库 # 数据库
sqlalchemy==2.0.25 sqlalchemy==2.0.25
aiosqlite==0.19.0 aiosqlite==0.19.0 # SQLite支持(保留用于开发环境)
asyncpg==0.29.0 # PostgreSQL异步驱动(生产环境)
psycopg2-binary==2.9.9 # PostgreSQL同步驱动(备用)
# 数据验证 # 数据验证
pydantic==2.12.4 pydantic==2.12.4
@@ -29,8 +31,8 @@ numpy==1.26.4
chromadb==1.3.2 chromadb==1.3.2
# Transformers锁定兼容版本 # Transformers更新到最新稳定版本以修复 FutureWarning
transformers==4.35.2 transformers==4.57.1
# Sentence Transformers基于PyTorch的文本embedding # Sentence Transformers更新到最新稳定版本以修复 FutureWarning
sentence-transformers==2.3.1 sentence-transformers==5.1.2
+30
View File
@@ -0,0 +1,30 @@
-- PostgreSQL 初始化脚本
-- 此脚本会在PostgreSQL容器首次启动时自动执行
-- 确保使用UTF8编码
SET client_encoding = 'UTF8';
-- 创建必要的扩展
CREATE EXTENSION IF NOT EXISTS "uuid-ossp";
CREATE EXTENSION IF NOT EXISTS "pg_trgm";
-- 设置时区
SET timezone = 'Asia/Shanghai';
-- 优化配置(这些设置会在容器启动后生效)
-- 注意:部分配置已在docker-compose.yml的command中设置
-- 创建索引优化查询性能(表会由SQLAlchemy自动创建)
-- 这里只是预留空间,实际索引会在应用启动时创建
-- 输出初始化信息
DO $$
BEGIN
RAISE NOTICE '==================================================';
RAISE NOTICE 'MuMuAINovel PostgreSQL 数据库初始化完成';
RAISE NOTICE '数据库名称: mumuai_novel';
RAISE NOTICE '字符编码: UTF8';
RAISE NOTICE '时区设置: Asia/Shanghai';
RAISE NOTICE '扩展已安装: uuid-ossp, pg_trgm';
RAISE NOTICE '==================================================';
END $$;
@@ -0,0 +1,816 @@
#!/usr/bin/env python3
"""
SQLite to PostgreSQL 数据迁移脚本
使用方法:
python backend/scripts/migrate_sqlite_to_postgres.py
前置条件:
1. PostgreSQL数据库已创建
2. .env文件中DATABASE_URL已配置为PostgreSQL
3. SQLite数据文件存在于 backend/data/ 目录
"""
import asyncio
import sys
from pathlib import Path
from typing import List, Dict, Any
import logging
from datetime import datetime
# 添加项目根目录到Python路径
sys.path.insert(0, str(Path(__file__).parent.parent))
from sqlalchemy import create_engine, text, select
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker
from app.database import Base
from app.models import (
Project, Outline, Character, Chapter, GenerationHistory,
Settings, WritingStyle, ProjectDefaultStyle,
RelationshipType, CharacterRelationship, Organization, OrganizationMember,
StoryMemory, PlotAnalysis, AnalysisTask, BatchGenerationTask,
MCPPlugin
)
from app.config import settings
# 创建日志目录
log_dir = Path(__file__).parent.parent / "logs"
log_dir.mkdir(exist_ok=True)
# 生成日志文件名(带时间戳)
log_filename = log_dir / f"migration_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log"
# 设置日志 - 同时输出到控制台和文件
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[
logging.StreamHandler(), # 控制台输出
logging.FileHandler(log_filename, encoding='utf-8') # 文件输出
]
)
logger = logging.getLogger(__name__)
logger.info(f"📝 日志文件: {log_filename}")
class SQLiteToPostgresMigrator:
"""SQLite到PostgreSQL的数据迁移器"""
def __init__(self, sqlite_dir: Path, target_user_id: str):
"""
初始化迁移器
Args:
sqlite_dir: SQLite数据库文件目录
target_user_id: 目标用户ID(迁移后的数据归属)
"""
self.sqlite_dir = sqlite_dir
self.target_user_id = target_user_id
self.sqlite_files = list(sqlite_dir.glob("ai_story_user_*.db"))
# PostgreSQL连接
if "postgresql" not in settings.database_url:
raise ValueError("DATABASE_URL必须配置为PostgreSQL")
self.pg_engine = create_async_engine(
settings.database_url,
echo=False,
pool_pre_ping=True
)
self.pg_session_maker = async_sessionmaker(
self.pg_engine,
class_=AsyncSession,
expire_on_commit=False
)
async def migrate_all(self):
"""迁移所有SQLite数据库"""
if not self.sqlite_files:
logger.warning(f"未找到SQLite数据库文件: {self.sqlite_dir}")
return
logger.info(f"找到 {len(self.sqlite_files)} 个SQLite数据库文件")
# 创建PostgreSQL表结构
await self._create_tables()
# 初始化关系类型数据
await self._init_relationship_types()
# 逐个迁移
for sqlite_file in self.sqlite_files:
await self._migrate_single_db(sqlite_file)
logger.info("✅ 所有数据迁移完成")
async def _create_tables(self):
"""创建PostgreSQL表结构"""
logger.info("创建PostgreSQL表结构...")
async with self.pg_engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
logger.info("✅ 表结构创建完成")
async def _init_relationship_types(self):
"""初始化关系类型数据"""
logger.info("初始化关系类型数据...")
# 预置关系类型数据
relationship_types = [
# 家族关系
{"name": "父亲", "category": "family", "reverse_name": "子女", "intimacy_range": "high", "icon": "👨"},
{"name": "母亲", "category": "family", "reverse_name": "子女", "intimacy_range": "high", "icon": "👩"},
{"name": "兄弟", "category": "family", "reverse_name": "兄弟", "intimacy_range": "high", "icon": "👬"},
{"name": "姐妹", "category": "family", "reverse_name": "姐妹", "intimacy_range": "high", "icon": "👭"},
{"name": "子女", "category": "family", "reverse_name": "父母", "intimacy_range": "high", "icon": "👶"},
{"name": "配偶", "category": "family", "reverse_name": "配偶", "intimacy_range": "high", "icon": "💑"},
{"name": "恋人", "category": "family", "reverse_name": "恋人", "intimacy_range": "high", "icon": "💕"},
# 社交关系
{"name": "师父", "category": "social", "reverse_name": "徒弟", "intimacy_range": "high", "icon": "🎓"},
{"name": "徒弟", "category": "social", "reverse_name": "师父", "intimacy_range": "high", "icon": "📚"},
{"name": "朋友", "category": "social", "reverse_name": "朋友", "intimacy_range": "medium", "icon": "🤝"},
{"name": "同学", "category": "social", "reverse_name": "同学", "intimacy_range": "medium", "icon": "🎒"},
{"name": "邻居", "category": "social", "reverse_name": "邻居", "intimacy_range": "low", "icon": "🏘️"},
{"name": "知己", "category": "social", "reverse_name": "知己", "intimacy_range": "high", "icon": "💙"},
# 职业关系
{"name": "上司", "category": "professional", "reverse_name": "下属", "intimacy_range": "low", "icon": "👔"},
{"name": "下属", "category": "professional", "reverse_name": "上司", "intimacy_range": "low", "icon": "💼"},
{"name": "同事", "category": "professional", "reverse_name": "同事", "intimacy_range": "medium", "icon": "🤵"},
{"name": "合作伙伴", "category": "professional", "reverse_name": "合作伙伴", "intimacy_range": "medium", "icon": "🤜🤛"},
# 敌对关系
{"name": "敌人", "category": "hostile", "reverse_name": "敌人", "intimacy_range": "low", "icon": "⚔️"},
{"name": "仇人", "category": "hostile", "reverse_name": "仇人", "intimacy_range": "low", "icon": "💢"},
{"name": "竞争对手", "category": "hostile", "reverse_name": "竞争对手", "intimacy_range": "low", "icon": "🎯"},
{"name": "宿敌", "category": "hostile", "reverse_name": "宿敌", "intimacy_range": "low", "icon": ""},
]
try:
async with self.pg_session_maker() as session:
# 检查是否已经有数据
result = await session.execute(select(RelationshipType))
existing = result.scalars().first()
if existing:
logger.info("关系类型数据已存在,跳过初始化")
return
# 插入预置数据
logger.info("开始插入关系类型数据...")
for rt_data in relationship_types:
relationship_type = RelationshipType(**rt_data)
session.add(relationship_type)
await session.commit()
logger.info(f"✅ 成功插入 {len(relationship_types)} 条关系类型数据")
except Exception as e:
logger.error(f"初始化关系类型数据失败: {str(e)}", exc_info=True)
# 不抛出异常,继续迁移流程
logger.warning("关系类型初始化失败,将跳过有外键依赖的记录")
async def _migrate_single_db(self, sqlite_file: Path):
"""迁移单个SQLite数据库"""
# 从文件名提取user_id
filename = sqlite_file.stem # ai_story_user_xxx
if filename.startswith("ai_story_user_"):
user_id = filename.replace("ai_story_user_", "")
else:
user_id = self.target_user_id
logger.info(f"\n{'='*60}")
logger.info(f"开始迁移: {sqlite_file.name} -> user_id: {user_id}")
logger.info(f"{'='*60}")
# 创建SQLite连接
sqlite_url = f"sqlite+aiosqlite:///{sqlite_file.absolute()}"
sqlite_engine = create_async_engine(sqlite_url, echo=False)
sqlite_session_maker = async_sessionmaker(
sqlite_engine,
class_=AsyncSession,
expire_on_commit=False
)
try:
# 迁移各个表
async with sqlite_session_maker() as sqlite_session:
async with self.pg_session_maker() as pg_session:
# 按照依赖顺序迁移
await self._migrate_table(
sqlite_session, pg_session, user_id, Settings, "设置"
)
await self._migrate_table(
sqlite_session, pg_session, user_id, Project, "项目"
)
await self._migrate_table(
sqlite_session, pg_session, user_id, Character, "角色"
)
await self._migrate_table(
sqlite_session, pg_session, user_id, Outline, "大纲"
)
await self._migrate_table(
sqlite_session, pg_session, user_id, Chapter, "章节"
)
await self._migrate_table(
sqlite_session, pg_session, user_id, CharacterRelationship, "角色关系"
)
await self._migrate_table(
sqlite_session, pg_session, user_id, Organization, "组织"
)
await self._migrate_table(
sqlite_session, pg_session, user_id, OrganizationMember, "组织成员"
)
await self._migrate_table(
sqlite_session, pg_session, user_id, GenerationHistory, "生成历史"
)
await self._migrate_table(
sqlite_session, pg_session, user_id, WritingStyle, "写作风格"
)
await self._migrate_table(
sqlite_session, pg_session, user_id, ProjectDefaultStyle, "项目默认风格"
)
await self._migrate_table(
sqlite_session, pg_session, user_id, StoryMemory, "记忆"
)
await self._migrate_table(
sqlite_session, pg_session, user_id, PlotAnalysis, "剧情分析"
)
await self._migrate_table(
sqlite_session, pg_session, user_id, AnalysisTask, "分析任务"
)
await self._migrate_table(
sqlite_session, pg_session, user_id, BatchGenerationTask, "批量生成任务"
)
await self._migrate_table(
sqlite_session, pg_session, user_id, MCPPlugin, "MCP插件"
)
await pg_session.commit()
logger.info(f"{sqlite_file.name} 迁移完成")
except Exception as e:
logger.error(f"❌ 迁移失败: {e}", exc_info=True)
finally:
await sqlite_engine.dispose()
async def _migrate_table(
self,
sqlite_session: AsyncSession,
pg_session: AsyncSession,
user_id: str,
model_class,
table_name: str
):
"""迁移单个表的数据"""
try:
# 获取SQLite表中实际存在的列
sqlite_table = model_class.__table__
sqlite_conn = await sqlite_session.connection()
# 查询SQLite表结构
inspect_result = await sqlite_conn.execute(
text(f"PRAGMA table_info({sqlite_table.name})")
)
sqlite_columns = {row[1] for row in inspect_result.fetchall()} # row[1]是列名
# 构建只包含SQLite中存在的列的查询
available_columns = [
c for c in model_class.__table__.columns
if c.name in sqlite_columns
]
if not available_columns:
logger.warning(f" ⚠️ {table_name}: 表结构不匹配,跳过")
return
# 从SQLite读取数据(只查询存在的列)
result = await sqlite_session.execute(
select(*available_columns)
)
records = result.all()
if not records:
logger.info(f" - {table_name}: 无数据")
return
# 为每条记录创建字典并添加user_id
migrated_count = 0
skipped_count = 0
for record in records:
# 从查询结果构建字典
record_dict = {}
for i, col in enumerate(available_columns):
record_dict[col.name] = record[i]
# 添加user_id(如果PostgreSQL模型有这个字段但SQLite没有)
if hasattr(model_class, 'user_id') and 'user_id' not in record_dict:
record_dict['user_id'] = user_id
# 验证字段长度(防止超长字段导致插入失败)
if not self._validate_field_lengths(model_class, record_dict, table_name):
skipped_count += 1
record_id = record_dict.get('id', 'unknown')
logger.warning(f" ⚠️ [{table_name}] 跳过超长字段记录 ID={record_id}")
continue
# 验证外键引用(针对有外键的表)
validation_result = await self._validate_foreign_keys(pg_session, model_class, record_dict)
if not validation_result:
skipped_count += 1
record_id = record_dict.get('id', 'unknown')
logger.warning(f" ⚠️ [{table_name}] 跳过无效外键记录 ID={record_id}")
# 输出记录详情以便调试
if model_class.__tablename__ == 'story_memories':
logger.warning(f" 记忆详情: project_id={record_dict.get('project_id')}, "
f"chapter_id={record_dict.get('chapter_id')}, "
f"type={record_dict.get('memory_type')}")
elif model_class.__tablename__ == 'character_relationships':
logger.warning(f" 关系详情: project_id={record_dict.get('project_id')}, "
f"from={record_dict.get('character_from_id')}, "
f"to={record_dict.get('character_to_id')}, "
f"type_id={record_dict.get('relationship_type_id')}")
elif model_class.__tablename__ == 'organizations':
logger.warning(f" 组织详情: project_id={record_dict.get('project_id')}, "
f"character_id={record_dict.get('character_id')}")
elif model_class.__tablename__ == 'organization_members':
logger.warning(f" 成员详情: org_id={record_dict.get('organization_id')}, "
f"character_id={record_dict.get('character_id')}")
elif model_class.__tablename__ == 'writing_styles':
logger.warning(f" 写作风格详情: project_id={record_dict.get('project_id')}, "
f"name={record_dict.get('name')}, "
f"style_type={record_dict.get('style_type')}")
elif model_class.__tablename__ == 'characters':
logger.warning(f" 角色详情: project_id={record_dict.get('project_id')}, "
f"name={record_dict.get('name')}, "
f"is_organization={record_dict.get('is_organization')}")
elif model_class.__tablename__ == 'outlines':
logger.warning(f" 大纲详情: project_id={record_dict.get('project_id')}, "
f"title={record_dict.get('title')}")
elif model_class.__tablename__ == 'chapters':
logger.warning(f" 章节详情: project_id={record_dict.get('project_id')}, "
f"title={record_dict.get('title')}, "
f"chapter_number={record_dict.get('chapter_number')}")
elif model_class.__tablename__ == 'generation_history':
logger.warning(f" 生成历史详情: project_id={record_dict.get('project_id')}, "
f"chapter_id={record_dict.get('chapter_id')}, "
f"model={record_dict.get('model')}")
elif model_class.__tablename__ == 'plot_analysis':
logger.warning(f" 剧情分析详情: project_id={record_dict.get('project_id')}, "
f"chapter_id={record_dict.get('chapter_id')}, "
f"plot_stage={record_dict.get('plot_stage')}")
elif model_class.__tablename__ == 'analysis_tasks':
logger.warning(f" 分析任务详情: chapter_id={record_dict.get('chapter_id')}, "
f"project_id={record_dict.get('project_id')}, "
f"status={record_dict.get('status')}")
elif model_class.__tablename__ == 'batch_generation_tasks':
logger.warning(f" 批量生成任务详情: project_id={record_dict.get('project_id')}, "
f"status={record_dict.get('status')}, "
f"completed={record_dict.get('completed_chapters')}/{record_dict.get('total_chapters')}")
elif model_class.__tablename__ == 'project_default_styles':
logger.warning(f" 项目默认风格详情: project_id={record_dict.get('project_id')}, "
f"style_id={record_dict.get('style_id')}")
continue
# 检查记录是否已存在(避免主键冲突)
record_id = record_dict.get('id')
if record_id and await self._record_exists(pg_session, model_class, record_id):
skipped_count += 1
logger.debug(f" 跳过已存在的记录: {record_id}")
continue
# 创建新记录
try:
new_record = model_class(**record_dict)
pg_session.add(new_record)
migrated_count += 1
except Exception as e:
logger.warning(f" ⚠️ 跳过无效记录: {str(e)[:100]}")
skipped_count += 1
continue
await pg_session.flush()
if skipped_count > 0:
logger.info(f"{table_name}: {migrated_count} 条记录(跳过 {skipped_count} 条无效记录)")
else:
logger.info(f"{table_name}: {migrated_count} 条记录")
except Exception as e:
logger.error(f"{table_name} 迁移失败: {e}")
raise
async def _record_exists(
self,
pg_session: AsyncSession,
model_class,
record_id: Any
) -> bool:
"""
检查记录是否已存在
Args:
pg_session: PostgreSQL会话
model_class: 模型类
record_id: 记录ID
Returns:
bool: 记录是否存在
"""
try:
# 获取主键列
pk_column = list(model_class.__table__.primary_key.columns)[0]
result = await pg_session.execute(
select(pk_column).where(pk_column == record_id)
)
return result.scalar_one_or_none() is not None
except Exception:
return False
async def _validate_foreign_keys(
self,
pg_session: AsyncSession,
model_class,
record_dict: Dict[str, Any]
) -> bool:
"""
验证记录的外键是否有效
Args:
pg_session: PostgreSQL会话
model_class: 模型类
record_dict: 记录字典
Returns:
bool: 外键是否全部有效
"""
from app.models import Character, Project, Chapter
# 使用no_autoflush防止过早flush
with pg_session.no_autoflush:
# 针对StoryMemory表验证外键
if model_class.__tablename__ == 'story_memories':
# 验证project_id
project_id = record_dict.get('project_id')
if project_id:
result = await pg_session.execute(
select(Project.id).where(Project.id == project_id)
)
if not result.scalar_one_or_none():
logger.warning(f" ❌ [记忆] 无效的project_id: {project_id}")
return False
# 验证chapter_id(可选)
chapter_id = record_dict.get('chapter_id')
if chapter_id:
result = await pg_session.execute(
select(Chapter.id).where(Chapter.id == chapter_id)
)
if not result.scalar_one_or_none():
logger.warning(f" ❌ [记忆] 无效的chapter_id: {chapter_id}")
return False
# 针对CharacterRelationship表验证外键
elif model_class.__tablename__ == 'character_relationships':
# 验证project_id
project_id = record_dict.get('project_id')
if project_id:
result = await pg_session.execute(
select(Project.id).where(Project.id == project_id)
)
if not result.scalar_one_or_none():
logger.warning(f" ❌ 无效的project_id: {project_id}")
return False
# 验证character_from_id
char_from_id = record_dict.get('character_from_id')
if char_from_id:
result = await pg_session.execute(
select(Character.id).where(Character.id == char_from_id)
)
if not result.scalar_one_or_none():
logger.warning(f" ❌ 无效的character_from_id: {char_from_id}")
return False
# 验证character_to_id
char_to_id = record_dict.get('character_to_id')
if char_to_id:
result = await pg_session.execute(
select(Character.id).where(Character.id == char_to_id)
)
if not result.scalar_one_or_none():
logger.warning(f" ❌ 无效的character_to_id: {char_to_id}")
return False
# 验证relationship_type_id
rel_type_id = record_dict.get('relationship_type_id')
if rel_type_id:
result = await pg_session.execute(
select(RelationshipType.id).where(RelationshipType.id == rel_type_id)
)
if not result.scalar_one_or_none():
logger.warning(f" ❌ 无效的relationship_type_id: {rel_type_id}")
return False
# 针对Organization表验证外键
elif model_class.__tablename__ == 'organizations':
# 验证character_id
char_id = record_dict.get('character_id')
if char_id:
result = await pg_session.execute(
select(Character.id).where(Character.id == char_id)
)
if not result.scalar_one_or_none():
logger.warning(f" ❌ [组织] 无效的character_id: {char_id}")
return False
# 针对OrganizationMember表验证外键
elif model_class.__tablename__ == 'organization_members':
from app.models import Organization
# 验证organization_id
org_id = record_dict.get('organization_id')
if org_id:
result = await pg_session.execute(
select(Organization.id).where(Organization.id == org_id)
)
if not result.scalar_one_or_none():
logger.warning(f" ❌ 无效的organization_id: {org_id}")
return False
# 验证character_id
char_id = record_dict.get('character_id')
if char_id:
result = await pg_session.execute(
select(Character.id).where(Character.id == char_id)
)
if not result.scalar_one_or_none():
logger.warning(f" ❌ [组织成员] 无效的character_id: {char_id}")
return False
# 针对Character表验证外键
elif model_class.__tablename__ == 'characters':
# 验证project_id
project_id = record_dict.get('project_id')
if project_id:
result = await pg_session.execute(
select(Project.id).where(Project.id == project_id)
)
if not result.scalar_one_or_none():
logger.warning(f" ❌ [角色] 无效的project_id: {project_id}")
return False
# 针对Outline表验证外键
elif model_class.__tablename__ == 'outlines':
# 验证project_id
project_id = record_dict.get('project_id')
if project_id:
result = await pg_session.execute(
select(Project.id).where(Project.id == project_id)
)
if not result.scalar_one_or_none():
logger.warning(f" ❌ [大纲] 无效的project_id: {project_id}")
return False
# 针对Chapter表验证外键
elif model_class.__tablename__ == 'chapters':
# 验证project_id
project_id = record_dict.get('project_id')
if project_id:
result = await pg_session.execute(
select(Project.id).where(Project.id == project_id)
)
if not result.scalar_one_or_none():
logger.warning(f" ❌ [章节] 无效的project_id: {project_id}")
return False
# 针对WritingStyle表验证外键
elif model_class.__tablename__ == 'writing_styles':
# 验证project_id(可选)
project_id = record_dict.get('project_id')
if project_id:
result = await pg_session.execute(
select(Project.id).where(Project.id == project_id)
)
if not result.scalar_one_or_none():
logger.warning(f" ❌ [写作风格] 无效的project_id: {project_id}")
return False
# 针对GenerationHistory表验证外键
elif model_class.__tablename__ == 'generation_history':
# 验证project_id
project_id = record_dict.get('project_id')
if project_id:
result = await pg_session.execute(
select(Project.id).where(Project.id == project_id)
)
if not result.scalar_one_or_none():
logger.warning(f" ❌ [生成历史] 无效的project_id: {project_id}")
return False
# 验证chapter_id(可选)
chapter_id = record_dict.get('chapter_id')
if chapter_id:
result = await pg_session.execute(
select(Chapter.id).where(Chapter.id == chapter_id)
)
if not result.scalar_one_or_none():
logger.warning(f" ❌ [生成历史] 无效的chapter_id: {chapter_id}")
return False
# 针对PlotAnalysis表验证外键
elif model_class.__tablename__ == 'plot_analysis':
# 验证project_id(必需)
project_id = record_dict.get('project_id')
if project_id:
result = await pg_session.execute(
select(Project.id).where(Project.id == project_id)
)
if not result.scalar_one_or_none():
logger.warning(f" ❌ [剧情分析] 无效的project_id: {project_id}")
return False
# 验证chapter_id(必需)
chapter_id = record_dict.get('chapter_id')
if chapter_id:
result = await pg_session.execute(
select(Chapter.id).where(Chapter.id == chapter_id)
)
if not result.scalar_one_or_none():
logger.warning(f" ❌ [剧情分析] 无效的chapter_id: {chapter_id}")
return False
# 针对AnalysisTask表验证外键
elif model_class.__tablename__ == 'analysis_tasks':
# 验证chapter_id(必需)
chapter_id = record_dict.get('chapter_id')
if chapter_id:
result = await pg_session.execute(
select(Chapter.id).where(Chapter.id == chapter_id)
)
if not result.scalar_one_or_none():
logger.warning(f" ❌ [分析任务] 无效的chapter_id: {chapter_id}")
return False
# 验证project_id
project_id = record_dict.get('project_id')
if project_id:
result = await pg_session.execute(
select(Project.id).where(Project.id == project_id)
)
if not result.scalar_one_or_none():
logger.warning(f" ❌ [分析任务] 无效的project_id: {project_id}")
return False
# 针对BatchGenerationTask表验证外键
elif model_class.__tablename__ == 'batch_generation_tasks':
# 验证project_id(必需)
project_id = record_dict.get('project_id')
if project_id:
result = await pg_session.execute(
select(Project.id).where(Project.id == project_id)
)
if not result.scalar_one_or_none():
logger.warning(f" ❌ [批量生成任务] 无效的project_id: {project_id}")
return False
# 针对ProjectDefaultStyle表验证外键
elif model_class.__tablename__ == 'project_default_styles':
from app.models import WritingStyle
# 验证project_id(必需)
project_id = record_dict.get('project_id')
if project_id:
result = await pg_session.execute(
select(Project.id).where(Project.id == project_id)
)
if not result.scalar_one_or_none():
logger.warning(f" ❌ [项目默认风格] 无效的project_id: {project_id}")
return False
# 验证style_id(必需)
style_id = record_dict.get('style_id')
if style_id:
result = await pg_session.execute(
select(WritingStyle.id).where(WritingStyle.id == style_id)
)
if not result.scalar_one_or_none():
logger.warning(f" ❌ [项目默认风格] 无效的style_id: {style_id}")
return False
return True
def _validate_field_lengths(
self,
model_class,
record_dict: Dict[str, Any],
table_name: str
) -> bool:
"""
验证记录的字段长度是否符合模型定义
Args:
model_class: 模型类
record_dict: 记录字典
table_name: 表名(用于日志)
Returns:
bool: 字段长度是否全部有效
"""
from sqlalchemy import String
# 检查所有字符串类型字段
for column in model_class.__table__.columns:
# 只检查有长度限制的String类型字段
if isinstance(column.type, String) and column.type.length:
field_name = column.name
field_value = record_dict.get(field_name)
max_length = column.type.length
# 如果字段有值且超过最大长度
if field_value and isinstance(field_value, str) and len(field_value) > max_length:
logger.warning(
f" ❌ [{table_name}] 字段 '{field_name}' 超长: "
f"{len(field_value)} > {max_length} (截断了 {len(field_value) - max_length} 字符)"
)
# 对于敏感字段如API密钥,记录部分内容
if field_name in ['api_key', 'api_base_url']:
preview = field_value[:50] + "..." + field_value[-20:] if len(field_value) > 70 else field_value
logger.warning(f" 值预览: {preview}")
return False
return True
async def cleanup(self):
"""清理资源"""
await self.pg_engine.dispose()
async def main():
"""主函数"""
banner = """
╔══════════════════════════════════════════════════════════════╗
║ SQLite to PostgreSQL 数据迁移工具 ║
║ ║
║ 此工具将SQLite数据迁移到PostgreSQL ║
║ 请确保: ║
║ 1. PostgreSQL数据库已创建 ║
║ 2. .env中DATABASE_URL已配置为PostgreSQL ║
║ 3. SQLite数据文件存在 ║
╚══════════════════════════════════════════════════════════════╝
"""
print(banner)
logger.info(banner)
# 配置
sqlite_dir = Path(__file__).parent.parent / "data"
target_user_id = "migrated_user" # 默认用户ID
config_info = f"""
配置信息:
SQLite目录: {sqlite_dir}
PostgreSQL: {settings.database_url}
目标用户ID: {target_user_id}
日志文件: {log_filename}
"""
print(config_info)
logger.info(config_info)
# 确认
response = input("是否继续迁移? (yes/no): ")
if response.lower() not in ['yes', 'y']:
print("已取消迁移")
return
# 执行迁移
migrator = SQLiteToPostgresMigrator(sqlite_dir, target_user_id)
try:
await migrator.migrate_all()
success_msg = """
🎉 数据迁移成功完成!
下一步:
1. 测试应用功能
2. 验证数据完整性
3. 备份SQLite文件后可删除
详细日志已保存到: {}
""".format(log_filename)
print(success_msg)
logger.info(success_msg)
except Exception as e:
error_msg = f"\n❌ 迁移失败: {e}\n详细日志已保存到: {log_filename}"
print(error_msg)
logger.error("迁移过程出错", exc_info=True)
finally:
await migrator.cleanup()
logger.info(f"🔒 数据库连接已关闭,日志文件: {log_filename}")
if __name__ == "__main__":
asyncio.run(main())
+408
View File
@@ -0,0 +1,408 @@
#!/usr/bin/env python3
"""
PostgreSQL 数据库自动设置脚本
功能:
1. 自动连接到PostgreSQL服务器
2. 创建数据库和用户
3. 设置权限
4. 初始化表结构
使用方法:
python backend/scripts/setup_postgres.py
前置条件:
- PostgreSQL服务已安装并运行
- 知道PostgreSQL的超级用户密码通常是postgres用户
"""
import sys
import asyncio
from pathlib import Path
from getpass import getpass
import logging
# 添加项目根目录到Python路径
sys.path.insert(0, str(Path(__file__).parent.parent))
try:
import psycopg2
from psycopg2 import sql
from psycopg2.extensions import ISOLATION_LEVEL_AUTOCOMMIT
except ImportError:
print("❌ 缺少psycopg2依赖,正在安装...")
import subprocess
subprocess.check_call([sys.executable, "-m", "pip", "install", "psycopg2-binary"])
import psycopg2
from psycopg2 import sql
from psycopg2.extensions import ISOLATION_LEVEL_AUTOCOMMIT
from app.database import init_db
# 设置日志
logging.basicConfig(
level=logging.INFO,
format='%(message)s'
)
logger = logging.getLogger(__name__)
class PostgreSQLSetup:
"""PostgreSQL数据库自动设置"""
def __init__(
self,
host: str = "localhost",
port: int = 5432,
admin_user: str = "postgres",
admin_password: str = None,
db_name: str = "mumuai_novel",
db_user: str = "mumuai",
db_password: str = "123456"
):
"""
初始化设置参数
Args:
host: PostgreSQL主机地址
port: PostgreSQL端口
admin_user: 管理员用户名
admin_password: 管理员密码
db_name: 要创建的数据库名
db_user: 要创建的用户名
db_password: 用户密码
"""
self.host = host
self.port = port
self.admin_user = admin_user
self.admin_password = admin_password
self.db_name = db_name
self.db_user = db_user
self.db_password = db_password
self.conn = None
def connect_as_admin(self) -> bool:
"""连接到PostgreSQL(使用管理员权限)"""
try:
logger.info(f"🔌 连接到 PostgreSQL ({self.host}:{self.port})...")
self.conn = psycopg2.connect(
host=self.host,
port=self.port,
user=self.admin_user,
password=self.admin_password,
database="postgres" # 连接到默认数据库
)
self.conn.set_isolation_level(ISOLATION_LEVEL_AUTOCOMMIT)
logger.info(f"✅ 已连接到 PostgreSQL")
return True
except psycopg2.OperationalError as e:
logger.error(f"❌ 连接失败: {e}")
logger.error("\n可能的原因:")
logger.error("1. PostgreSQL服务未启动")
logger.error("2. 管理员密码错误")
logger.error("3. 主机地址或端口错误")
logger.error("4. pg_hba.conf配置不允许连接")
return False
def database_exists(self) -> bool:
"""检查数据库是否存在"""
cursor = self.conn.cursor()
cursor.execute(
"SELECT 1 FROM pg_database WHERE datname = %s",
(self.db_name,)
)
exists = cursor.fetchone() is not None
cursor.close()
return exists
def user_exists(self) -> bool:
"""检查用户是否存在"""
cursor = self.conn.cursor()
cursor.execute(
"SELECT 1 FROM pg_user WHERE usename = %s",
(self.db_user,)
)
exists = cursor.fetchone() is not None
cursor.close()
return exists
def create_user(self) -> bool:
"""创建数据库用户"""
try:
if self.user_exists():
logger.info(f"️ 用户 '{self.db_user}' 已存在")
# 询问是否重置密码
response = input(f"是否重置用户 '{self.db_user}' 的密码? (yes/no): ")
if response.lower() in ['yes', 'y']:
cursor = self.conn.cursor()
cursor.execute(
sql.SQL("ALTER USER {} WITH PASSWORD %s").format(
sql.Identifier(self.db_user)
),
(self.db_password,)
)
cursor.close()
logger.info(f"✅ 用户密码已更新")
return True
logger.info(f"👤 创建用户 '{self.db_user}'...")
cursor = self.conn.cursor()
cursor.execute(
sql.SQL("CREATE USER {} WITH PASSWORD %s").format(
sql.Identifier(self.db_user)
),
(self.db_password,)
)
cursor.close()
logger.info(f"✅ 用户创建成功")
return True
except Exception as e:
logger.error(f"❌ 创建用户失败: {e}")
return False
def create_database(self) -> bool:
"""创建数据库"""
try:
if self.database_exists():
logger.info(f"️ 数据库 '{self.db_name}' 已存在")
# 询问是否删除重建
response = input(f"是否删除并重建数据库 '{self.db_name}'? (yes/no): ")
if response.lower() in ['yes', 'y']:
logger.warning(f"⚠️ 删除数据库 '{self.db_name}'...")
cursor = self.conn.cursor()
# 断开所有连接
cursor.execute(
sql.SQL("""
SELECT pg_terminate_backend(pg_stat_activity.pid)
FROM pg_stat_activity
WHERE pg_stat_activity.datname = %s
AND pid <> pg_backend_pid()
"""),
(self.db_name,)
)
cursor.execute(
sql.SQL("DROP DATABASE {}").format(
sql.Identifier(self.db_name)
)
)
cursor.close()
logger.info(f"✅ 数据库已删除")
else:
return True
logger.info(f"🗄️ 创建数据库 '{self.db_name}'...")
cursor = self.conn.cursor()
cursor.execute(
sql.SQL("CREATE DATABASE {} OWNER {}").format(
sql.Identifier(self.db_name),
sql.Identifier(self.db_user)
)
)
cursor.close()
logger.info(f"✅ 数据库创建成功")
return True
except Exception as e:
logger.error(f"❌ 创建数据库失败: {e}")
return False
def grant_privileges(self) -> bool:
"""授予用户权限"""
try:
logger.info(f"🔐 授予用户权限...")
cursor = self.conn.cursor()
# 授予数据库所有权限
cursor.execute(
sql.SQL("GRANT ALL PRIVILEGES ON DATABASE {} TO {}").format(
sql.Identifier(self.db_name),
sql.Identifier(self.db_user)
)
)
cursor.close()
logger.info(f"✅ 权限授予成功")
return True
except Exception as e:
logger.error(f"❌ 授予权限失败: {e}")
return False
def update_env_file(self) -> bool:
"""更新.env文件"""
try:
env_file = Path(__file__).parent.parent / ".env"
database_url = (
f"postgresql+asyncpg://{self.db_user}:{self.db_password}"
f"@{self.host}:{self.port}/{self.db_name}"
)
if env_file.exists():
# 读取现有内容
with open(env_file, 'r', encoding='utf-8') as f:
lines = f.readlines()
# 更新DATABASE_URL
updated = False
for i, line in enumerate(lines):
if line.startswith('DATABASE_URL='):
lines[i] = f"DATABASE_URL={database_url}\n"
updated = True
break
if not updated:
lines.append(f"\nDATABASE_URL={database_url}\n")
# 写回文件
with open(env_file, 'w', encoding='utf-8') as f:
f.writelines(lines)
else:
# 创建新文件
with open(env_file, 'w', encoding='utf-8') as f:
f.write(f"DATABASE_URL={database_url}\n")
logger.info(f"✅ .env 文件已更新")
logger.info(f" DATABASE_URL={database_url}")
return True
except Exception as e:
logger.error(f"❌ 更新.env文件失败: {e}")
return False
async def initialize_tables(self) -> bool:
"""初始化数据库表结构"""
try:
logger.info(f"📋 初始化数据库表结构...")
await init_db('system')
logger.info(f"✅ 表结构初始化成功")
return True
except Exception as e:
logger.error(f"❌ 初始化表结构失败: {e}")
return False
def close(self):
"""关闭数据库连接"""
if self.conn:
self.conn.close()
logger.info(f"🔌 已断开连接")
async def setup(self) -> bool:
"""执行完整设置流程"""
try:
# 1. 连接
if not self.connect_as_admin():
return False
# 2. 创建用户
if not self.create_user():
return False
# 3. 创建数据库
if not self.create_database():
return False
# 4. 授予权限
if not self.grant_privileges():
return False
# 5. 更新配置
if not self.update_env_file():
return False
# 6. 关闭管理员连接
self.close()
# 7. 初始化表结构
if not await self.initialize_tables():
return False
return True
except Exception as e:
logger.error(f"❌ 设置过程出错: {e}")
return False
finally:
if self.conn:
self.close()
async def main():
"""主函数"""
print("""
PostgreSQL 数据库自动设置工具
此工具将自动完成:
1. 连接到PostgreSQL服务器
2. 创建数据库和用户
3. 设置权限
4. 初始化表结构
5. 更新.env配置文件
""")
# 获取配置
print("请输入PostgreSQL配置信息:\n")
host = input("主机地址 [localhost]: ").strip() or "localhost"
port = input("端口 [5432]: ").strip() or "5432"
port = int(port)
admin_user = input("管理员用户名 [postgres]: ").strip() or "postgres"
admin_password = getpass(f"管理员密码: ")
print("\n请输入要创建的数据库信息:\n")
db_name = input("数据库名 [mumuai_novel]: ").strip() or "mumuai_novel"
db_user = input("数据库用户名 [mumuai]: ").strip() or "mumuai"
db_password = getpass("数据库用户密码 [mumuai123]: ") or "mumuai123"
print(f"\n{'='*60}")
print(f"配置摘要:")
print(f" 服务器: {host}:{port}")
print(f" 数据库: {db_name}")
print(f" 用户: {db_user}")
print(f"{'='*60}\n")
response = input("确认开始设置? (yes/no): ")
if response.lower() not in ['yes', 'y']:
print("已取消设置")
return
# 执行设置
setup = PostgreSQLSetup(
host=host,
port=port,
admin_user=admin_user,
admin_password=admin_password,
db_name=db_name,
db_user=db_user,
db_password=db_password
)
print(f"\n{'='*60}")
success = await setup.setup()
print(f"{'='*60}\n")
if success:
print("🎉 PostgreSQL设置完成!\n")
print("下一步:")
print("1. 启动应用: python -m app.main")
print("2. 访问: http://localhost:8000")
print("3. 查看API文档: http://localhost:8000/docs")
else:
print("❌ 设置过程中出现错误,请检查日志")
print("\n故障排查:")
print("1. 确认PostgreSQL服务正在运行")
print("2. 检查管理员用户名和密码")
print("3. 查看PostgreSQL日志")
if __name__ == "__main__":
asyncio.run(main())
+74 -5
View File
@@ -1,15 +1,67 @@
services: services:
postgres:
image: postgres:18-alpine
container_name: mumuainovel-postgres
environment:
POSTGRES_DB: mumuai_novel
POSTGRES_USER: mumuai
POSTGRES_PASSWORD: ${POSTGRES_PASSWORD:-mumuai_password_2024}
POSTGRES_INITDB_ARGS: "--encoding=UTF8 --locale=C"
TZ: Asia/Shanghai
volumes:
- postgres_data:/var/lib/postgresql/data
- ./backend/scripts/init_postgres.sql:/docker-entrypoint-initdb.d/init.sql:ro
ports:
- "5432:5432"
restart: unless-stopped
healthcheck:
test: ["CMD-SHELL", "pg_isready -U mumuai -d mumuai_novel"]
interval: 10s
timeout: 5s
retries: 5
start_period: 10s
networks:
- ai-story-network
command:
- postgres
- -c
- max_connections=200
- -c
- shared_buffers=256MB
- -c
- effective_cache_size=1GB
- -c
- maintenance_work_mem=64MB
- -c
- checkpoint_completion_target=0.9
- -c
- wal_buffers=16MB
- -c
- default_statistics_target=100
- -c
- random_page_cost=1.1
- -c
- effective_io_concurrency=200
- -c
- work_mem=4MB
- -c
- min_wal_size=1GB
- -c
- max_wal_size=4GB
mumuainovel: mumuainovel:
build: build:
context: . context: .
dockerfile: Dockerfile dockerfile: Dockerfile
image: mumujie/mumuainovel:latest image: mumujie/mumuainovel:latest
container_name: mumuainovel container_name: mumuainovel
depends_on:
postgres:
condition: service_healthy
ports: ports:
- "8000:8000" - "8000:8000"
volumes: volumes:
# 持久化数据库和日志 # 持久化日志
- ./data:/app/data
- ./logs:/app/logs - ./logs:/app/logs
# 挂载环境变量文件(可选) # 挂载环境变量文件(可选)
- ./.env:/app/.env:ro - ./.env:/app/.env:ro
@@ -21,8 +73,20 @@ services:
- APP_PORT=8000 - APP_PORT=8000
- DEBUG=false - DEBUG=false
# 重要:环境变量会从 .env 文件自动加载 # 数据库配置(使用PostgreSQL
# 也可以在这里显式设置,优先级:此处设置 > .env 文件 - DATABASE_URL=postgresql+asyncpg://mumuai:${POSTGRES_PASSWORD:-mumuai_password_2024}@postgres:5432/mumuai_novel
# PostgreSQL连接池配置
- DATABASE_POOL_SIZE=30
- DATABASE_MAX_OVERFLOW=20
- DATABASE_POOL_TIMEOUT=60
- DATABASE_POOL_RECYCLE=1800
- DATABASE_POOL_PRE_PING=True
- DATABASE_POOL_USE_LIFO=True
- HTTP_PROXY=http://172.16.66.175:7890
- HTTPS_PROXY=http://172.16.66.175:7890
- NO_PROXY=localhost,127.0.0.1
# AI服务配置(建议在 .env 文件中设置) # AI服务配置(建议在 .env 文件中设置)
# - OPENAI_API_KEY=${OPENAI_API_KEY} # - OPENAI_API_KEY=${OPENAI_API_KEY}
@@ -41,10 +105,15 @@ services:
interval: 30s interval: 30s
timeout: 10s timeout: 10s
retries: 3 retries: 3
start_period: 10s start_period: 30s
networks: networks:
- ai-story-network - ai-story-network
volumes:
postgres_data:
driver: local
networks: networks:
ai-story-network: ai-story-network:
driver: bridge driver: bridge
@@ -67,6 +67,14 @@ export default function ChapterAnalysis({ chapterId, visible, onClose }: Chapter
} }
const taskData: AnalysisTask = await response.json(); const taskData: AnalysisTask = await response.json();
// 如果状态为 none(无任务),设置 task 为 null,让前端显示"开始分析"按钮
if (taskData.status === 'none' || !taskData.has_task) {
setTask(null);
setError(null); // 清除错误,这不是错误状态
return;
}
setTask(taskData); setTask(taskData);
if (taskData.status === 'completed') { if (taskData.status === 'completed') {
+1
View File
@@ -321,6 +321,7 @@ export default function Chapters() {
setAnalysisTasksMap(prev => ({ setAnalysisTasksMap(prev => ({
...prev, ...prev,
[editingId]: { [editingId]: {
has_task: true,
task_id: taskId, task_id: taskId,
chapter_id: editingId, chapter_id: editingId,
status: 'pending', status: 'pending',
+8 -6
View File
@@ -390,14 +390,16 @@ export interface ApiError {
// 章节分析任务相关类型 // 章节分析任务相关类型
export interface AnalysisTask { export interface AnalysisTask {
task_id: string; has_task: boolean;
task_id: string | null;
chapter_id: string; chapter_id: string;
status: 'pending' | 'running' | 'completed' | 'failed'; status: 'pending' | 'running' | 'completed' | 'failed' | 'none';
progress: number; progress: number;
error_message?: string; error_message?: string | null;
created_at?: string; auto_recovered?: boolean;
started_at?: string; created_at?: string | null;
completed_at?: string; started_at?: string | null;
completed_at?: string | null;
} }
// 分析结果 - 钩子 // 分析结果 - 钩子