update:1.切换数据库PostgreSQL
This commit is contained in:
@@ -39,6 +39,7 @@ Thumbs.db
|
|||||||
# 数据库文件(不包含在镜像中)
|
# 数据库文件(不包含在镜像中)
|
||||||
data/*.db
|
data/*.db
|
||||||
backend/data/*.db
|
backend/data/*.db
|
||||||
|
postgres_data/
|
||||||
|
|
||||||
# ChromaDB数据库(不包含在镜像中,会在运行时生成)
|
# ChromaDB数据库(不包含在镜像中,会在运行时生成)
|
||||||
backend/data/chroma_db/
|
backend/data/chroma_db/
|
||||||
|
|||||||
@@ -26,7 +26,7 @@
|
|||||||
- 🌐 **世界观设定** - 构建完整的故事世界观和背景设定
|
- 🌐 **世界观设定** - 构建完整的故事世界观和背景设定
|
||||||
- 🔐 **多种登录方式** - 支持 LinuxDO OAuth 登录和本地账户登录
|
- 🔐 **多种登录方式** - 支持 LinuxDO OAuth 登录和本地账户登录
|
||||||
- 🐳 **Docker 部署** - 一键部署,开箱即用
|
- 🐳 **Docker 部署** - 一键部署,开箱即用
|
||||||
- 💾 **数据持久化** - 基于 SQLite 的本地数据存储,支持多用户隔离
|
- 💾 **数据持久化** - 支持 PostgreSQL 和 SQLite 双数据库,多用户数据隔离
|
||||||
- 🎨 **现代化 UI** - 基于 Ant Design 的美观界面,响应式设计
|
- 🎨 **现代化 UI** - 基于 Ant Design 的美观界面,响应式设计
|
||||||
|
|
||||||
|
|
||||||
@@ -36,7 +36,7 @@
|
|||||||
|
|
||||||
- [ ] **灵感模式** - 提供创作灵感和点子生成功能
|
- [ ] **灵感模式** - 提供创作灵感和点子生成功能
|
||||||
- [✔] **自定义写作风格** - 支持自定义AI写作风格和语言风格
|
- [✔] **自定义写作风格** - 支持自定义AI写作风格和语言风格
|
||||||
- [ ] **支持数据导入导出** - 支持项目数据的导入和导出功能
|
- [✔] **支持数据导入导出** - 支持项目数据的导入和导出功能
|
||||||
- [ ] **添加prompt调整界面** - 提供可视化的prompt模板编辑和调整界面
|
- [ ] **添加prompt调整界面** - 提供可视化的prompt模板编辑和调整界面
|
||||||
- [✔] **开放章节内容字数限制** - 支持用户在生成章节内容时设置字数 @wyf007
|
- [✔] **开放章节内容字数限制** - 支持用户在生成章节内容时设置字数 @wyf007
|
||||||
- [ ] **设定追溯与矛盾检测** - 对大纲、世界观、角色档案中的设定支持悬停查看注释,显示相关章节来源和佐证原文;自动检测新章节与已有设定的矛盾(吃书),标记为"矛盾"设定并提供解决建议,当新设定解决矛盾后自动更新注释说明 @lulujiang
|
- [ ] **设定追溯与矛盾检测** - 对大纲、世界观、角色档案中的设定支持悬停查看注释,显示相关章节来源和佐证原文;自动检测新章节与已有设定的矛盾(吃书),标记为"矛盾"设定并提供解决建议,当新设定解决矛盾后自动更新注释说明 @lulujiang
|
||||||
@@ -116,9 +116,77 @@ npm run build
|
|||||||
|
|
||||||
## 🐳 部署方式
|
## 🐳 部署方式
|
||||||
|
|
||||||
### Docker Compose 部署
|
### Docker Compose 部署(PostgreSQL)
|
||||||
|
|
||||||
#### 使用 Docker Hub 镜像(推荐)
|
**推荐生产环境使用**,提供更好的性能和并发支持。
|
||||||
|
|
||||||
|
#### 快速启动
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 1. 克隆项目
|
||||||
|
git clone https://github.com/xiamuceer-j/MuMuAINovel.git
|
||||||
|
cd MuMuAINovel
|
||||||
|
|
||||||
|
# 2. 配置环境变量
|
||||||
|
cp backend/.env.example .env
|
||||||
|
# 编辑 .env 文件,设置必要的配置:
|
||||||
|
# - POSTGRES_PASSWORD(数据库密码)
|
||||||
|
# - OPENAI_API_KEY(AI服务密钥)
|
||||||
|
# - 其他可选配置
|
||||||
|
|
||||||
|
# 3. 启动服务(包含PostgreSQL)
|
||||||
|
docker-compose up -d
|
||||||
|
|
||||||
|
# 4. 查看服务状态
|
||||||
|
docker-compose ps
|
||||||
|
|
||||||
|
# 5. 查看日志
|
||||||
|
docker-compose logs -f
|
||||||
|
|
||||||
|
# 6. 访问应用
|
||||||
|
# 打开浏览器访问 http://localhost:8000
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 环境变量配置
|
||||||
|
|
||||||
|
创建 `.env` 文件并配置:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# PostgreSQL数据库密码(必须设置)
|
||||||
|
POSTGRES_PASSWORD=your_secure_password_here
|
||||||
|
|
||||||
|
# AI服务配置(必须设置)
|
||||||
|
OPENAI_API_KEY=your_openai_api_key
|
||||||
|
DEFAULT_AI_PROVIDER=openai
|
||||||
|
DEFAULT_MODEL=gpt-4
|
||||||
|
|
||||||
|
# 本地账户登录(可选)
|
||||||
|
LOCAL_AUTH_ENABLED=true
|
||||||
|
LOCAL_AUTH_USERNAME=admin
|
||||||
|
LOCAL_AUTH_PASSWORD=admin123
|
||||||
|
|
||||||
|
# 其他配置见 backend/.env.example
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 服务说明
|
||||||
|
|
||||||
|
- **postgres**: PostgreSQL 18 数据库
|
||||||
|
- 端口:5432
|
||||||
|
- 数据持久化:`./postgres_data`
|
||||||
|
- 已优化配置,支持80-150并发用户
|
||||||
|
|
||||||
|
- **mumuainovel**: 主应用服务
|
||||||
|
- 端口:8000
|
||||||
|
- 自动等待数据库就绪
|
||||||
|
- 日志持久化:`./logs`
|
||||||
|
|
||||||
|
详细部署指南请参考:[Docker + PostgreSQL 部署文档](docs/docker-postgres-deployment.md)
|
||||||
|
|
||||||
|
### Docker Compose 部署(SQLite)
|
||||||
|
|
||||||
|
适合个人使用或小团队,配置更简单。
|
||||||
|
|
||||||
|
#### 使用 Docker Hub 镜像
|
||||||
|
|
||||||
项目已发布到 Docker Hub,可直接拉取使用:
|
项目已发布到 Docker Hub,可直接拉取使用:
|
||||||
|
|
||||||
@@ -197,7 +265,11 @@ networks:
|
|||||||
|
|
||||||
#### 2. 数据持久化
|
#### 2. 数据持久化
|
||||||
|
|
||||||
数据目录已通过 volume 挂载,数据不会丢失:
|
**PostgreSQL部署**:
|
||||||
|
- `./postgres_data`:PostgreSQL 数据库文件
|
||||||
|
- `./logs`:应用日志文件
|
||||||
|
|
||||||
|
**SQLite部署**:
|
||||||
- `./data`:SQLite 数据库文件
|
- `./data`:SQLite 数据库文件
|
||||||
- `./logs`:应用日志文件
|
- `./logs`:应用日志文件
|
||||||
|
|
||||||
@@ -238,11 +310,33 @@ services:
|
|||||||
|
|
||||||
## ⚙️ 配置说明
|
## ⚙️ 配置说明
|
||||||
|
|
||||||
|
### 数据库选择
|
||||||
|
|
||||||
|
项目支持两种数据库:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# .env 配置
|
||||||
|
DATABASE_URL=postgresql+asyncpg://mumuai:password@postgres:5432/mumuai_novel
|
||||||
|
```
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# .env 配置
|
||||||
|
DATABASE_URL=sqlite+aiosqlite:///data/ai_story.db
|
||||||
|
```
|
||||||
|
|
||||||
### 环境变量
|
### 环境变量
|
||||||
|
|
||||||
创建 `.env` 文件并配置以下变量:
|
创建 `.env` 文件并配置以下变量:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
|
# ===== 数据库配置 =====
|
||||||
|
# PostgreSQL(生产环境推荐)
|
||||||
|
DATABASE_URL=postgresql+asyncpg://mumuai:password@postgres:5432/mumuai_novel
|
||||||
|
POSTGRES_PASSWORD=your_secure_password_here
|
||||||
|
|
||||||
|
# 或使用 SQLite(开发环境)
|
||||||
|
# DATABASE_URL=sqlite+aiosqlite:///data/ai_story.db
|
||||||
|
|
||||||
# ===== AI 服务配置(必填)=====
|
# ===== AI 服务配置(必填)=====
|
||||||
# OpenAI 配置(支持官方API和中转API)
|
# OpenAI 配置(支持官方API和中转API)
|
||||||
OPENAI_API_KEY=your_openai_key_here
|
OPENAI_API_KEY=your_openai_key_here
|
||||||
@@ -325,24 +419,6 @@ OPENAI_API_KEY=sk-xxxxxxxxxxxxxxxx
|
|||||||
OPENAI_BASE_URL=https://api.new-api.com/v1
|
OPENAI_BASE_URL=https://api.new-api.com/v1
|
||||||
```
|
```
|
||||||
|
|
||||||
**API2D**
|
|
||||||
```bash
|
|
||||||
OPENAI_API_KEY=fk-xxxxxxxxxxxxxxxx
|
|
||||||
OPENAI_BASE_URL=https://api.api2d.com/v1
|
|
||||||
```
|
|
||||||
|
|
||||||
**OpenAI-SB**
|
|
||||||
```bash
|
|
||||||
OPENAI_API_KEY=sb-xxxxxxxxxxxxxxxx
|
|
||||||
OPENAI_BASE_URL=https://api.openai-sb.com/v1
|
|
||||||
```
|
|
||||||
|
|
||||||
**自建 One API / New API**
|
|
||||||
```bash
|
|
||||||
OPENAI_API_KEY=sk-xxxxxxxxxxxxxxxx
|
|
||||||
OPENAI_BASE_URL=https://your-domain.com/v1
|
|
||||||
```
|
|
||||||
|
|
||||||
##### 注意事项
|
##### 注意事项
|
||||||
|
|
||||||
- ✅ 所有支持 OpenAI 接口格式的服务都可以使用
|
- ✅ 所有支持 OpenAI 接口格式的服务都可以使用
|
||||||
@@ -436,7 +512,7 @@ MuMuAINovel/
|
|||||||
### 后端
|
### 后端
|
||||||
|
|
||||||
- **框架**:FastAPI 0.109.0
|
- **框架**:FastAPI 0.109.0
|
||||||
- **数据库**:SQLite + SQLAlchemy(异步)
|
- **数据库**:PostgreSQL / SQLite + SQLAlchemy(异步)
|
||||||
- **AI 集成**:OpenAI、Anthropic、Google Gemini SDK
|
- **AI 集成**:OpenAI、Anthropic、Google Gemini SDK
|
||||||
- **认证**:LinuxDO OAuth2、本地账户
|
- **认证**:LinuxDO OAuth2、本地账户
|
||||||
- **日志**:Python logging + 文件轮转
|
- **日志**:Python logging + 文件轮转
|
||||||
|
|||||||
+124
-38
@@ -1,54 +1,140 @@
|
|||||||
# AI服务配置
|
# ==========================================
|
||||||
# OpenAI配置
|
# MuMuAINovel 配置文件示例
|
||||||
OPENAI_API_KEY=your_openai_key_here
|
# ==========================================
|
||||||
OPENAI_BASE_URL=https://api.openai.com/v1
|
# 复制此文件为 .env 并修改配置值
|
||||||
|
# cp .env.example .env
|
||||||
# Anthropic配置
|
|
||||||
ANTHROPIC_API_KEY=your_anthropic_key_here
|
|
||||||
ANTHROPIC_BASE_URL=https://api.anthropic.com
|
|
||||||
|
|
||||||
# 默认AI提供商:openai, gemini, anthropic
|
|
||||||
DEFAULT_AI_PROVIDER=openai
|
|
||||||
DEFAULT_MODEL=gpt-4.1
|
|
||||||
DEFAULT_TEMPERATURE=0.8
|
|
||||||
DEFAULT_MAX_TOKENS=32000
|
|
||||||
|
|
||||||
|
# ==========================================
|
||||||
# 应用配置
|
# 应用配置
|
||||||
|
# ==========================================
|
||||||
APP_NAME=MuMuAINovel
|
APP_NAME=MuMuAINovel
|
||||||
APP_VERSION=1.0.0
|
APP_VERSION=1.0.0
|
||||||
APP_HOST=0.0.0.0
|
APP_HOST=0.0.0.0
|
||||||
APP_PORT=8000
|
APP_PORT=8000
|
||||||
DEBUG=true
|
DEBUG=True
|
||||||
|
|
||||||
|
# ==========================================
|
||||||
|
# 数据库配置
|
||||||
|
# ==========================================
|
||||||
|
|
||||||
|
# 选项1: PostgreSQL(生产环境推荐)
|
||||||
|
# DATABASE_URL=postgresql+asyncpg://username:password@localhost:5432/database_name
|
||||||
|
# 示例:
|
||||||
|
# DATABASE_URL=postgresql+asyncpg://mumuai:your_password@localhost:5432/mumuai_novel
|
||||||
|
|
||||||
|
# 选项2: SQLite(开发环境,默认)
|
||||||
|
DATABASE_URL=sqlite+aiosqlite:///data/ai_story.db
|
||||||
|
|
||||||
|
# PostgreSQL 连接池配置(优化后,支持80-150并发用户)
|
||||||
|
DATABASE_POOL_SIZE=30 # 核心连接数(默认30,小团队可用20)
|
||||||
|
DATABASE_MAX_OVERFLOW=20 # 最大溢出连接数(默认20,小团队可用10)
|
||||||
|
DATABASE_POOL_TIMEOUT=60 # 连接等待超时秒数(默认60)
|
||||||
|
DATABASE_POOL_RECYCLE=1800 # 连接回收时间秒数(默认1800=30分钟)
|
||||||
|
DATABASE_POOL_PRE_PING=True # 连接前检测是否有效
|
||||||
|
DATABASE_POOL_USE_LIFO=True # 使用LIFO策略提高连接复用率
|
||||||
|
|
||||||
|
# 会话监控配置
|
||||||
|
DATABASE_SESSION_MAX_ACTIVE=50 # 活跃会话警告阈值
|
||||||
|
DATABASE_SESSION_LEAK_THRESHOLD=100 # 会话泄漏严重告警阈值
|
||||||
|
|
||||||
|
# SQLite 性能优化配置(仅在使用SQLite时生效)
|
||||||
|
SQLITE_CACHE_SIZE_MB=128 # SQLite缓存大小(MB),默认128
|
||||||
|
SQLITE_MMAP_SIZE_MB=256 # 内存映射I/O大小(MB),默认256
|
||||||
|
SQLITE_WAL_AUTOCHECKPOINT=1000 # WAL自动检查点间隔
|
||||||
|
|
||||||
|
# 数据库监控配置
|
||||||
|
DATABASE_ENABLE_SLOW_QUERY_LOG=True # 启用慢查询日志
|
||||||
|
DATABASE_SLOW_QUERY_THRESHOLD=1.0 # 慢查询阈值(秒)
|
||||||
|
DATABASE_ENABLE_METRICS=True # 启用性能指标收集
|
||||||
|
|
||||||
|
# ==========================================
|
||||||
|
# 日志配置
|
||||||
|
# ==========================================
|
||||||
|
LOG_LEVEL=INFO
|
||||||
|
LOG_TO_FILE=True
|
||||||
|
LOG_FILE_PATH=logs/app.log
|
||||||
|
LOG_MAX_BYTES=10485760
|
||||||
|
LOG_BACKUP_COUNT=30
|
||||||
|
|
||||||
|
# ==========================================
|
||||||
|
# CORS配置
|
||||||
|
# ==========================================
|
||||||
|
CORS_ORIGINS=["http://localhost:8000","http://127.0.0.1:8000"]
|
||||||
|
|
||||||
|
# ==========================================
|
||||||
|
# AI服务配置
|
||||||
|
# ==========================================
|
||||||
|
|
||||||
|
# OpenAI配置
|
||||||
|
OPENAI_API_KEY=your_openai_api_key_here
|
||||||
|
OPENAI_BASE_URL=https://api.openai.com/v1
|
||||||
|
# 或使用兼容的API服务(如DeepSeek)
|
||||||
|
# OPENAI_BASE_URL=https://api.deepseek.com/v1
|
||||||
|
|
||||||
|
# Gemini配置(可选)
|
||||||
|
# GEMINI_API_KEY=your_gemini_api_key_here
|
||||||
|
# GEMINI_BASE_URL=https://generativelanguage.googleapis.com
|
||||||
|
|
||||||
|
# Anthropic配置(可选)
|
||||||
|
# ANTHROPIC_API_KEY=your_anthropic_api_key_here
|
||||||
|
# ANTHROPIC_BASE_URL=https://api.anthropic.com
|
||||||
|
|
||||||
|
# 默认AI配置
|
||||||
|
DEFAULT_AI_PROVIDER=openai
|
||||||
|
DEFAULT_MODEL=gpt-4
|
||||||
|
DEFAULT_TEMPERATURE=0.7
|
||||||
|
DEFAULT_MAX_TOKENS=2000
|
||||||
|
|
||||||
|
# ==========================================
|
||||||
# LinuxDO OAuth2 配置(可选)
|
# LinuxDO OAuth2 配置(可选)
|
||||||
# 注意:Docker部署时,LINUXDO_REDIRECT_URI 应该使用实际的域名或服务器IP
|
# ==========================================
|
||||||
# 本地开发: http://localhost:8000/api/auth/callback
|
# LINUXDO_CLIENT_ID=your_client_id
|
||||||
# 生产环境: https://your-domain.com/api/auth/callback 或 http://your-server-ip:8000/api/auth/callback
|
# LINUXDO_CLIENT_SECRET=your_client_secret
|
||||||
LINUXDO_CLIENT_ID=your_client_id_here
|
# LINUXDO_REDIRECT_URI=http://localhost:8000/api/auth/callback
|
||||||
LINUXDO_CLIENT_SECRET=your_client_secret_here
|
|
||||||
LINUXDO_REDIRECT_URI=http://localhost:8000/api/auth/callback
|
|
||||||
|
|
||||||
# 前端URL配置(用于OAuth回调后重定向到前端)
|
# 前端URL(OAuth回调后重定向)
|
||||||
# 本地开发: http://localhost:8000
|
|
||||||
# 生产环境: https://your-domain.com 或 http://your-server-ip:8000
|
|
||||||
FRONTEND_URL=http://localhost:8000
|
FRONTEND_URL=http://localhost:8000
|
||||||
|
|
||||||
# 本地账户登录配置
|
# 初始管理员(LinuxDO user_id)
|
||||||
# 启用本地账户登录(true/false)
|
# INITIAL_ADMIN_LINUXDO_ID=12345
|
||||||
LOCAL_AUTH_ENABLED=true
|
|
||||||
# 本地登录用户名
|
|
||||||
LOCAL_AUTH_USERNAME=admin
|
|
||||||
# 本地登录密码
|
|
||||||
LOCAL_AUTH_PASSWORD=your_secure_password_here
|
|
||||||
# 本地用户显示名称
|
|
||||||
LOCAL_AUTH_DISPLAY_NAME=管理员
|
|
||||||
|
|
||||||
|
# ==========================================
|
||||||
|
# 本地账户登录配置
|
||||||
|
# ==========================================
|
||||||
|
LOCAL_AUTH_ENABLED=True
|
||||||
|
LOCAL_AUTH_USERNAME=admin
|
||||||
|
LOCAL_AUTH_PASSWORD=admin123
|
||||||
|
LOCAL_AUTH_DISPLAY_NAME=本地管理员
|
||||||
|
|
||||||
|
# ==========================================
|
||||||
# 会话配置
|
# 会话配置
|
||||||
# 会话过期时间(分钟),默认120分钟(2小时)
|
# ==========================================
|
||||||
SESSION_EXPIRE_MINUTES=120
|
SESSION_EXPIRE_MINUTES=120
|
||||||
# 会话刷新阈值(分钟),剩余时间少于此值时可刷新,默认30分钟
|
|
||||||
SESSION_REFRESH_THRESHOLD_MINUTES=30
|
SESSION_REFRESH_THRESHOLD_MINUTES=30
|
||||||
|
|
||||||
# CORS配置(生产环境)
|
# ==========================================
|
||||||
# 允许的跨域来源,多个用逗号分隔
|
# 部署配置说明
|
||||||
# CORS_ORIGINS=https://your-domain.com,https://www.your-domain.com
|
# ==========================================
|
||||||
|
|
||||||
|
# 生产环境 PostgreSQL 配置示例(50-100并发用户):
|
||||||
|
# DATABASE_URL=postgresql+asyncpg://mumuai:SecurePassword123@db.example.com:5432/mumuai_prod
|
||||||
|
# DATABASE_POOL_SIZE=30
|
||||||
|
# DATABASE_MAX_OVERFLOW=20
|
||||||
|
# DATABASE_POOL_TIMEOUT=60
|
||||||
|
# DATABASE_POOL_RECYCLE=1800
|
||||||
|
# DATABASE_SESSION_MAX_ACTIVE=50
|
||||||
|
# DEBUG=False
|
||||||
|
# LOG_LEVEL=WARNING
|
||||||
|
|
||||||
|
# 高并发环境 PostgreSQL 配置示例(100+并发用户):
|
||||||
|
# DATABASE_URL=postgresql+asyncpg://mumuai:SecurePassword123@db.example.com:5432/mumuai_prod
|
||||||
|
# DATABASE_POOL_SIZE=40
|
||||||
|
# DATABASE_MAX_OVERFLOW=30
|
||||||
|
# DATABASE_SESSION_MAX_ACTIVE=80
|
||||||
|
# DATABASE_SESSION_LEAK_THRESHOLD=150
|
||||||
|
|
||||||
|
# Docker 部署配置示例:
|
||||||
|
# DATABASE_URL=postgresql+asyncpg://mumuai:password@postgres:5432/mumuai_novel
|
||||||
|
# OPENAI_BASE_URL=https://api.openai.com/v1
|
||||||
|
# FRONTEND_URL=https://your-domain.com
|
||||||
|
# LINUXDO_REDIRECT_URI=https://your-domain.com/api/auth/callback
|
||||||
+125
-17
@@ -43,6 +43,39 @@ logger = get_logger(__name__)
|
|||||||
db_write_locks: dict[str, Lock] = {}
|
db_write_locks: dict[str, Lock] = {}
|
||||||
|
|
||||||
|
|
||||||
|
async def verify_project_access(project_id: str, user_id: str, db: AsyncSession) -> Project:
|
||||||
|
"""
|
||||||
|
验证用户是否有权访问指定项目
|
||||||
|
|
||||||
|
Args:
|
||||||
|
project_id: 项目ID
|
||||||
|
user_id: 用户ID
|
||||||
|
db: 数据库会话
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Project: 项目对象
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
HTTPException: 401 未登录,404 项目不存在或无权访问
|
||||||
|
"""
|
||||||
|
if not user_id:
|
||||||
|
raise HTTPException(status_code=401, detail="未登录")
|
||||||
|
|
||||||
|
result = await db.execute(
|
||||||
|
select(Project).where(
|
||||||
|
Project.id == project_id,
|
||||||
|
Project.user_id == user_id
|
||||||
|
)
|
||||||
|
)
|
||||||
|
project = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if not project:
|
||||||
|
logger.warning(f"项目访问被拒绝: project_id={project_id}, user_id={user_id}")
|
||||||
|
raise HTTPException(status_code=404, detail="项目不存在或无权访问")
|
||||||
|
|
||||||
|
return project
|
||||||
|
|
||||||
|
|
||||||
async def get_db_write_lock(user_id: str) -> Lock:
|
async def get_db_write_lock(user_id: str) -> Lock:
|
||||||
"""获取或创建用户的数据库写入锁"""
|
"""获取或创建用户的数据库写入锁"""
|
||||||
if user_id not in db_write_locks:
|
if user_id not in db_write_locks:
|
||||||
@@ -54,16 +87,13 @@ async def get_db_write_lock(user_id: str) -> Lock:
|
|||||||
@router.post("", response_model=ChapterResponse, summary="创建章节")
|
@router.post("", response_model=ChapterResponse, summary="创建章节")
|
||||||
async def create_chapter(
|
async def create_chapter(
|
||||||
chapter: ChapterCreate,
|
chapter: ChapterCreate,
|
||||||
|
request: Request,
|
||||||
db: AsyncSession = Depends(get_db)
|
db: AsyncSession = Depends(get_db)
|
||||||
):
|
):
|
||||||
"""创建新的章节"""
|
"""创建新的章节"""
|
||||||
# 验证项目是否存在
|
# 验证用户权限和项目是否存在
|
||||||
result = await db.execute(
|
user_id = getattr(request.state, 'user_id', None)
|
||||||
select(Project).where(Project.id == chapter.project_id)
|
project = await verify_project_access(chapter.project_id, user_id, db)
|
||||||
)
|
|
||||||
project = result.scalar_one_or_none()
|
|
||||||
if not project:
|
|
||||||
raise HTTPException(status_code=404, detail="项目不存在")
|
|
||||||
|
|
||||||
# 计算字数
|
# 计算字数
|
||||||
word_count = len(chapter.content)
|
word_count = len(chapter.content)
|
||||||
@@ -85,9 +115,14 @@ async def create_chapter(
|
|||||||
@router.get("/project/{project_id}", response_model=ChapterListResponse, summary="获取项目的所有章节")
|
@router.get("/project/{project_id}", response_model=ChapterListResponse, summary="获取项目的所有章节")
|
||||||
async def get_project_chapters(
|
async def get_project_chapters(
|
||||||
project_id: str,
|
project_id: str,
|
||||||
|
request: Request,
|
||||||
db: AsyncSession = Depends(get_db)
|
db: AsyncSession = Depends(get_db)
|
||||||
):
|
):
|
||||||
"""获取指定项目的所有章节(路径参数版本)"""
|
"""获取指定项目的所有章节(路径参数版本)"""
|
||||||
|
# 验证用户权限
|
||||||
|
user_id = getattr(request.state, 'user_id', None)
|
||||||
|
await verify_project_access(project_id, user_id, db)
|
||||||
|
|
||||||
# 获取总数
|
# 获取总数
|
||||||
count_result = await db.execute(
|
count_result = await db.execute(
|
||||||
select(func.count(Chapter.id)).where(Chapter.project_id == project_id)
|
select(func.count(Chapter.id)).where(Chapter.project_id == project_id)
|
||||||
@@ -108,6 +143,7 @@ async def get_project_chapters(
|
|||||||
@router.get("/{chapter_id}", response_model=ChapterResponse, summary="获取章节详情")
|
@router.get("/{chapter_id}", response_model=ChapterResponse, summary="获取章节详情")
|
||||||
async def get_chapter(
|
async def get_chapter(
|
||||||
chapter_id: str,
|
chapter_id: str,
|
||||||
|
request: Request,
|
||||||
db: AsyncSession = Depends(get_db)
|
db: AsyncSession = Depends(get_db)
|
||||||
):
|
):
|
||||||
"""根据ID获取章节详情"""
|
"""根据ID获取章节详情"""
|
||||||
@@ -119,12 +155,17 @@ async def get_chapter(
|
|||||||
if not chapter:
|
if not chapter:
|
||||||
raise HTTPException(status_code=404, detail="章节不存在")
|
raise HTTPException(status_code=404, detail="章节不存在")
|
||||||
|
|
||||||
|
# 验证用户权限
|
||||||
|
user_id = getattr(request.state, 'user_id', None)
|
||||||
|
await verify_project_access(chapter.project_id, user_id, db)
|
||||||
|
|
||||||
return chapter
|
return chapter
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{chapter_id}/navigation", summary="获取章节导航信息")
|
@router.get("/{chapter_id}/navigation", summary="获取章节导航信息")
|
||||||
async def get_chapter_navigation(
|
async def get_chapter_navigation(
|
||||||
chapter_id: str,
|
chapter_id: str,
|
||||||
|
request: Request,
|
||||||
db: AsyncSession = Depends(get_db)
|
db: AsyncSession = Depends(get_db)
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@@ -140,6 +181,10 @@ async def get_chapter_navigation(
|
|||||||
if not current_chapter:
|
if not current_chapter:
|
||||||
raise HTTPException(status_code=404, detail="章节不存在")
|
raise HTTPException(status_code=404, detail="章节不存在")
|
||||||
|
|
||||||
|
# 验证用户权限
|
||||||
|
user_id = getattr(request.state, 'user_id', None)
|
||||||
|
await verify_project_access(current_chapter.project_id, user_id, db)
|
||||||
|
|
||||||
# 获取上一章
|
# 获取上一章
|
||||||
prev_result = await db.execute(
|
prev_result = await db.execute(
|
||||||
select(Chapter)
|
select(Chapter)
|
||||||
@@ -183,6 +228,7 @@ async def get_chapter_navigation(
|
|||||||
async def update_chapter(
|
async def update_chapter(
|
||||||
chapter_id: str,
|
chapter_id: str,
|
||||||
chapter_update: ChapterUpdate,
|
chapter_update: ChapterUpdate,
|
||||||
|
request: Request,
|
||||||
db: AsyncSession = Depends(get_db)
|
db: AsyncSession = Depends(get_db)
|
||||||
):
|
):
|
||||||
"""更新章节信息"""
|
"""更新章节信息"""
|
||||||
@@ -194,6 +240,10 @@ async def update_chapter(
|
|||||||
if not chapter:
|
if not chapter:
|
||||||
raise HTTPException(status_code=404, detail="章节不存在")
|
raise HTTPException(status_code=404, detail="章节不存在")
|
||||||
|
|
||||||
|
# 验证用户权限
|
||||||
|
user_id = getattr(request.state, 'user_id', None)
|
||||||
|
await verify_project_access(chapter.project_id, user_id, db)
|
||||||
|
|
||||||
# 记录旧字数
|
# 记录旧字数
|
||||||
old_word_count = chapter.word_count or 0
|
old_word_count = chapter.word_count or 0
|
||||||
|
|
||||||
@@ -223,6 +273,7 @@ async def update_chapter(
|
|||||||
@router.delete("/{chapter_id}", summary="删除章节")
|
@router.delete("/{chapter_id}", summary="删除章节")
|
||||||
async def delete_chapter(
|
async def delete_chapter(
|
||||||
chapter_id: str,
|
chapter_id: str,
|
||||||
|
request: Request,
|
||||||
db: AsyncSession = Depends(get_db)
|
db: AsyncSession = Depends(get_db)
|
||||||
):
|
):
|
||||||
"""删除章节"""
|
"""删除章节"""
|
||||||
@@ -234,6 +285,10 @@ async def delete_chapter(
|
|||||||
if not chapter:
|
if not chapter:
|
||||||
raise HTTPException(status_code=404, detail="章节不存在")
|
raise HTTPException(status_code=404, detail="章节不存在")
|
||||||
|
|
||||||
|
# 验证用户权限
|
||||||
|
user_id = getattr(request.state, 'user_id', None)
|
||||||
|
await verify_project_access(chapter.project_id, user_id, db)
|
||||||
|
|
||||||
# 更新项目字数
|
# 更新项目字数
|
||||||
result = await db.execute(
|
result = await db.execute(
|
||||||
select(Project).where(Project.id == chapter.project_id)
|
select(Project).where(Project.id == chapter.project_id)
|
||||||
@@ -481,6 +536,7 @@ async def build_smart_chapter_context(
|
|||||||
@router.get("/{chapter_id}/can-generate", summary="检查章节是否可以生成")
|
@router.get("/{chapter_id}/can-generate", summary="检查章节是否可以生成")
|
||||||
async def check_can_generate(
|
async def check_can_generate(
|
||||||
chapter_id: str,
|
chapter_id: str,
|
||||||
|
request: Request,
|
||||||
db: AsyncSession = Depends(get_db)
|
db: AsyncSession = Depends(get_db)
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@@ -495,6 +551,10 @@ async def check_can_generate(
|
|||||||
if not chapter:
|
if not chapter:
|
||||||
raise HTTPException(status_code=404, detail="章节不存在")
|
raise HTTPException(status_code=404, detail="章节不存在")
|
||||||
|
|
||||||
|
# 验证用户权限
|
||||||
|
user_id = getattr(request.state, 'user_id', None)
|
||||||
|
await verify_project_access(chapter.project_id, user_id, db)
|
||||||
|
|
||||||
# 检查前置条件
|
# 检查前置条件
|
||||||
can_generate, error_msg, previous_chapters = await check_prerequisites(db, chapter)
|
can_generate, error_msg, previous_chapters = await check_prerequisites(db, chapter)
|
||||||
|
|
||||||
@@ -1238,6 +1298,7 @@ async def generate_chapter_content_stream(
|
|||||||
@router.get("/{chapter_id}/analysis/status", summary="查询章节分析任务状态")
|
@router.get("/{chapter_id}/analysis/status", summary="查询章节分析任务状态")
|
||||||
async def get_analysis_task_status(
|
async def get_analysis_task_status(
|
||||||
chapter_id: str,
|
chapter_id: str,
|
||||||
|
request: Request,
|
||||||
db: AsyncSession = Depends(get_db)
|
db: AsyncSession = Depends(get_db)
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@@ -1248,16 +1309,32 @@ async def get_analysis_task_status(
|
|||||||
- 如果任务状态为pending且超过2分钟未启动,自动标记为failed
|
- 如果任务状态为pending且超过2分钟未启动,自动标记为failed
|
||||||
|
|
||||||
返回:
|
返回:
|
||||||
- task_id: 任务ID
|
- has_task: 是否存在分析任务
|
||||||
- status: pending/running/completed/failed
|
- task_id: 任务ID(如果存在)
|
||||||
|
- status: pending/running/completed/failed/none(如果不存在则为none)
|
||||||
- progress: 0-100
|
- progress: 0-100
|
||||||
- error_message: 错误信息(如果失败)
|
- error_message: 错误信息(如果失败)
|
||||||
- auto_recovered: 是否被自动恢复
|
- auto_recovered: 是否被自动恢复
|
||||||
- created_at: 创建时间
|
- created_at: 创建时间
|
||||||
- completed_at: 完成时间
|
- completed_at: 完成时间
|
||||||
|
|
||||||
|
注意:当章节不存在或无权访问时返回404,当没有分析任务时返回has_task=false
|
||||||
"""
|
"""
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
|
|
||||||
|
# 先获取章节以验证存在性和权限
|
||||||
|
chapter_result = await db.execute(
|
||||||
|
select(Chapter).where(Chapter.id == chapter_id)
|
||||||
|
)
|
||||||
|
chapter = chapter_result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if not chapter:
|
||||||
|
raise HTTPException(status_code=404, detail="章节不存在")
|
||||||
|
|
||||||
|
# 验证用户权限
|
||||||
|
user_id = getattr(request.state, 'user_id', None)
|
||||||
|
await verify_project_access(chapter.project_id, user_id, db)
|
||||||
|
|
||||||
# 获取该章节最新的分析任务
|
# 获取该章节最新的分析任务
|
||||||
result = await db.execute(
|
result = await db.execute(
|
||||||
select(AnalysisTask)
|
select(AnalysisTask)
|
||||||
@@ -1268,7 +1345,19 @@ async def get_analysis_task_status(
|
|||||||
task = result.scalar_one_or_none()
|
task = result.scalar_one_or_none()
|
||||||
|
|
||||||
if not task:
|
if not task:
|
||||||
raise HTTPException(status_code=404, detail="未找到分析任务")
|
# 返回无任务状态,而不是抛出404错误
|
||||||
|
return {
|
||||||
|
"has_task": False,
|
||||||
|
"chapter_id": chapter_id,
|
||||||
|
"status": "none",
|
||||||
|
"progress": 0,
|
||||||
|
"error_message": None,
|
||||||
|
"auto_recovered": False,
|
||||||
|
"task_id": None,
|
||||||
|
"created_at": None,
|
||||||
|
"started_at": None,
|
||||||
|
"completed_at": None
|
||||||
|
}
|
||||||
|
|
||||||
auto_recovered = False
|
auto_recovered = False
|
||||||
current_time = datetime.now()
|
current_time = datetime.now()
|
||||||
@@ -1299,6 +1388,7 @@ async def get_analysis_task_status(
|
|||||||
logger.warning(f"🔄 自动恢复未启动的任务: {task.id}, 章节: {chapter_id}")
|
logger.warning(f"🔄 自动恢复未启动的任务: {task.id}, 章节: {chapter_id}")
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
"has_task": True,
|
||||||
"task_id": task.id,
|
"task_id": task.id,
|
||||||
"chapter_id": task.chapter_id,
|
"chapter_id": task.chapter_id,
|
||||||
"status": task.status,
|
"status": task.status,
|
||||||
@@ -1314,6 +1404,7 @@ async def get_analysis_task_status(
|
|||||||
@router.get("/{chapter_id}/analysis", summary="获取章节分析结果")
|
@router.get("/{chapter_id}/analysis", summary="获取章节分析结果")
|
||||||
async def get_chapter_analysis(
|
async def get_chapter_analysis(
|
||||||
chapter_id: str,
|
chapter_id: str,
|
||||||
|
request: Request,
|
||||||
db: AsyncSession = Depends(get_db)
|
db: AsyncSession = Depends(get_db)
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@@ -1325,6 +1416,16 @@ async def get_chapter_analysis(
|
|||||||
- memories: 提取的记忆列表
|
- memories: 提取的记忆列表
|
||||||
- created_at: 分析时间
|
- created_at: 分析时间
|
||||||
"""
|
"""
|
||||||
|
# 先获取章节以验证权限
|
||||||
|
chapter_result_check = await db.execute(
|
||||||
|
select(Chapter).where(Chapter.id == chapter_id)
|
||||||
|
)
|
||||||
|
chapter_check = chapter_result_check.scalar_one_or_none()
|
||||||
|
if chapter_check:
|
||||||
|
# 验证用户权限
|
||||||
|
user_id = getattr(request.state, 'user_id', None)
|
||||||
|
await verify_project_access(chapter_check.project_id, user_id, db)
|
||||||
|
|
||||||
# 获取分析结果
|
# 获取分析结果
|
||||||
analysis_result = await db.execute(
|
analysis_result = await db.execute(
|
||||||
select(PlotAnalysis)
|
select(PlotAnalysis)
|
||||||
@@ -1369,6 +1470,7 @@ async def get_chapter_analysis(
|
|||||||
@router.get("/{chapter_id}/annotations", summary="获取章节标注数据")
|
@router.get("/{chapter_id}/annotations", summary="获取章节标注数据")
|
||||||
async def get_chapter_annotations(
|
async def get_chapter_annotations(
|
||||||
chapter_id: str,
|
chapter_id: str,
|
||||||
|
request: Request,
|
||||||
db: AsyncSession = Depends(get_db)
|
db: AsyncSession = Depends(get_db)
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@@ -1377,6 +1479,9 @@ async def get_chapter_annotations(
|
|||||||
返回格式化的标注列表,包含精确位置信息
|
返回格式化的标注列表,包含精确位置信息
|
||||||
适用于章节内容的可视化标注展示
|
适用于章节内容的可视化标注展示
|
||||||
"""
|
"""
|
||||||
|
# 验证用户权限
|
||||||
|
user_id = getattr(request.state, 'user_id', None)
|
||||||
|
|
||||||
# 获取章节
|
# 获取章节
|
||||||
chapter_result = await db.execute(
|
chapter_result = await db.execute(
|
||||||
select(Chapter).where(Chapter.id == chapter_id)
|
select(Chapter).where(Chapter.id == chapter_id)
|
||||||
@@ -1386,6 +1491,9 @@ async def get_chapter_annotations(
|
|||||||
if not chapter:
|
if not chapter:
|
||||||
raise HTTPException(status_code=404, detail="章节不存在")
|
raise HTTPException(status_code=404, detail="章节不存在")
|
||||||
|
|
||||||
|
# 验证项目访问权限
|
||||||
|
await verify_project_access(chapter.project_id, user_id, db)
|
||||||
|
|
||||||
# 获取分析结果
|
# 获取分析结果
|
||||||
analysis_result = await db.execute(
|
analysis_result = await db.execute(
|
||||||
select(PlotAnalysis)
|
select(PlotAnalysis)
|
||||||
@@ -1623,13 +1731,8 @@ async def batch_generate_chapters_in_order(
|
|||||||
if not user_id:
|
if not user_id:
|
||||||
raise HTTPException(status_code=401, detail="未登录")
|
raise HTTPException(status_code=401, detail="未登录")
|
||||||
|
|
||||||
# 验证项目存在
|
# 验证项目存在和用户权限
|
||||||
project_result = await db.execute(
|
project = await verify_project_access(project_id, user_id, db)
|
||||||
select(Project).where(Project.id == project_id)
|
|
||||||
)
|
|
||||||
project = project_result.scalar_one_or_none()
|
|
||||||
if not project:
|
|
||||||
raise HTTPException(status_code=404, detail="项目不存在")
|
|
||||||
|
|
||||||
# 获取项目的所有章节,按序号排序
|
# 获取项目的所有章节,按序号排序
|
||||||
result = await db.execute(
|
result = await db.execute(
|
||||||
@@ -1750,12 +1853,17 @@ async def get_batch_generation_status(
|
|||||||
@router.get("/project/{project_id}/batch-generate/active", summary="获取项目当前运行中的批量生成任务")
|
@router.get("/project/{project_id}/batch-generate/active", summary="获取项目当前运行中的批量生成任务")
|
||||||
async def get_active_batch_generation(
|
async def get_active_batch_generation(
|
||||||
project_id: str,
|
project_id: str,
|
||||||
|
request: Request,
|
||||||
db: AsyncSession = Depends(get_db)
|
db: AsyncSession = Depends(get_db)
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
获取项目当前运行中的批量生成任务
|
获取项目当前运行中的批量生成任务
|
||||||
用于页面刷新后恢复任务状态
|
用于页面刷新后恢复任务状态
|
||||||
"""
|
"""
|
||||||
|
# 验证用户权限
|
||||||
|
user_id = getattr(request.state, 'user_id', None)
|
||||||
|
await verify_project_access(project_id, user_id, db)
|
||||||
|
|
||||||
result = await db.execute(
|
result = await db.execute(
|
||||||
select(BatchGenerationTask)
|
select(BatchGenerationTask)
|
||||||
.where(BatchGenerationTask.project_id == project_id)
|
.where(BatchGenerationTask.project_id == project_id)
|
||||||
|
|||||||
@@ -24,12 +24,50 @@ router = APIRouter(prefix="/characters", tags=["角色管理"])
|
|||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
async def verify_project_access(project_id: str, user_id: str, db: AsyncSession) -> Project:
|
||||||
|
"""
|
||||||
|
验证用户是否有权访问指定项目
|
||||||
|
|
||||||
|
Args:
|
||||||
|
project_id: 项目ID
|
||||||
|
user_id: 用户ID
|
||||||
|
db: 数据库会话
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Project: 项目对象
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
HTTPException: 401 未登录,404 项目不存在或无权访问
|
||||||
|
"""
|
||||||
|
if not user_id:
|
||||||
|
raise HTTPException(status_code=401, detail="未登录")
|
||||||
|
|
||||||
|
result = await db.execute(
|
||||||
|
select(Project).where(
|
||||||
|
Project.id == project_id,
|
||||||
|
Project.user_id == user_id
|
||||||
|
)
|
||||||
|
)
|
||||||
|
project = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if not project:
|
||||||
|
logger.warning(f"项目访问被拒绝: project_id={project_id}, user_id={user_id}")
|
||||||
|
raise HTTPException(status_code=404, detail="项目不存在或无权访问")
|
||||||
|
|
||||||
|
return project
|
||||||
|
|
||||||
|
|
||||||
@router.get("", response_model=CharacterListResponse, summary="获取角色列表")
|
@router.get("", response_model=CharacterListResponse, summary="获取角色列表")
|
||||||
async def get_characters(
|
async def get_characters(
|
||||||
project_id: str,
|
project_id: str,
|
||||||
|
request: Request,
|
||||||
db: AsyncSession = Depends(get_db)
|
db: AsyncSession = Depends(get_db)
|
||||||
):
|
):
|
||||||
"""获取指定项目的所有角色(query参数版本)"""
|
"""获取指定项目的所有角色(query参数版本)"""
|
||||||
|
# 验证用户权限
|
||||||
|
user_id = getattr(request.state, 'user_id', None)
|
||||||
|
await verify_project_access(project_id, user_id, db)
|
||||||
|
|
||||||
# 获取总数
|
# 获取总数
|
||||||
count_result = await db.execute(
|
count_result = await db.execute(
|
||||||
select(func.count(Character.id)).where(Character.project_id == project_id)
|
select(func.count(Character.id)).where(Character.project_id == project_id)
|
||||||
@@ -93,9 +131,14 @@ async def get_characters(
|
|||||||
@router.get("/project/{project_id}", response_model=CharacterListResponse, summary="获取项目的所有角色")
|
@router.get("/project/{project_id}", response_model=CharacterListResponse, summary="获取项目的所有角色")
|
||||||
async def get_project_characters(
|
async def get_project_characters(
|
||||||
project_id: str,
|
project_id: str,
|
||||||
|
request: Request,
|
||||||
db: AsyncSession = Depends(get_db)
|
db: AsyncSession = Depends(get_db)
|
||||||
):
|
):
|
||||||
"""获取指定项目的所有角色(路径参数版本)"""
|
"""获取指定项目的所有角色(路径参数版本)"""
|
||||||
|
# 验证用户权限
|
||||||
|
user_id = getattr(request.state, 'user_id', None)
|
||||||
|
await verify_project_access(project_id, user_id, db)
|
||||||
|
|
||||||
# 获取总数
|
# 获取总数
|
||||||
count_result = await db.execute(
|
count_result = await db.execute(
|
||||||
select(func.count(Character.id)).where(Character.project_id == project_id)
|
select(func.count(Character.id)).where(Character.project_id == project_id)
|
||||||
@@ -159,6 +202,7 @@ async def get_project_characters(
|
|||||||
@router.get("/{character_id}", response_model=CharacterResponse, summary="获取角色详情")
|
@router.get("/{character_id}", response_model=CharacterResponse, summary="获取角色详情")
|
||||||
async def get_character(
|
async def get_character(
|
||||||
character_id: str,
|
character_id: str,
|
||||||
|
request: Request,
|
||||||
db: AsyncSession = Depends(get_db)
|
db: AsyncSession = Depends(get_db)
|
||||||
):
|
):
|
||||||
"""根据ID获取角色详情"""
|
"""根据ID获取角色详情"""
|
||||||
@@ -170,6 +214,10 @@ async def get_character(
|
|||||||
if not character:
|
if not character:
|
||||||
raise HTTPException(status_code=404, detail="角色不存在")
|
raise HTTPException(status_code=404, detail="角色不存在")
|
||||||
|
|
||||||
|
# 验证用户权限
|
||||||
|
user_id = getattr(request.state, 'user_id', None)
|
||||||
|
await verify_project_access(character.project_id, user_id, db)
|
||||||
|
|
||||||
return character
|
return character
|
||||||
|
|
||||||
|
|
||||||
@@ -177,6 +225,7 @@ async def get_character(
|
|||||||
async def update_character(
|
async def update_character(
|
||||||
character_id: str,
|
character_id: str,
|
||||||
character_update: CharacterUpdate,
|
character_update: CharacterUpdate,
|
||||||
|
request: Request,
|
||||||
db: AsyncSession = Depends(get_db)
|
db: AsyncSession = Depends(get_db)
|
||||||
):
|
):
|
||||||
"""更新角色信息"""
|
"""更新角色信息"""
|
||||||
@@ -188,6 +237,10 @@ async def update_character(
|
|||||||
if not character:
|
if not character:
|
||||||
raise HTTPException(status_code=404, detail="角色不存在")
|
raise HTTPException(status_code=404, detail="角色不存在")
|
||||||
|
|
||||||
|
# 验证用户权限
|
||||||
|
user_id = getattr(request.state, 'user_id', None)
|
||||||
|
await verify_project_access(character.project_id, user_id, db)
|
||||||
|
|
||||||
# 更新字段
|
# 更新字段
|
||||||
update_data = character_update.model_dump(exclude_unset=True)
|
update_data = character_update.model_dump(exclude_unset=True)
|
||||||
for field, value in update_data.items():
|
for field, value in update_data.items():
|
||||||
@@ -201,6 +254,7 @@ async def update_character(
|
|||||||
@router.delete("/{character_id}", summary="删除角色")
|
@router.delete("/{character_id}", summary="删除角色")
|
||||||
async def delete_character(
|
async def delete_character(
|
||||||
character_id: str,
|
character_id: str,
|
||||||
|
request: Request,
|
||||||
db: AsyncSession = Depends(get_db)
|
db: AsyncSession = Depends(get_db)
|
||||||
):
|
):
|
||||||
"""删除角色"""
|
"""删除角色"""
|
||||||
@@ -212,6 +266,10 @@ async def delete_character(
|
|||||||
if not character:
|
if not character:
|
||||||
raise HTTPException(status_code=404, detail="角色不存在")
|
raise HTTPException(status_code=404, detail="角色不存在")
|
||||||
|
|
||||||
|
# 验证用户权限
|
||||||
|
user_id = getattr(request.state, 'user_id', None)
|
||||||
|
await verify_project_access(character.project_id, user_id, db)
|
||||||
|
|
||||||
await db.delete(character)
|
await db.delete(character)
|
||||||
await db.commit()
|
await db.commit()
|
||||||
|
|
||||||
@@ -233,13 +291,9 @@ async def generate_character(
|
|||||||
|
|
||||||
生成内容包括:姓名、年龄、性别、性格、外貌、背景故事、人际关系等
|
生成内容包括:姓名、年龄、性别、性格、外貌、背景故事、人际关系等
|
||||||
"""
|
"""
|
||||||
# 验证项目是否存在并获取项目信息
|
# 验证用户权限和项目是否存在
|
||||||
result = await db.execute(
|
user_id = getattr(http_request.state, 'user_id', None)
|
||||||
select(Project).where(Project.id == request.project_id)
|
project = await verify_project_access(request.project_id, user_id, db)
|
||||||
)
|
|
||||||
project = result.scalar_one_or_none()
|
|
||||||
if not project:
|
|
||||||
raise HTTPException(status_code=404, detail="项目不存在")
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 获取已存在的角色列表,用于关系网络
|
# 获取已存在的角色列表,用于关系网络
|
||||||
@@ -295,9 +349,6 @@ async def generate_character(
|
|||||||
user_input=user_input
|
user_input=user_input
|
||||||
)
|
)
|
||||||
|
|
||||||
# 获取user_id用于MCP工具调用
|
|
||||||
user_id = http_request.state.user_id if hasattr(http_request.state, 'user_id') else 'default_user'
|
|
||||||
|
|
||||||
# 调用AI生成角色(支持MCP工具)
|
# 调用AI生成角色(支持MCP工具)
|
||||||
logger.info(f"🎯 开始为项目 {request.project_id} 生成角色(启用MCP)")
|
logger.info(f"🎯 开始为项目 {request.project_id} 生成角色(启用MCP)")
|
||||||
logger.info(f" - 角色名:{request.name or 'AI生成'}")
|
logger.info(f" - 角色名:{request.name or 'AI生成'}")
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ from typing import List, Optional
|
|||||||
from app.database import get_db
|
from app.database import get_db
|
||||||
from app.models.memory import StoryMemory, PlotAnalysis
|
from app.models.memory import StoryMemory, PlotAnalysis
|
||||||
from app.models.chapter import Chapter
|
from app.models.chapter import Chapter
|
||||||
|
from app.models.project import Project
|
||||||
from app.services.memory_service import memory_service
|
from app.services.memory_service import memory_service
|
||||||
from app.services.plot_analyzer import get_plot_analyzer
|
from app.services.plot_analyzer import get_plot_analyzer
|
||||||
from app.services.ai_service import create_user_ai_service
|
from app.services.ai_service import create_user_ai_service
|
||||||
@@ -17,6 +18,26 @@ logger = get_logger(__name__)
|
|||||||
router = APIRouter(prefix="/api/memories", tags=["memories"])
|
router = APIRouter(prefix="/api/memories", tags=["memories"])
|
||||||
|
|
||||||
|
|
||||||
|
async def verify_project_access(project_id: str, user_id: str, db: AsyncSession) -> Project:
|
||||||
|
"""验证用户是否有权访问指定项目"""
|
||||||
|
if not user_id:
|
||||||
|
raise HTTPException(status_code=401, detail="未登录")
|
||||||
|
|
||||||
|
result = await db.execute(
|
||||||
|
select(Project).where(
|
||||||
|
Project.id == project_id,
|
||||||
|
Project.user_id == user_id
|
||||||
|
)
|
||||||
|
)
|
||||||
|
project = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if not project:
|
||||||
|
logger.warning(f"项目访问被拒绝: project_id={project_id}, user_id={user_id}")
|
||||||
|
raise HTTPException(status_code=404, detail="项目不存在或无权访问")
|
||||||
|
|
||||||
|
return project
|
||||||
|
|
||||||
|
|
||||||
@router.post("/projects/{project_id}/analyze-chapter/{chapter_id}")
|
@router.post("/projects/{project_id}/analyze-chapter/{chapter_id}")
|
||||||
async def analyze_chapter(
|
async def analyze_chapter(
|
||||||
project_id: str,
|
project_id: str,
|
||||||
@@ -30,7 +51,10 @@ async def analyze_chapter(
|
|||||||
对指定章节进行剧情分析,提取钩子、伏笔、情节点等,并存入记忆系统
|
对指定章节进行剧情分析,提取钩子、伏笔、情节点等,并存入记忆系统
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
user_id = request.state.user_id
|
user_id = getattr(request.state, 'user_id', None)
|
||||||
|
|
||||||
|
# 验证用户权限
|
||||||
|
await verify_project_access(project_id, user_id, db)
|
||||||
|
|
||||||
# 获取章节内容
|
# 获取章节内容
|
||||||
result = await db.execute(
|
result = await db.execute(
|
||||||
@@ -192,7 +216,10 @@ async def get_project_memories(
|
|||||||
):
|
):
|
||||||
"""获取项目的记忆列表"""
|
"""获取项目的记忆列表"""
|
||||||
try:
|
try:
|
||||||
user_id = request.state.user_id
|
user_id = getattr(request.state, 'user_id', None)
|
||||||
|
|
||||||
|
# 验证用户权限
|
||||||
|
await verify_project_access(project_id, user_id, db)
|
||||||
|
|
||||||
# 构建查询
|
# 构建查询
|
||||||
query = select(StoryMemory).where(StoryMemory.project_id == project_id)
|
query = select(StoryMemory).where(StoryMemory.project_id == project_id)
|
||||||
@@ -222,10 +249,16 @@ async def get_project_memories(
|
|||||||
async def get_chapter_analysis(
|
async def get_chapter_analysis(
|
||||||
project_id: str,
|
project_id: str,
|
||||||
chapter_id: str,
|
chapter_id: str,
|
||||||
|
request: Request,
|
||||||
db: AsyncSession = Depends(get_db)
|
db: AsyncSession = Depends(get_db)
|
||||||
):
|
):
|
||||||
"""获取章节的剧情分析"""
|
"""获取章节的剧情分析"""
|
||||||
try:
|
try:
|
||||||
|
user_id = getattr(request.state, 'user_id', None)
|
||||||
|
|
||||||
|
# 验证用户权限
|
||||||
|
await verify_project_access(project_id, user_id, db)
|
||||||
|
|
||||||
result = await db.execute(
|
result = await db.execute(
|
||||||
select(PlotAnalysis).where(
|
select(PlotAnalysis).where(
|
||||||
and_(
|
and_(
|
||||||
@@ -258,11 +291,15 @@ async def search_memories(
|
|||||||
query: str,
|
query: str,
|
||||||
memory_types: Optional[List[str]] = None,
|
memory_types: Optional[List[str]] = None,
|
||||||
limit: int = 10,
|
limit: int = 10,
|
||||||
min_importance: float = 0.0
|
min_importance: float = 0.0,
|
||||||
|
db: AsyncSession = Depends(get_db)
|
||||||
):
|
):
|
||||||
"""语义搜索项目记忆"""
|
"""语义搜索项目记忆"""
|
||||||
try:
|
try:
|
||||||
user_id = request.state.user_id
|
user_id = getattr(request.state, 'user_id', None)
|
||||||
|
|
||||||
|
# 验证用户权限
|
||||||
|
await verify_project_access(project_id, user_id, db)
|
||||||
|
|
||||||
memories = await memory_service.search_memories(
|
memories = await memory_service.search_memories(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
@@ -294,7 +331,10 @@ async def get_unresolved_foreshadows(
|
|||||||
):
|
):
|
||||||
"""获取未完结的伏笔"""
|
"""获取未完结的伏笔"""
|
||||||
try:
|
try:
|
||||||
user_id = request.state.user_id
|
user_id = getattr(request.state, 'user_id', None)
|
||||||
|
|
||||||
|
# 验证用户权限
|
||||||
|
await verify_project_access(project_id, user_id, db)
|
||||||
|
|
||||||
# 从向量库搜索
|
# 从向量库搜索
|
||||||
foreshadows = await memory_service.find_unresolved_foreshadows(
|
foreshadows = await memory_service.find_unresolved_foreshadows(
|
||||||
@@ -317,11 +357,15 @@ async def get_unresolved_foreshadows(
|
|||||||
@router.get("/projects/{project_id}/stats")
|
@router.get("/projects/{project_id}/stats")
|
||||||
async def get_memory_stats(
|
async def get_memory_stats(
|
||||||
project_id: str,
|
project_id: str,
|
||||||
request: Request
|
request: Request,
|
||||||
|
db: AsyncSession = Depends(get_db)
|
||||||
):
|
):
|
||||||
"""获取记忆统计信息"""
|
"""获取记忆统计信息"""
|
||||||
try:
|
try:
|
||||||
user_id = request.state.user_id
|
user_id = getattr(request.state, 'user_id', None)
|
||||||
|
|
||||||
|
# 验证用户权限
|
||||||
|
await verify_project_access(project_id, user_id, db)
|
||||||
|
|
||||||
stats = await memory_service.get_memory_stats(
|
stats = await memory_service.get_memory_stats(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
@@ -347,7 +391,10 @@ async def delete_chapter_memories(
|
|||||||
):
|
):
|
||||||
"""删除章节的所有记忆"""
|
"""删除章节的所有记忆"""
|
||||||
try:
|
try:
|
||||||
user_id = request.state.user_id
|
user_id = getattr(request.state, 'user_id', None)
|
||||||
|
|
||||||
|
# 验证用户权限
|
||||||
|
await verify_project_access(project_id, user_id, db)
|
||||||
|
|
||||||
# 从数据库删除
|
# 从数据库删除
|
||||||
result = await db.execute(
|
result = await db.execute(
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
"""组织管理API"""
|
"""组织管理API"""
|
||||||
from fastapi import APIRouter, Depends, HTTPException
|
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
from sqlalchemy import select, and_
|
from sqlalchemy import select, and_
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
@@ -31,6 +31,26 @@ router = APIRouter(prefix="/organizations", tags=["组织管理"])
|
|||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
async def verify_project_access(project_id: str, user_id: str, db: AsyncSession) -> Project:
|
||||||
|
"""验证用户是否有权访问指定项目"""
|
||||||
|
if not user_id:
|
||||||
|
raise HTTPException(status_code=401, detail="未登录")
|
||||||
|
|
||||||
|
result = await db.execute(
|
||||||
|
select(Project).where(
|
||||||
|
Project.id == project_id,
|
||||||
|
Project.user_id == user_id
|
||||||
|
)
|
||||||
|
)
|
||||||
|
project = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if not project:
|
||||||
|
logger.warning(f"项目访问被拒绝: project_id={project_id}, user_id={user_id}")
|
||||||
|
raise HTTPException(status_code=404, detail="项目不存在或无权访问")
|
||||||
|
|
||||||
|
return project
|
||||||
|
|
||||||
|
|
||||||
class OrganizationGenerateRequest(BaseModel):
|
class OrganizationGenerateRequest(BaseModel):
|
||||||
"""AI生成组织的请求模型"""
|
"""AI生成组织的请求模型"""
|
||||||
project_id: str = Field(..., description="项目ID")
|
project_id: str = Field(..., description="项目ID")
|
||||||
@@ -44,8 +64,13 @@ class OrganizationGenerateRequest(BaseModel):
|
|||||||
@router.get("/project/{project_id}", response_model=List[OrganizationDetailResponse], summary="获取项目的所有组织")
|
@router.get("/project/{project_id}", response_model=List[OrganizationDetailResponse], summary="获取项目的所有组织")
|
||||||
async def get_project_organizations(
|
async def get_project_organizations(
|
||||||
project_id: str,
|
project_id: str,
|
||||||
|
request: Request,
|
||||||
db: AsyncSession = Depends(get_db)
|
db: AsyncSession = Depends(get_db)
|
||||||
):
|
):
|
||||||
|
# 验证用户权限
|
||||||
|
user_id = getattr(request.state, 'user_id', None)
|
||||||
|
await verify_project_access(project_id, user_id, db)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
获取项目中的所有组织及其详情
|
获取项目中的所有组织及其详情
|
||||||
|
|
||||||
@@ -85,6 +110,7 @@ async def get_project_organizations(
|
|||||||
@router.get("/{org_id}", response_model=OrganizationResponse, summary="获取组织详情")
|
@router.get("/{org_id}", response_model=OrganizationResponse, summary="获取组织详情")
|
||||||
async def get_organization(
|
async def get_organization(
|
||||||
org_id: str,
|
org_id: str,
|
||||||
|
request: Request,
|
||||||
db: AsyncSession = Depends(get_db)
|
db: AsyncSession = Depends(get_db)
|
||||||
):
|
):
|
||||||
"""获取组织的详细信息"""
|
"""获取组织的详细信息"""
|
||||||
@@ -96,12 +122,17 @@ async def get_organization(
|
|||||||
if not org:
|
if not org:
|
||||||
raise HTTPException(status_code=404, detail="组织不存在")
|
raise HTTPException(status_code=404, detail="组织不存在")
|
||||||
|
|
||||||
|
# 验证用户权限
|
||||||
|
user_id = getattr(request.state, 'user_id', None)
|
||||||
|
await verify_project_access(org.project_id, user_id, db)
|
||||||
|
|
||||||
return org
|
return org
|
||||||
|
|
||||||
|
|
||||||
@router.post("/", response_model=OrganizationResponse, summary="创建组织")
|
@router.post("/", response_model=OrganizationResponse, summary="创建组织")
|
||||||
async def create_organization(
|
async def create_organization(
|
||||||
organization: OrganizationCreate,
|
organization: OrganizationCreate,
|
||||||
|
request: Request,
|
||||||
db: AsyncSession = Depends(get_db)
|
db: AsyncSession = Depends(get_db)
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@@ -110,6 +141,10 @@ async def create_organization(
|
|||||||
- 需要关联到一个已存在的角色记录(is_organization=True)
|
- 需要关联到一个已存在的角色记录(is_organization=True)
|
||||||
- 可以设置父组织、势力等级等属性
|
- 可以设置父组织、势力等级等属性
|
||||||
"""
|
"""
|
||||||
|
# 验证用户权限
|
||||||
|
user_id = getattr(request.state, 'user_id', None)
|
||||||
|
await verify_project_access(organization.project_id, user_id, db)
|
||||||
|
|
||||||
# 验证角色是否存在且是组织
|
# 验证角色是否存在且是组织
|
||||||
char_result = await db.execute(
|
char_result = await db.execute(
|
||||||
select(Character).where(Character.id == organization.character_id)
|
select(Character).where(Character.id == organization.character_id)
|
||||||
@@ -142,6 +177,7 @@ async def create_organization(
|
|||||||
async def update_organization(
|
async def update_organization(
|
||||||
org_id: str,
|
org_id: str,
|
||||||
organization: OrganizationUpdate,
|
organization: OrganizationUpdate,
|
||||||
|
request: Request,
|
||||||
db: AsyncSession = Depends(get_db)
|
db: AsyncSession = Depends(get_db)
|
||||||
):
|
):
|
||||||
"""更新组织的属性"""
|
"""更新组织的属性"""
|
||||||
@@ -153,6 +189,10 @@ async def update_organization(
|
|||||||
if not db_org:
|
if not db_org:
|
||||||
raise HTTPException(status_code=404, detail="组织不存在")
|
raise HTTPException(status_code=404, detail="组织不存在")
|
||||||
|
|
||||||
|
# 验证用户权限
|
||||||
|
user_id = getattr(request.state, 'user_id', None)
|
||||||
|
await verify_project_access(db_org.project_id, user_id, db)
|
||||||
|
|
||||||
# 更新字段
|
# 更新字段
|
||||||
update_data = organization.model_dump(exclude_unset=True)
|
update_data = organization.model_dump(exclude_unset=True)
|
||||||
for field, value in update_data.items():
|
for field, value in update_data.items():
|
||||||
@@ -168,6 +208,7 @@ async def update_organization(
|
|||||||
@router.delete("/{org_id}", summary="删除组织")
|
@router.delete("/{org_id}", summary="删除组织")
|
||||||
async def delete_organization(
|
async def delete_organization(
|
||||||
org_id: str,
|
org_id: str,
|
||||||
|
request: Request,
|
||||||
db: AsyncSession = Depends(get_db)
|
db: AsyncSession = Depends(get_db)
|
||||||
):
|
):
|
||||||
"""删除组织(会级联删除所有成员关系)"""
|
"""删除组织(会级联删除所有成员关系)"""
|
||||||
@@ -179,6 +220,10 @@ async def delete_organization(
|
|||||||
if not db_org:
|
if not db_org:
|
||||||
raise HTTPException(status_code=404, detail="组织不存在")
|
raise HTTPException(status_code=404, detail="组织不存在")
|
||||||
|
|
||||||
|
# 验证用户权限
|
||||||
|
user_id = getattr(request.state, 'user_id', None)
|
||||||
|
await verify_project_access(db_org.project_id, user_id, db)
|
||||||
|
|
||||||
await db.delete(db_org)
|
await db.delete(db_org)
|
||||||
await db.commit()
|
await db.commit()
|
||||||
|
|
||||||
@@ -191,6 +236,7 @@ async def delete_organization(
|
|||||||
@router.get("/{org_id}/members", response_model=List[OrganizationMemberDetailResponse], summary="获取组织成员")
|
@router.get("/{org_id}/members", response_model=List[OrganizationMemberDetailResponse], summary="获取组织成员")
|
||||||
async def get_organization_members(
|
async def get_organization_members(
|
||||||
org_id: str,
|
org_id: str,
|
||||||
|
request: Request,
|
||||||
db: AsyncSession = Depends(get_db)
|
db: AsyncSession = Depends(get_db)
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@@ -202,9 +248,14 @@ async def get_organization_members(
|
|||||||
org_result = await db.execute(
|
org_result = await db.execute(
|
||||||
select(Organization).where(Organization.id == org_id)
|
select(Organization).where(Organization.id == org_id)
|
||||||
)
|
)
|
||||||
if not org_result.scalar_one_or_none():
|
org = org_result.scalar_one_or_none()
|
||||||
|
if not org:
|
||||||
raise HTTPException(status_code=404, detail="组织不存在")
|
raise HTTPException(status_code=404, detail="组织不存在")
|
||||||
|
|
||||||
|
# 验证用户权限
|
||||||
|
user_id = getattr(request.state, 'user_id', None)
|
||||||
|
await verify_project_access(org.project_id, user_id, db)
|
||||||
|
|
||||||
# 获取成员列表
|
# 获取成员列表
|
||||||
result = await db.execute(
|
result = await db.execute(
|
||||||
select(OrganizationMember)
|
select(OrganizationMember)
|
||||||
@@ -244,6 +295,7 @@ async def get_organization_members(
|
|||||||
async def add_organization_member(
|
async def add_organization_member(
|
||||||
org_id: str,
|
org_id: str,
|
||||||
member: OrganizationMemberCreate,
|
member: OrganizationMemberCreate,
|
||||||
|
request: Request,
|
||||||
db: AsyncSession = Depends(get_db)
|
db: AsyncSession = Depends(get_db)
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@@ -260,6 +312,10 @@ async def add_organization_member(
|
|||||||
if not org:
|
if not org:
|
||||||
raise HTTPException(status_code=404, detail="组织不存在")
|
raise HTTPException(status_code=404, detail="组织不存在")
|
||||||
|
|
||||||
|
# 验证用户权限
|
||||||
|
user_id = getattr(request.state, 'user_id', None)
|
||||||
|
await verify_project_access(org.project_id, user_id, db)
|
||||||
|
|
||||||
# 验证角色存在
|
# 验证角色存在
|
||||||
char_result = await db.execute(
|
char_result = await db.execute(
|
||||||
select(Character).where(Character.id == member.character_id)
|
select(Character).where(Character.id == member.character_id)
|
||||||
@@ -304,6 +360,7 @@ async def add_organization_member(
|
|||||||
async def update_organization_member(
|
async def update_organization_member(
|
||||||
member_id: str,
|
member_id: str,
|
||||||
member: OrganizationMemberUpdate,
|
member: OrganizationMemberUpdate,
|
||||||
|
request: Request,
|
||||||
db: AsyncSession = Depends(get_db)
|
db: AsyncSession = Depends(get_db)
|
||||||
):
|
):
|
||||||
"""更新组织成员的职位、忠诚度等信息"""
|
"""更新组织成员的职位、忠诚度等信息"""
|
||||||
@@ -315,6 +372,14 @@ async def update_organization_member(
|
|||||||
if not db_member:
|
if not db_member:
|
||||||
raise HTTPException(status_code=404, detail="成员记录不存在")
|
raise HTTPException(status_code=404, detail="成员记录不存在")
|
||||||
|
|
||||||
|
# 通过成员所属的组织验证用户权限
|
||||||
|
org_result = await db.execute(
|
||||||
|
select(Organization).where(Organization.id == db_member.organization_id)
|
||||||
|
)
|
||||||
|
org = org_result.scalar_one()
|
||||||
|
user_id = getattr(request.state, 'user_id', None)
|
||||||
|
await verify_project_access(org.project_id, user_id, db)
|
||||||
|
|
||||||
# 更新字段
|
# 更新字段
|
||||||
update_data = member.model_dump(exclude_unset=True)
|
update_data = member.model_dump(exclude_unset=True)
|
||||||
for field, value in update_data.items():
|
for field, value in update_data.items():
|
||||||
@@ -330,6 +395,7 @@ async def update_organization_member(
|
|||||||
@router.delete("/members/{member_id}", summary="移除组织成员")
|
@router.delete("/members/{member_id}", summary="移除组织成员")
|
||||||
async def remove_organization_member(
|
async def remove_organization_member(
|
||||||
member_id: str,
|
member_id: str,
|
||||||
|
request: Request,
|
||||||
db: AsyncSession = Depends(get_db)
|
db: AsyncSession = Depends(get_db)
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@@ -350,6 +416,10 @@ async def remove_organization_member(
|
|||||||
select(Organization).where(Organization.id == db_member.organization_id)
|
select(Organization).where(Organization.id == db_member.organization_id)
|
||||||
)
|
)
|
||||||
org = org_result.scalar_one()
|
org = org_result.scalar_one()
|
||||||
|
|
||||||
|
# 验证用户权限
|
||||||
|
user_id = getattr(request.state, 'user_id', None)
|
||||||
|
await verify_project_access(org.project_id, user_id, db)
|
||||||
org.member_count = max(0, org.member_count - 1)
|
org.member_count = max(0, org.member_count - 1)
|
||||||
|
|
||||||
await db.delete(db_member)
|
await db.delete(db_member)
|
||||||
@@ -360,7 +430,8 @@ async def remove_organization_member(
|
|||||||
|
|
||||||
@router.post("/generate", response_model=CharacterResponse, summary="AI生成组织")
|
@router.post("/generate", response_model=CharacterResponse, summary="AI生成组织")
|
||||||
async def generate_organization(
|
async def generate_organization(
|
||||||
request: OrganizationGenerateRequest,
|
gen_request: OrganizationGenerateRequest,
|
||||||
|
http_request: Request,
|
||||||
db: AsyncSession = Depends(get_db),
|
db: AsyncSession = Depends(get_db),
|
||||||
user_ai_service: AIService = Depends(get_user_ai_service)
|
user_ai_service: AIService = Depends(get_user_ai_service)
|
||||||
):
|
):
|
||||||
@@ -372,19 +443,15 @@ async def generate_organization(
|
|||||||
|
|
||||||
生成内容包括:组织名称、类型、特性、背景、目的、势力等级等
|
生成内容包括:组织名称、类型、特性、背景、目的、势力等级等
|
||||||
"""
|
"""
|
||||||
# 验证项目是否存在并获取项目信息
|
# 验证用户权限
|
||||||
result = await db.execute(
|
user_id = getattr(http_request.state, 'user_id', None)
|
||||||
select(Project).where(Project.id == request.project_id)
|
project = await verify_project_access(gen_request.project_id, user_id, db)
|
||||||
)
|
|
||||||
project = result.scalar_one_or_none()
|
|
||||||
if not project:
|
|
||||||
raise HTTPException(status_code=404, detail="项目不存在")
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 获取已存在的角色和组织列表
|
# 获取已存在的角色和组织列表
|
||||||
existing_chars_result = await db.execute(
|
existing_chars_result = await db.execute(
|
||||||
select(Character)
|
select(Character)
|
||||||
.where(Character.project_id == request.project_id)
|
.where(Character.project_id == gen_request.project_id)
|
||||||
.order_by(Character.created_at.desc())
|
.order_by(Character.created_at.desc())
|
||||||
)
|
)
|
||||||
existing_characters = existing_chars_result.scalars().all()
|
existing_characters = existing_chars_result.scalars().all()
|
||||||
@@ -422,10 +489,10 @@ async def generate_organization(
|
|||||||
# 构建用户输入信息
|
# 构建用户输入信息
|
||||||
user_input = f"""
|
user_input = f"""
|
||||||
用户要求:
|
用户要求:
|
||||||
- 组织名称:{request.name or '请AI生成'}
|
- 组织名称:{gen_request.name or '请AI生成'}
|
||||||
- 组织类型:{request.organization_type or '请AI根据世界观决定'}
|
- 组织类型:{gen_request.organization_type or '请AI根据世界观决定'}
|
||||||
- 背景设定:{request.background or '无特殊要求'}
|
- 背景设定:{gen_request.background or '无特殊要求'}
|
||||||
- 其他要求:{request.requirements or '无'}
|
- 其他要求:{gen_request.requirements or '无'}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# 使用统一的提示词服务
|
# 使用统一的提示词服务
|
||||||
@@ -435,10 +502,10 @@ async def generate_organization(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# 调用AI生成组织
|
# 调用AI生成组织
|
||||||
logger.info(f"🎯 开始为项目 {request.project_id} 生成组织")
|
logger.info(f"🎯 开始为项目 {gen_request.project_id} 生成组织")
|
||||||
logger.info(f" - 组织名:{request.name or 'AI生成'}")
|
logger.info(f" - 组织名:{gen_request.name or 'AI生成'}")
|
||||||
logger.info(f" - 组织类型:{request.organization_type or 'AI决定'}")
|
logger.info(f" - 组织类型:{gen_request.organization_type or 'AI决定'}")
|
||||||
logger.info(f" - 背景设定:{request.background or '无'}")
|
logger.info(f" - 背景设定:{gen_request.background or '无'}")
|
||||||
logger.info(f" - AI提供商:{user_ai_service.api_provider}")
|
logger.info(f" - AI提供商:{user_ai_service.api_provider}")
|
||||||
logger.info(f" - AI模型:{user_ai_service.default_model}")
|
logger.info(f" - AI模型:{user_ai_service.default_model}")
|
||||||
logger.info(f" - Prompt长度:{len(prompt)} 字符")
|
logger.info(f" - Prompt长度:{len(prompt)} 字符")
|
||||||
@@ -492,8 +559,8 @@ async def generate_organization(
|
|||||||
|
|
||||||
# 创建角色记录(组织也是角色的一种)
|
# 创建角色记录(组织也是角色的一种)
|
||||||
character = Character(
|
character = Character(
|
||||||
project_id=request.project_id,
|
project_id=gen_request.project_id,
|
||||||
name=organization_data.get("name", request.name or "未命名组织"),
|
name=organization_data.get("name", gen_request.name or "未命名组织"),
|
||||||
is_organization=True,
|
is_organization=True,
|
||||||
role_type="supporting", # 组织通常作为配角
|
role_type="supporting", # 组织通常作为配角
|
||||||
personality=organization_data.get("personality", ""),
|
personality=organization_data.get("personality", ""),
|
||||||
@@ -518,7 +585,7 @@ async def generate_organization(
|
|||||||
# 自动创建Organization详情记录
|
# 自动创建Organization详情记录
|
||||||
organization = Organization(
|
organization = Organization(
|
||||||
character_id=character.id,
|
character_id=character.id,
|
||||||
project_id=request.project_id,
|
project_id=gen_request.project_id,
|
||||||
member_count=0,
|
member_count=0,
|
||||||
power_level=organization_data.get("power_level", 50),
|
power_level=organization_data.get("power_level", 50),
|
||||||
location=organization_data.get("location"),
|
location=organization_data.get("location"),
|
||||||
@@ -532,7 +599,7 @@ async def generate_organization(
|
|||||||
|
|
||||||
# 记录生成历史
|
# 记录生成历史
|
||||||
history = GenerationHistory(
|
history = GenerationHistory(
|
||||||
project_id=request.project_id,
|
project_id=gen_request.project_id,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
generated_content=ai_content,
|
generated_content=ai_content,
|
||||||
model=user_ai_service.default_model
|
model=user_ai_service.default_model
|
||||||
@@ -542,7 +609,7 @@ async def generate_organization(
|
|||||||
await db.commit()
|
await db.commit()
|
||||||
await db.refresh(character)
|
await db.refresh(character)
|
||||||
|
|
||||||
logger.info(f"🎉 成功为项目 {request.project_id} 生成组织: {character.name}")
|
logger.info(f"🎉 成功为项目 {gen_request.project_id} 生成组织: {character.name}")
|
||||||
|
|
||||||
return character
|
return character
|
||||||
|
|
||||||
|
|||||||
+81
-23
@@ -30,19 +30,49 @@ router = APIRouter(prefix="/outlines", tags=["大纲管理"])
|
|||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
async def verify_project_access(project_id: str, user_id: str, db: AsyncSession) -> Project:
|
||||||
|
"""
|
||||||
|
验证用户是否有权访问指定项目
|
||||||
|
|
||||||
|
Args:
|
||||||
|
project_id: 项目ID
|
||||||
|
user_id: 用户ID
|
||||||
|
db: 数据库会话
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Project: 项目对象
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
HTTPException: 401 未登录,404 项目不存在或无权访问
|
||||||
|
"""
|
||||||
|
if not user_id:
|
||||||
|
raise HTTPException(status_code=401, detail="未登录")
|
||||||
|
|
||||||
|
result = await db.execute(
|
||||||
|
select(Project).where(
|
||||||
|
Project.id == project_id,
|
||||||
|
Project.user_id == user_id
|
||||||
|
)
|
||||||
|
)
|
||||||
|
project = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if not project:
|
||||||
|
logger.warning(f"项目访问被拒绝: project_id={project_id}, user_id={user_id}")
|
||||||
|
raise HTTPException(status_code=404, detail="项目不存在或无权访问")
|
||||||
|
|
||||||
|
return project
|
||||||
|
|
||||||
|
|
||||||
@router.post("", response_model=OutlineResponse, summary="创建大纲")
|
@router.post("", response_model=OutlineResponse, summary="创建大纲")
|
||||||
async def create_outline(
|
async def create_outline(
|
||||||
outline: OutlineCreate,
|
outline: OutlineCreate,
|
||||||
|
request: Request,
|
||||||
db: AsyncSession = Depends(get_db)
|
db: AsyncSession = Depends(get_db)
|
||||||
):
|
):
|
||||||
"""创建新的章节大纲,同时创建对应的章节记录"""
|
"""创建新的章节大纲,同时创建对应的章节记录"""
|
||||||
# 验证项目是否存在
|
# 验证用户权限
|
||||||
result = await db.execute(
|
user_id = getattr(request.state, 'user_id', None)
|
||||||
select(Project).where(Project.id == outline.project_id)
|
await verify_project_access(outline.project_id, user_id, db)
|
||||||
)
|
|
||||||
project = result.scalar_one_or_none()
|
|
||||||
if not project:
|
|
||||||
raise HTTPException(status_code=404, detail="项目不存在")
|
|
||||||
|
|
||||||
# 创建大纲
|
# 创建大纲
|
||||||
db_outline = Outline(**outline.model_dump())
|
db_outline = Outline(**outline.model_dump())
|
||||||
@@ -66,9 +96,14 @@ async def create_outline(
|
|||||||
@router.get("", response_model=OutlineListResponse, summary="获取大纲列表")
|
@router.get("", response_model=OutlineListResponse, summary="获取大纲列表")
|
||||||
async def get_outlines(
|
async def get_outlines(
|
||||||
project_id: str,
|
project_id: str,
|
||||||
|
request: Request,
|
||||||
db: AsyncSession = Depends(get_db)
|
db: AsyncSession = Depends(get_db)
|
||||||
):
|
):
|
||||||
"""获取指定项目的所有大纲"""
|
"""获取指定项目的所有大纲"""
|
||||||
|
# 验证用户权限
|
||||||
|
user_id = getattr(request.state, 'user_id', None)
|
||||||
|
await verify_project_access(project_id, user_id, db)
|
||||||
|
|
||||||
# 获取总数
|
# 获取总数
|
||||||
count_result = await db.execute(
|
count_result = await db.execute(
|
||||||
select(func.count(Outline.id)).where(Outline.project_id == project_id)
|
select(func.count(Outline.id)).where(Outline.project_id == project_id)
|
||||||
@@ -89,9 +124,14 @@ async def get_outlines(
|
|||||||
@router.get("/project/{project_id}", response_model=OutlineListResponse, summary="获取项目的所有大纲")
|
@router.get("/project/{project_id}", response_model=OutlineListResponse, summary="获取项目的所有大纲")
|
||||||
async def get_project_outlines(
|
async def get_project_outlines(
|
||||||
project_id: str,
|
project_id: str,
|
||||||
|
request: Request,
|
||||||
db: AsyncSession = Depends(get_db)
|
db: AsyncSession = Depends(get_db)
|
||||||
):
|
):
|
||||||
"""获取指定项目的所有大纲(路径参数版本)"""
|
"""获取指定项目的所有大纲(路径参数版本)"""
|
||||||
|
# 验证用户权限
|
||||||
|
user_id = getattr(request.state, 'user_id', None)
|
||||||
|
await verify_project_access(project_id, user_id, db)
|
||||||
|
|
||||||
# 获取总数
|
# 获取总数
|
||||||
count_result = await db.execute(
|
count_result = await db.execute(
|
||||||
select(func.count(Outline.id)).where(Outline.project_id == project_id)
|
select(func.count(Outline.id)).where(Outline.project_id == project_id)
|
||||||
@@ -112,6 +152,7 @@ async def get_project_outlines(
|
|||||||
@router.get("/{outline_id}", response_model=OutlineResponse, summary="获取大纲详情")
|
@router.get("/{outline_id}", response_model=OutlineResponse, summary="获取大纲详情")
|
||||||
async def get_outline(
|
async def get_outline(
|
||||||
outline_id: str,
|
outline_id: str,
|
||||||
|
request: Request,
|
||||||
db: AsyncSession = Depends(get_db)
|
db: AsyncSession = Depends(get_db)
|
||||||
):
|
):
|
||||||
"""根据ID获取大纲详情"""
|
"""根据ID获取大纲详情"""
|
||||||
@@ -123,6 +164,10 @@ async def get_outline(
|
|||||||
if not outline:
|
if not outline:
|
||||||
raise HTTPException(status_code=404, detail="大纲不存在")
|
raise HTTPException(status_code=404, detail="大纲不存在")
|
||||||
|
|
||||||
|
# 验证用户权限
|
||||||
|
user_id = getattr(request.state, 'user_id', None)
|
||||||
|
await verify_project_access(outline.project_id, user_id, db)
|
||||||
|
|
||||||
return outline
|
return outline
|
||||||
|
|
||||||
|
|
||||||
@@ -130,6 +175,7 @@ async def get_outline(
|
|||||||
async def update_outline(
|
async def update_outline(
|
||||||
outline_id: str,
|
outline_id: str,
|
||||||
outline_update: OutlineUpdate,
|
outline_update: OutlineUpdate,
|
||||||
|
request: Request,
|
||||||
db: AsyncSession = Depends(get_db)
|
db: AsyncSession = Depends(get_db)
|
||||||
):
|
):
|
||||||
"""更新大纲信息,同步更新对应章节和structure字段"""
|
"""更新大纲信息,同步更新对应章节和structure字段"""
|
||||||
@@ -141,6 +187,10 @@ async def update_outline(
|
|||||||
if not outline:
|
if not outline:
|
||||||
raise HTTPException(status_code=404, detail="大纲不存在")
|
raise HTTPException(status_code=404, detail="大纲不存在")
|
||||||
|
|
||||||
|
# 验证用户权限
|
||||||
|
user_id = getattr(request.state, 'user_id', None)
|
||||||
|
await verify_project_access(outline.project_id, user_id, db)
|
||||||
|
|
||||||
# 更新字段
|
# 更新字段
|
||||||
update_data = outline_update.model_dump(exclude_unset=True)
|
update_data = outline_update.model_dump(exclude_unset=True)
|
||||||
for field, value in update_data.items():
|
for field, value in update_data.items():
|
||||||
@@ -196,6 +246,7 @@ async def update_outline(
|
|||||||
@router.delete("/{outline_id}", summary="删除大纲")
|
@router.delete("/{outline_id}", summary="删除大纲")
|
||||||
async def delete_outline(
|
async def delete_outline(
|
||||||
outline_id: str,
|
outline_id: str,
|
||||||
|
request: Request,
|
||||||
db: AsyncSession = Depends(get_db)
|
db: AsyncSession = Depends(get_db)
|
||||||
):
|
):
|
||||||
"""删除大纲,同步删除章节,并重新排序后续项"""
|
"""删除大纲,同步删除章节,并重新排序后续项"""
|
||||||
@@ -207,6 +258,10 @@ async def delete_outline(
|
|||||||
if not outline:
|
if not outline:
|
||||||
raise HTTPException(status_code=404, detail="大纲不存在")
|
raise HTTPException(status_code=404, detail="大纲不存在")
|
||||||
|
|
||||||
|
# 验证用户权限
|
||||||
|
user_id = getattr(request.state, 'user_id', None)
|
||||||
|
await verify_project_access(outline.project_id, user_id, db)
|
||||||
|
|
||||||
project_id = outline.project_id
|
project_id = outline.project_id
|
||||||
deleted_order = outline.order_index
|
deleted_order = outline.order_index
|
||||||
|
|
||||||
@@ -252,7 +307,8 @@ async def delete_outline(
|
|||||||
|
|
||||||
@router.post("/reorder", summary="批量重排序大纲")
|
@router.post("/reorder", summary="批量重排序大纲")
|
||||||
async def reorder_outlines(
|
async def reorder_outlines(
|
||||||
request: OutlineReorderRequest,
|
reorder_request: OutlineReorderRequest,
|
||||||
|
http_request: Request,
|
||||||
db: AsyncSession = Depends(get_db)
|
db: AsyncSession = Depends(get_db)
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@@ -261,10 +317,20 @@ async def reorder_outlines(
|
|||||||
策略:先收集所有变更,最后一次性提交,避免临时冲突
|
策略:先收集所有变更,最后一次性提交,避免临时冲突
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
|
# 验证用户权限(通过第一个大纲的project_id)
|
||||||
|
user_id = getattr(http_request.state, 'user_id', None)
|
||||||
|
if reorder_request.orders and len(reorder_request.orders) > 0:
|
||||||
|
first_outline_result = await db.execute(
|
||||||
|
select(Outline).where(Outline.id == reorder_request.orders[0].id)
|
||||||
|
)
|
||||||
|
first_outline = first_outline_result.scalar_one_or_none()
|
||||||
|
if first_outline:
|
||||||
|
await verify_project_access(first_outline.project_id, user_id, db)
|
||||||
|
|
||||||
# 第一步:收集所有大纲和对应的章节
|
# 第一步:收集所有大纲和对应的章节
|
||||||
outline_chapter_map = {} # {outline_id: (outline, chapter, old_order, new_order)}
|
outline_chapter_map = {} # {outline_id: (outline, chapter, old_order, new_order)}
|
||||||
|
|
||||||
for item in request.orders:
|
for item in reorder_request.orders:
|
||||||
outline_id = item.id
|
outline_id = item.id
|
||||||
new_order = item.order_index
|
new_order = item.order_index
|
||||||
|
|
||||||
@@ -341,13 +407,9 @@ async def generate_outline(
|
|||||||
- new: 强制全新生成
|
- new: 强制全新生成
|
||||||
- continue: 强制续写模式
|
- continue: 强制续写模式
|
||||||
"""
|
"""
|
||||||
# 验证项目是否存在
|
# 验证用户权限
|
||||||
result = await db.execute(
|
user_id = getattr(http_request.state, 'user_id', None)
|
||||||
select(Project).where(Project.id == request.project_id)
|
project = await verify_project_access(request.project_id, user_id, db)
|
||||||
)
|
|
||||||
project = result.scalar_one_or_none()
|
|
||||||
if not project:
|
|
||||||
raise HTTPException(status_code=404, detail="项目不存在")
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 获取现有大纲(强制从数据库获取最新数据,包括用户手动修改的内容)
|
# 获取现有大纲(强制从数据库获取最新数据,包括用户手动修改的内容)
|
||||||
@@ -1472,13 +1534,9 @@ async def generate_outline_stream(
|
|||||||
"model": "gpt-4" // 可选
|
"model": "gpt-4" // 可选
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
# 验证项目是否存在
|
# 验证用户权限
|
||||||
result = await db.execute(
|
user_id = getattr(request.state, 'user_id', None)
|
||||||
select(Project).where(Project.id == data.get("project_id"))
|
project = await verify_project_access(data.get("project_id"), user_id, db)
|
||||||
)
|
|
||||||
project = result.scalar_one_or_none()
|
|
||||||
if not project:
|
|
||||||
raise HTTPException(status_code=404, detail="项目不存在")
|
|
||||||
|
|
||||||
# 判断模式
|
# 判断模式
|
||||||
mode = data.get("mode", "auto")
|
mode = data.get("mode", "auto")
|
||||||
|
|||||||
+155
-41
@@ -41,17 +41,31 @@ router = APIRouter(prefix="/projects", tags=["项目管理"])
|
|||||||
@router.post("", response_model=ProjectResponse, summary="创建项目")
|
@router.post("", response_model=ProjectResponse, summary="创建项目")
|
||||||
async def create_project(
|
async def create_project(
|
||||||
project: ProjectCreate,
|
project: ProjectCreate,
|
||||||
db: AsyncSession = Depends(get_db)
|
db: AsyncSession = Depends(get_db),
|
||||||
|
request: Request = None
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
logger.info(f"创建新项目: {project.title}")
|
# 从认证中间件获取用户ID
|
||||||
db_project = Project(**project.model_dump())
|
user_id = getattr(request.state, 'user_id', None)
|
||||||
|
if not user_id:
|
||||||
|
logger.warning("未登录用户尝试创建项目")
|
||||||
|
raise HTTPException(status_code=401, detail="未登录")
|
||||||
|
|
||||||
|
logger.info(f"创建新项目: {project.title}, user_id={user_id}")
|
||||||
|
|
||||||
|
# 创建项目时自动设置user_id
|
||||||
|
project_data = project.model_dump()
|
||||||
|
project_data['user_id'] = user_id
|
||||||
|
db_project = Project(**project_data)
|
||||||
|
|
||||||
db.add(db_project)
|
db.add(db_project)
|
||||||
await db.commit()
|
await db.commit()
|
||||||
await db.refresh(db_project)
|
await db.refresh(db_project)
|
||||||
logger.info(f"项目创建成功: {db_project.id}")
|
logger.info(f"项目创建成功: project_id={db_project.id}, user_id={user_id}")
|
||||||
|
|
||||||
return db_project
|
return db_project
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"创建项目失败: {str(e)}", exc_info=True)
|
logger.error(f"创建项目失败: {str(e)}", exc_info=True)
|
||||||
raise
|
raise
|
||||||
@@ -61,24 +75,38 @@ async def create_project(
|
|||||||
async def get_projects(
|
async def get_projects(
|
||||||
skip: int = 0,
|
skip: int = 0,
|
||||||
limit: int = 100,
|
limit: int = 100,
|
||||||
db: AsyncSession = Depends(get_db)
|
db: AsyncSession = Depends(get_db),
|
||||||
|
request: Request = None
|
||||||
):
|
):
|
||||||
"""获取所有项目列表"""
|
"""获取当前用户的项目列表"""
|
||||||
try:
|
try:
|
||||||
logger.debug(f"获取项目列表: skip={skip}, limit={limit}")
|
# 从认证中间件获取用户ID
|
||||||
count_result = await db.execute(select(func.count(Project.id)))
|
user_id = getattr(request.state, 'user_id', None)
|
||||||
|
if not user_id:
|
||||||
|
logger.warning("未登录用户尝试获取项目列表")
|
||||||
|
raise HTTPException(status_code=401, detail="未登录")
|
||||||
|
|
||||||
|
logger.debug(f"获取项目列表: user_id={user_id}, skip={skip}, limit={limit}")
|
||||||
|
|
||||||
|
# 只查询当前用户的项目
|
||||||
|
count_result = await db.execute(
|
||||||
|
select(func.count(Project.id)).where(Project.user_id == user_id)
|
||||||
|
)
|
||||||
total = count_result.scalar_one()
|
total = count_result.scalar_one()
|
||||||
|
|
||||||
result = await db.execute(
|
result = await db.execute(
|
||||||
select(Project)
|
select(Project)
|
||||||
|
.where(Project.user_id == user_id)
|
||||||
.order_by(Project.updated_at.desc())
|
.order_by(Project.updated_at.desc())
|
||||||
.offset(skip)
|
.offset(skip)
|
||||||
.limit(limit)
|
.limit(limit)
|
||||||
)
|
)
|
||||||
projects = result.scalars().all()
|
projects = result.scalars().all()
|
||||||
logger.info(f"获取项目列表成功: 共{total}个项目")
|
logger.info(f"获取项目列表成功: user_id={user_id}, 共{total}个项目")
|
||||||
|
|
||||||
return ProjectListResponse(total=total, items=projects)
|
return ProjectListResponse(total=total, items=projects)
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"获取项目列表失败: {str(e)}", exc_info=True)
|
logger.error(f"获取项目列表失败: {str(e)}", exc_info=True)
|
||||||
raise
|
raise
|
||||||
@@ -87,17 +115,29 @@ async def get_projects(
|
|||||||
@router.get("/{project_id}", response_model=ProjectResponse, summary="获取项目详情")
|
@router.get("/{project_id}", response_model=ProjectResponse, summary="获取项目详情")
|
||||||
async def get_project(
|
async def get_project(
|
||||||
project_id: str,
|
project_id: str,
|
||||||
db: AsyncSession = Depends(get_db)
|
db: AsyncSession = Depends(get_db),
|
||||||
|
request: Request = None
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
logger.debug(f"获取项目详情: {project_id}")
|
# 从认证中间件获取用户ID
|
||||||
|
user_id = getattr(request.state, 'user_id', None)
|
||||||
|
if not user_id:
|
||||||
|
logger.warning("未登录用户尝试获取项目详情")
|
||||||
|
raise HTTPException(status_code=401, detail="未登录")
|
||||||
|
|
||||||
|
logger.debug(f"获取项目详情: project_id={project_id}, user_id={user_id}")
|
||||||
|
|
||||||
|
# 只查询当前用户的项目
|
||||||
result = await db.execute(
|
result = await db.execute(
|
||||||
select(Project).where(Project.id == project_id)
|
select(Project).where(
|
||||||
|
Project.id == project_id,
|
||||||
|
Project.user_id == user_id
|
||||||
|
)
|
||||||
)
|
)
|
||||||
project = result.scalar_one_or_none()
|
project = result.scalar_one_or_none()
|
||||||
|
|
||||||
if not project:
|
if not project:
|
||||||
logger.warning(f"项目不存在: {project_id}")
|
logger.warning(f"项目不存在或无权访问: project_id={project_id}, user_id={user_id}")
|
||||||
raise HTTPException(status_code=404, detail="项目不存在")
|
raise HTTPException(status_code=404, detail="项目不存在")
|
||||||
|
|
||||||
logger.info(f"获取项目详情成功: {project.title}")
|
logger.info(f"获取项目详情成功: {project.title}")
|
||||||
@@ -113,17 +153,29 @@ async def get_project(
|
|||||||
async def update_project(
|
async def update_project(
|
||||||
project_id: str,
|
project_id: str,
|
||||||
project_update: ProjectUpdate,
|
project_update: ProjectUpdate,
|
||||||
db: AsyncSession = Depends(get_db)
|
db: AsyncSession = Depends(get_db),
|
||||||
|
request: Request = None
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
logger.info(f"更新项目: {project_id}")
|
# 从认证中间件获取用户ID
|
||||||
|
user_id = getattr(request.state, 'user_id', None)
|
||||||
|
if not user_id:
|
||||||
|
logger.warning("未登录用户尝试更新项目")
|
||||||
|
raise HTTPException(status_code=401, detail="未登录")
|
||||||
|
|
||||||
|
logger.info(f"更新项目: project_id={project_id}, user_id={user_id}")
|
||||||
|
|
||||||
|
# 只查询当前用户的项目
|
||||||
result = await db.execute(
|
result = await db.execute(
|
||||||
select(Project).where(Project.id == project_id)
|
select(Project).where(
|
||||||
|
Project.id == project_id,
|
||||||
|
Project.user_id == user_id
|
||||||
|
)
|
||||||
)
|
)
|
||||||
project = result.scalar_one_or_none()
|
project = result.scalar_one_or_none()
|
||||||
|
|
||||||
if not project:
|
if not project:
|
||||||
logger.warning(f"项目不存在: {project_id}")
|
logger.warning(f"项目不存在或无权访问: project_id={project_id}, user_id={user_id}")
|
||||||
raise HTTPException(status_code=404, detail="项目不存在")
|
raise HTTPException(status_code=404, detail="项目不存在")
|
||||||
|
|
||||||
update_data = project_update.model_dump(exclude_unset=True)
|
update_data = project_update.model_dump(exclude_unset=True)
|
||||||
@@ -149,22 +201,30 @@ async def delete_project(
|
|||||||
db: AsyncSession = Depends(get_db)
|
db: AsyncSession = Depends(get_db)
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
logger.info(f"删除项目: {project_id}")
|
# 从认证中间件获取用户ID
|
||||||
|
user_id = getattr(request.state, 'user_id', None)
|
||||||
|
if not user_id:
|
||||||
|
logger.warning("未登录用户尝试删除项目")
|
||||||
|
raise HTTPException(status_code=401, detail="未登录")
|
||||||
|
|
||||||
|
logger.info(f"删除项目: project_id={project_id}, user_id={user_id}")
|
||||||
|
|
||||||
|
# 只查询当前用户的项目
|
||||||
result = await db.execute(
|
result = await db.execute(
|
||||||
select(Project).where(Project.id == project_id)
|
select(Project).where(
|
||||||
|
Project.id == project_id,
|
||||||
|
Project.user_id == user_id
|
||||||
|
)
|
||||||
)
|
)
|
||||||
project = result.scalar_one_or_none()
|
project = result.scalar_one_or_none()
|
||||||
|
|
||||||
if not project:
|
if not project:
|
||||||
logger.warning(f"项目不存在: {project_id}")
|
logger.warning(f"项目不存在或无权访问: project_id={project_id}, user_id={user_id}")
|
||||||
raise HTTPException(status_code=404, detail="项目不存在")
|
raise HTTPException(status_code=404, detail="项目不存在")
|
||||||
|
|
||||||
project_title = project.title
|
project_title = project.title
|
||||||
|
|
||||||
# 从认证中间件获取用户ID
|
# 删除向量数据库中的记忆(user_id已在上面获取)
|
||||||
user_id = getattr(request.state, 'user_id', None)
|
|
||||||
|
|
||||||
# 删除向量数据库中的记忆
|
|
||||||
if user_id:
|
if user_id:
|
||||||
try:
|
try:
|
||||||
await memory_service.delete_project_memories(user_id, project_id)
|
await memory_service.delete_project_memories(user_id, project_id)
|
||||||
@@ -234,22 +294,33 @@ async def delete_project(
|
|||||||
@router.get("/{project_id}/export", summary="导出项目章节为TXT")
|
@router.get("/{project_id}/export", summary="导出项目章节为TXT")
|
||||||
async def export_project_chapters(
|
async def export_project_chapters(
|
||||||
project_id: str,
|
project_id: str,
|
||||||
db: AsyncSession = Depends(get_db)
|
db: AsyncSession = Depends(get_db),
|
||||||
|
request: Request = None
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
导出项目的所有章节内容为TXT文本文件
|
导出项目的所有章节内容为TXT文本文件
|
||||||
按章节顺序组织,包含项目基本信息
|
按章节顺序组织,包含项目基本信息
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
logger.info(f"开始导出项目: {project_id}")
|
# 从认证中间件获取用户ID
|
||||||
|
user_id = getattr(request.state, 'user_id', None)
|
||||||
|
if not user_id:
|
||||||
|
logger.warning("未登录用户尝试导出项目")
|
||||||
|
raise HTTPException(status_code=401, detail="未登录")
|
||||||
|
|
||||||
|
logger.info(f"开始导出项目: project_id={project_id}, user_id={user_id}")
|
||||||
|
|
||||||
|
# 只查询当前用户的项目
|
||||||
result = await db.execute(
|
result = await db.execute(
|
||||||
select(Project).where(Project.id == project_id)
|
select(Project).where(
|
||||||
|
Project.id == project_id,
|
||||||
|
Project.user_id == user_id
|
||||||
|
)
|
||||||
)
|
)
|
||||||
project = result.scalar_one_or_none()
|
project = result.scalar_one_or_none()
|
||||||
|
|
||||||
if not project:
|
if not project:
|
||||||
logger.warning(f"项目不存在: {project_id}")
|
logger.warning(f"项目不存在或无权访问: project_id={project_id}, user_id={user_id}")
|
||||||
raise HTTPException(status_code=404, detail="项目不存在")
|
raise HTTPException(status_code=404, detail="项目不存在")
|
||||||
|
|
||||||
chapters_result = await db.execute(
|
chapters_result = await db.execute(
|
||||||
@@ -326,6 +397,7 @@ async def export_project_chapters(
|
|||||||
@router.post("/{project_id}/check-consistency", summary="检查数据一致性")
|
@router.post("/{project_id}/check-consistency", summary="检查数据一致性")
|
||||||
async def check_project_consistency(
|
async def check_project_consistency(
|
||||||
project_id: str,
|
project_id: str,
|
||||||
|
request: Request,
|
||||||
auto_fix: bool = True,
|
auto_fix: bool = True,
|
||||||
db: AsyncSession = Depends(get_db)
|
db: AsyncSession = Depends(get_db)
|
||||||
):
|
):
|
||||||
@@ -343,15 +415,25 @@ async def check_project_consistency(
|
|||||||
- organization_members: 验证组织成员数据完整性
|
- organization_members: 验证组织成员数据完整性
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
logger.info(f"开始数据一致性检查: {project_id}, auto_fix={auto_fix}")
|
# 从认证中间件获取用户ID
|
||||||
|
user_id = getattr(request.state, 'user_id', None)
|
||||||
|
if not user_id:
|
||||||
|
logger.warning("未登录用户尝试检查数据一致性")
|
||||||
|
raise HTTPException(status_code=401, detail="未登录")
|
||||||
|
|
||||||
|
logger.info(f"开始数据一致性检查: project_id={project_id}, user_id={user_id}, auto_fix={auto_fix}")
|
||||||
|
|
||||||
|
# 只查询当前用户的项目
|
||||||
result = await db.execute(
|
result = await db.execute(
|
||||||
select(Project).where(Project.id == project_id)
|
select(Project).where(
|
||||||
|
Project.id == project_id,
|
||||||
|
Project.user_id == user_id
|
||||||
|
)
|
||||||
)
|
)
|
||||||
project = result.scalar_one_or_none()
|
project = result.scalar_one_or_none()
|
||||||
|
|
||||||
if not project:
|
if not project:
|
||||||
logger.warning(f"项目不存在: {project_id}")
|
logger.warning(f"项目不存在或无权访问: project_id={project_id}, user_id={user_id}")
|
||||||
raise HTTPException(status_code=404, detail="项目不存在")
|
raise HTTPException(status_code=404, detail="项目不存在")
|
||||||
|
|
||||||
report = await run_full_data_consistency_check(project_id, db, auto_fix)
|
report = await run_full_data_consistency_check(project_id, db, auto_fix)
|
||||||
@@ -369,6 +451,7 @@ async def check_project_consistency(
|
|||||||
@router.post("/{project_id}/fix-organizations", summary="修复组织记录")
|
@router.post("/{project_id}/fix-organizations", summary="修复组织记录")
|
||||||
async def fix_project_organizations(
|
async def fix_project_organizations(
|
||||||
project_id: str,
|
project_id: str,
|
||||||
|
request: Request,
|
||||||
db: AsyncSession = Depends(get_db)
|
db: AsyncSession = Depends(get_db)
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@@ -377,15 +460,25 @@ async def fix_project_organizations(
|
|||||||
为所有is_organization=True但没有Organization记录的Character创建记录
|
为所有is_organization=True但没有Organization记录的Character创建记录
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
logger.info(f"开始修复组织记录: {project_id}")
|
# 从认证中间件获取用户ID
|
||||||
|
user_id = getattr(request.state, 'user_id', None)
|
||||||
|
if not user_id:
|
||||||
|
logger.warning("未登录用户尝试修复组织记录")
|
||||||
|
raise HTTPException(status_code=401, detail="未登录")
|
||||||
|
|
||||||
|
logger.info(f"开始修复组织记录: project_id={project_id}, user_id={user_id}")
|
||||||
|
|
||||||
|
# 只查询当前用户的项目
|
||||||
result = await db.execute(
|
result = await db.execute(
|
||||||
select(Project).where(Project.id == project_id)
|
select(Project).where(
|
||||||
|
Project.id == project_id,
|
||||||
|
Project.user_id == user_id
|
||||||
|
)
|
||||||
)
|
)
|
||||||
project = result.scalar_one_or_none()
|
project = result.scalar_one_or_none()
|
||||||
|
|
||||||
if not project:
|
if not project:
|
||||||
logger.warning(f"项目不存在: {project_id}")
|
logger.warning(f"项目不存在或无权访问: project_id={project_id}, user_id={user_id}")
|
||||||
raise HTTPException(status_code=404, detail="项目不存在")
|
raise HTTPException(status_code=404, detail="项目不存在")
|
||||||
|
|
||||||
fixed_count, total_count = await fix_missing_organization_records(project_id, db)
|
fixed_count, total_count = await fix_missing_organization_records(project_id, db)
|
||||||
@@ -407,6 +500,7 @@ async def fix_project_organizations(
|
|||||||
@router.post("/{project_id}/fix-member-counts", summary="修复成员计数")
|
@router.post("/{project_id}/fix-member-counts", summary="修复成员计数")
|
||||||
async def fix_project_member_counts(
|
async def fix_project_member_counts(
|
||||||
project_id: str,
|
project_id: str,
|
||||||
|
request: Request,
|
||||||
db: AsyncSession = Depends(get_db)
|
db: AsyncSession = Depends(get_db)
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@@ -415,15 +509,25 @@ async def fix_project_member_counts(
|
|||||||
从实际成员记录重新计算每个组织的member_count
|
从实际成员记录重新计算每个组织的member_count
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
logger.info(f"开始修复成员计数: {project_id}")
|
# 从认证中间件获取用户ID
|
||||||
|
user_id = getattr(request.state, 'user_id', None)
|
||||||
|
if not user_id:
|
||||||
|
logger.warning("未登录用户尝试修复成员计数")
|
||||||
|
raise HTTPException(status_code=401, detail="未登录")
|
||||||
|
|
||||||
|
logger.info(f"开始修复成员计数: project_id={project_id}, user_id={user_id}")
|
||||||
|
|
||||||
|
# 只查询当前用户的项目
|
||||||
result = await db.execute(
|
result = await db.execute(
|
||||||
select(Project).where(Project.id == project_id)
|
select(Project).where(
|
||||||
|
Project.id == project_id,
|
||||||
|
Project.user_id == user_id
|
||||||
|
)
|
||||||
)
|
)
|
||||||
project = result.scalar_one_or_none()
|
project = result.scalar_one_or_none()
|
||||||
|
|
||||||
if not project:
|
if not project:
|
||||||
logger.warning(f"项目不存在: {project_id}")
|
logger.warning(f"项目不存在或无权访问: project_id={project_id}, user_id={user_id}")
|
||||||
raise HTTPException(status_code=404, detail="项目不存在")
|
raise HTTPException(status_code=404, detail="项目不存在")
|
||||||
|
|
||||||
fixed_count, total_count = await fix_organization_member_counts(project_id, db)
|
fixed_count, total_count = await fix_organization_member_counts(project_id, db)
|
||||||
@@ -445,6 +549,7 @@ async def fix_project_member_counts(
|
|||||||
@router.post("/{project_id}/export-data", summary="导出项目数据为JSON")
|
@router.post("/{project_id}/export-data", summary="导出项目数据为JSON")
|
||||||
async def export_project_data(
|
async def export_project_data(
|
||||||
project_id: str,
|
project_id: str,
|
||||||
|
request: Request,
|
||||||
options: ExportOptions,
|
options: ExportOptions,
|
||||||
db: AsyncSession = Depends(get_db)
|
db: AsyncSession = Depends(get_db)
|
||||||
):
|
):
|
||||||
@@ -459,16 +564,25 @@ async def export_project_data(
|
|||||||
JSON文件下载
|
JSON文件下载
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
logger.info(f"开始导出项目数据: {project_id}")
|
# 从认证中间件获取用户ID
|
||||||
|
user_id = getattr(request.state, 'user_id', None)
|
||||||
|
if not user_id:
|
||||||
|
logger.warning("未登录用户尝试导出项目数据")
|
||||||
|
raise HTTPException(status_code=401, detail="未登录")
|
||||||
|
|
||||||
# 检查项目是否存在
|
logger.info(f"开始导出项目数据: project_id={project_id}, user_id={user_id}")
|
||||||
|
|
||||||
|
# 只查询当前用户的项目
|
||||||
result = await db.execute(
|
result = await db.execute(
|
||||||
select(Project).where(Project.id == project_id)
|
select(Project).where(
|
||||||
|
Project.id == project_id,
|
||||||
|
Project.user_id == user_id
|
||||||
|
)
|
||||||
)
|
)
|
||||||
project = result.scalar_one_or_none()
|
project = result.scalar_one_or_none()
|
||||||
|
|
||||||
if not project:
|
if not project:
|
||||||
logger.warning(f"项目不存在: {project_id}")
|
logger.warning(f"项目不存在或无权访问: project_id={project_id}, user_id={user_id}")
|
||||||
raise HTTPException(status_code=404, detail="项目不存在")
|
raise HTTPException(status_code=404, detail="项目不存在")
|
||||||
|
|
||||||
# 导出数据
|
# 导出数据
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
"""关系管理API"""
|
"""关系管理API"""
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
from fastapi import APIRouter, Depends, HTTPException, Query, Request
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
from sqlalchemy import select, or_, and_
|
from sqlalchemy import select, or_, and_
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
@@ -12,6 +12,7 @@ from app.models.relationship import (
|
|||||||
OrganizationMember
|
OrganizationMember
|
||||||
)
|
)
|
||||||
from app.models.character import Character
|
from app.models.character import Character
|
||||||
|
from app.models.project import Project
|
||||||
from app.schemas.relationship import (
|
from app.schemas.relationship import (
|
||||||
RelationshipTypeResponse,
|
RelationshipTypeResponse,
|
||||||
CharacterRelationshipCreate,
|
CharacterRelationshipCreate,
|
||||||
@@ -27,6 +28,26 @@ router = APIRouter(prefix="/relationships", tags=["关系管理"])
|
|||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
async def verify_project_access(project_id: str, user_id: str, db: AsyncSession) -> Project:
|
||||||
|
"""验证用户是否有权访问指定项目"""
|
||||||
|
if not user_id:
|
||||||
|
raise HTTPException(status_code=401, detail="未登录")
|
||||||
|
|
||||||
|
result = await db.execute(
|
||||||
|
select(Project).where(
|
||||||
|
Project.id == project_id,
|
||||||
|
Project.user_id == user_id
|
||||||
|
)
|
||||||
|
)
|
||||||
|
project = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if not project:
|
||||||
|
logger.warning(f"项目访问被拒绝: project_id={project_id}, user_id={user_id}")
|
||||||
|
raise HTTPException(status_code=404, detail="项目不存在或无权访问")
|
||||||
|
|
||||||
|
return project
|
||||||
|
|
||||||
|
|
||||||
@router.get("/types", response_model=List[RelationshipTypeResponse], summary="获取关系类型列表")
|
@router.get("/types", response_model=List[RelationshipTypeResponse], summary="获取关系类型列表")
|
||||||
async def get_relationship_types(db: AsyncSession = Depends(get_db)):
|
async def get_relationship_types(db: AsyncSession = Depends(get_db)):
|
||||||
"""获取所有预定义的关系类型"""
|
"""获取所有预定义的关系类型"""
|
||||||
@@ -38,9 +59,14 @@ async def get_relationship_types(db: AsyncSession = Depends(get_db)):
|
|||||||
@router.get("/project/{project_id}", response_model=List[CharacterRelationshipResponse], summary="获取项目的所有关系")
|
@router.get("/project/{project_id}", response_model=List[CharacterRelationshipResponse], summary="获取项目的所有关系")
|
||||||
async def get_project_relationships(
|
async def get_project_relationships(
|
||||||
project_id: str,
|
project_id: str,
|
||||||
|
request: Request,
|
||||||
character_id: Optional[str] = Query(None, description="筛选特定角色的关系"),
|
character_id: Optional[str] = Query(None, description="筛选特定角色的关系"),
|
||||||
db: AsyncSession = Depends(get_db)
|
db: AsyncSession = Depends(get_db)
|
||||||
):
|
):
|
||||||
|
# 验证用户权限
|
||||||
|
user_id = getattr(request.state, 'user_id', None)
|
||||||
|
await verify_project_access(project_id, user_id, db)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
获取项目中的所有角色关系
|
获取项目中的所有角色关系
|
||||||
|
|
||||||
@@ -70,8 +96,13 @@ async def get_project_relationships(
|
|||||||
@router.get("/graph/{project_id}", response_model=RelationshipGraphData, summary="获取关系图谱数据")
|
@router.get("/graph/{project_id}", response_model=RelationshipGraphData, summary="获取关系图谱数据")
|
||||||
async def get_relationship_graph(
|
async def get_relationship_graph(
|
||||||
project_id: str,
|
project_id: str,
|
||||||
|
request: Request,
|
||||||
db: AsyncSession = Depends(get_db)
|
db: AsyncSession = Depends(get_db)
|
||||||
):
|
):
|
||||||
|
# 验证用户权限
|
||||||
|
user_id = getattr(request.state, 'user_id', None)
|
||||||
|
await verify_project_access(project_id, user_id, db)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
获取用于可视化的关系图谱数据
|
获取用于可视化的关系图谱数据
|
||||||
|
|
||||||
@@ -122,6 +153,7 @@ async def get_relationship_graph(
|
|||||||
@router.post("/", response_model=CharacterRelationshipResponse, summary="创建角色关系")
|
@router.post("/", response_model=CharacterRelationshipResponse, summary="创建角色关系")
|
||||||
async def create_relationship(
|
async def create_relationship(
|
||||||
relationship: CharacterRelationshipCreate,
|
relationship: CharacterRelationshipCreate,
|
||||||
|
request: Request,
|
||||||
db: AsyncSession = Depends(get_db)
|
db: AsyncSession = Depends(get_db)
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@@ -131,6 +163,10 @@ async def create_relationship(
|
|||||||
- 可以指定预定义的关系类型或自定义关系名称
|
- 可以指定预定义的关系类型或自定义关系名称
|
||||||
- 可以设置亲密度、状态等属性
|
- 可以设置亲密度、状态等属性
|
||||||
"""
|
"""
|
||||||
|
# 验证用户权限
|
||||||
|
user_id = getattr(request.state, 'user_id', None)
|
||||||
|
await verify_project_access(relationship.project_id, user_id, db)
|
||||||
|
|
||||||
# 验证角色是否存在
|
# 验证角色是否存在
|
||||||
char_from = await db.execute(
|
char_from = await db.execute(
|
||||||
select(Character).where(Character.id == relationship.character_from_id)
|
select(Character).where(Character.id == relationship.character_from_id)
|
||||||
@@ -161,6 +197,7 @@ async def create_relationship(
|
|||||||
async def update_relationship(
|
async def update_relationship(
|
||||||
relationship_id: str,
|
relationship_id: str,
|
||||||
relationship: CharacterRelationshipUpdate,
|
relationship: CharacterRelationshipUpdate,
|
||||||
|
request: Request,
|
||||||
db: AsyncSession = Depends(get_db)
|
db: AsyncSession = Depends(get_db)
|
||||||
):
|
):
|
||||||
"""更新角色关系的属性(亲密度、状态等)"""
|
"""更新角色关系的属性(亲密度、状态等)"""
|
||||||
@@ -174,6 +211,10 @@ async def update_relationship(
|
|||||||
if not db_rel:
|
if not db_rel:
|
||||||
raise HTTPException(status_code=404, detail="关系不存在")
|
raise HTTPException(status_code=404, detail="关系不存在")
|
||||||
|
|
||||||
|
# 验证用户权限
|
||||||
|
user_id = getattr(request.state, 'user_id', None)
|
||||||
|
await verify_project_access(db_rel.project_id, user_id, db)
|
||||||
|
|
||||||
# 更新字段
|
# 更新字段
|
||||||
update_data = relationship.model_dump(exclude_unset=True)
|
update_data = relationship.model_dump(exclude_unset=True)
|
||||||
for field, value in update_data.items():
|
for field, value in update_data.items():
|
||||||
@@ -189,6 +230,7 @@ async def update_relationship(
|
|||||||
@router.delete("/{relationship_id}", summary="删除关系")
|
@router.delete("/{relationship_id}", summary="删除关系")
|
||||||
async def delete_relationship(
|
async def delete_relationship(
|
||||||
relationship_id: str,
|
relationship_id: str,
|
||||||
|
request: Request,
|
||||||
db: AsyncSession = Depends(get_db)
|
db: AsyncSession = Depends(get_db)
|
||||||
):
|
):
|
||||||
"""删除角色关系"""
|
"""删除角色关系"""
|
||||||
@@ -202,6 +244,10 @@ async def delete_relationship(
|
|||||||
if not db_rel:
|
if not db_rel:
|
||||||
raise HTTPException(status_code=404, detail="关系不存在")
|
raise HTTPException(status_code=404, detail="关系不存在")
|
||||||
|
|
||||||
|
# 验证用户权限
|
||||||
|
user_id = getattr(request.state, 'user_id', None)
|
||||||
|
await verify_project_access(db_rel.project_id, user_id, db)
|
||||||
|
|
||||||
await db.delete(db_rel)
|
await db.delete(db_rel)
|
||||||
await db.commit()
|
await db.commit()
|
||||||
|
|
||||||
|
|||||||
@@ -183,7 +183,13 @@ async def world_building_generator(
|
|||||||
# 保存到数据库
|
# 保存到数据库
|
||||||
yield await SSEResponse.send_progress("保存到数据库...", 90)
|
yield await SSEResponse.send_progress("保存到数据库...", 90)
|
||||||
|
|
||||||
|
# 确保user_id存在
|
||||||
|
if not user_id:
|
||||||
|
yield await SSEResponse.send_error("用户ID缺失,无法创建项目", 401)
|
||||||
|
return
|
||||||
|
|
||||||
project = Project(
|
project = Project(
|
||||||
|
user_id=user_id, # 添加user_id字段
|
||||||
title=title,
|
title=title,
|
||||||
description=description,
|
description=description,
|
||||||
theme=theme,
|
theme=theme,
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
"""写作风格管理 API"""
|
"""写作风格管理 API"""
|
||||||
from fastapi import APIRouter, Depends, HTTPException
|
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
from sqlalchemy import select, func, delete
|
from sqlalchemy import select, func, delete
|
||||||
from typing import List
|
from typing import List
|
||||||
@@ -16,8 +16,30 @@ from ..schemas.writing_style import (
|
|||||||
SetDefaultStyleRequest
|
SetDefaultStyleRequest
|
||||||
)
|
)
|
||||||
from ..services.prompt_service import WritingStyleManager
|
from ..services.prompt_service import WritingStyleManager
|
||||||
|
from ..logger import get_logger
|
||||||
|
|
||||||
router = APIRouter(prefix="/writing-styles", tags=["writing-styles"])
|
router = APIRouter(prefix="/writing-styles", tags=["writing-styles"])
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
async def verify_project_access(project_id: str, user_id: str, db: AsyncSession) -> Project:
|
||||||
|
"""验证用户是否有权访问指定项目"""
|
||||||
|
if not user_id:
|
||||||
|
raise HTTPException(status_code=401, detail="未登录")
|
||||||
|
|
||||||
|
result = await db.execute(
|
||||||
|
select(Project).where(
|
||||||
|
Project.id == project_id,
|
||||||
|
Project.user_id == user_id
|
||||||
|
)
|
||||||
|
)
|
||||||
|
project = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if not project:
|
||||||
|
logger.warning(f"项目访问被拒绝: project_id={project_id}, user_id={user_id}")
|
||||||
|
raise HTTPException(status_code=404, detail="项目不存在或无权访问")
|
||||||
|
|
||||||
|
return project
|
||||||
|
|
||||||
|
|
||||||
@router.get("/presets/list", response_model=List[dict])
|
@router.get("/presets/list", response_model=List[dict])
|
||||||
@@ -42,6 +64,7 @@ async def get_preset_styles():
|
|||||||
@router.post("", response_model=WritingStyleResponse, status_code=201)
|
@router.post("", response_model=WritingStyleResponse, status_code=201)
|
||||||
async def create_writing_style(
|
async def create_writing_style(
|
||||||
style_data: WritingStyleCreate,
|
style_data: WritingStyleCreate,
|
||||||
|
request: Request,
|
||||||
db: AsyncSession = Depends(get_db)
|
db: AsyncSession = Depends(get_db)
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@@ -50,13 +73,9 @@ async def create_writing_style(
|
|||||||
- **基于预设创建**:提供 preset_id,系统会自动填充预设内容
|
- **基于预设创建**:提供 preset_id,系统会自动填充预设内容
|
||||||
- **完全自定义**:不提供 preset_id,需要手动填写所有字段
|
- **完全自定义**:不提供 preset_id,需要手动填写所有字段
|
||||||
"""
|
"""
|
||||||
# 验证项目是否存在
|
# 验证用户权限
|
||||||
result = await db.execute(
|
user_id = getattr(request.state, 'user_id', None)
|
||||||
select(Project).where(Project.id == style_data.project_id)
|
await verify_project_access(style_data.project_id, user_id, db)
|
||||||
)
|
|
||||||
project = result.scalar_one_or_none()
|
|
||||||
if not project:
|
|
||||||
raise HTTPException(status_code=404, detail="项目不存在")
|
|
||||||
|
|
||||||
# 如果基于预设创建,获取预设内容
|
# 如果基于预设创建,获取预设内容
|
||||||
if style_data.preset_id:
|
if style_data.preset_id:
|
||||||
@@ -120,6 +139,7 @@ async def create_writing_style(
|
|||||||
@router.get("/project/{project_id}", response_model=WritingStyleListResponse)
|
@router.get("/project/{project_id}", response_model=WritingStyleListResponse)
|
||||||
async def get_project_styles(
|
async def get_project_styles(
|
||||||
project_id: str,
|
project_id: str,
|
||||||
|
request: Request,
|
||||||
db: AsyncSession = Depends(get_db)
|
db: AsyncSession = Depends(get_db)
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@@ -128,13 +148,9 @@ async def get_project_styles(
|
|||||||
返回:全局预设风格 + 该项目的自定义风格
|
返回:全局预设风格 + 该项目的自定义风格
|
||||||
按 order_index 排序,并标记哪个是当前项目的默认风格
|
按 order_index 排序,并标记哪个是当前项目的默认风格
|
||||||
"""
|
"""
|
||||||
# 验证项目是否存在
|
# 验证用户权限
|
||||||
result = await db.execute(
|
user_id = getattr(request.state, 'user_id', None)
|
||||||
select(Project).where(Project.id == project_id)
|
await verify_project_access(project_id, user_id, db)
|
||||||
)
|
|
||||||
project = result.scalar_one_or_none()
|
|
||||||
if not project:
|
|
||||||
raise HTTPException(status_code=404, detail="项目不存在")
|
|
||||||
|
|
||||||
# 获取该项目的默认风格ID
|
# 获取该项目的默认风格ID
|
||||||
result = await db.execute(
|
result = await db.execute(
|
||||||
@@ -222,6 +238,7 @@ async def get_writing_style(
|
|||||||
async def update_writing_style(
|
async def update_writing_style(
|
||||||
style_id: int,
|
style_id: int,
|
||||||
style_data: WritingStyleUpdate,
|
style_data: WritingStyleUpdate,
|
||||||
|
request: Request,
|
||||||
db: AsyncSession = Depends(get_db)
|
db: AsyncSession = Depends(get_db)
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@@ -241,6 +258,10 @@ async def update_writing_style(
|
|||||||
if style.project_id is None:
|
if style.project_id is None:
|
||||||
raise HTTPException(status_code=403, detail="不能修改全局预设风格,只能修改自定义风格")
|
raise HTTPException(status_code=403, detail="不能修改全局预设风格,只能修改自定义风格")
|
||||||
|
|
||||||
|
# 验证用户权限
|
||||||
|
user_id = getattr(request.state, 'user_id', None)
|
||||||
|
await verify_project_access(style.project_id, user_id, db)
|
||||||
|
|
||||||
# 更新字段
|
# 更新字段
|
||||||
update_data = style_data.model_dump(exclude_unset=True)
|
update_data = style_data.model_dump(exclude_unset=True)
|
||||||
|
|
||||||
@@ -279,6 +300,7 @@ async def update_writing_style(
|
|||||||
@router.delete("/{style_id}", status_code=204)
|
@router.delete("/{style_id}", status_code=204)
|
||||||
async def delete_writing_style(
|
async def delete_writing_style(
|
||||||
style_id: int,
|
style_id: int,
|
||||||
|
request: Request,
|
||||||
db: AsyncSession = Depends(get_db)
|
db: AsyncSession = Depends(get_db)
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@@ -300,6 +322,10 @@ async def delete_writing_style(
|
|||||||
if style.project_id is None:
|
if style.project_id is None:
|
||||||
raise HTTPException(status_code=403, detail="不能删除全局预设风格,只能删除自定义风格")
|
raise HTTPException(status_code=403, detail="不能删除全局预设风格,只能删除自定义风格")
|
||||||
|
|
||||||
|
# 验证用户权限
|
||||||
|
user_id = getattr(request.state, 'user_id', None)
|
||||||
|
await verify_project_access(style.project_id, user_id, db)
|
||||||
|
|
||||||
# 检查是否有项目将其设置为默认风格
|
# 检查是否有项目将其设置为默认风格
|
||||||
result = await db.execute(
|
result = await db.execute(
|
||||||
select(ProjectDefaultStyle).where(ProjectDefaultStyle.style_id == style_id)
|
select(ProjectDefaultStyle).where(ProjectDefaultStyle.style_id == style_id)
|
||||||
@@ -321,6 +347,7 @@ async def delete_writing_style(
|
|||||||
async def set_default_style(
|
async def set_default_style(
|
||||||
style_id: int,
|
style_id: int,
|
||||||
request_data: SetDefaultStyleRequest,
|
request_data: SetDefaultStyleRequest,
|
||||||
|
request: Request,
|
||||||
db: AsyncSession = Depends(get_db)
|
db: AsyncSession = Depends(get_db)
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@@ -335,13 +362,9 @@ async def set_default_style(
|
|||||||
"""
|
"""
|
||||||
project_id = request_data.project_id
|
project_id = request_data.project_id
|
||||||
|
|
||||||
# 验证项目是否存在
|
# 验证用户权限
|
||||||
result = await db.execute(
|
user_id = getattr(request.state, 'user_id', None)
|
||||||
select(Project).where(Project.id == project_id)
|
await verify_project_access(project_id, user_id, db)
|
||||||
)
|
|
||||||
project = result.scalar_one_or_none()
|
|
||||||
if not project:
|
|
||||||
raise HTTPException(status_code=404, detail="项目不存在")
|
|
||||||
|
|
||||||
# 验证风格是否存在
|
# 验证风格是否存在
|
||||||
result = await db.execute(
|
result = await db.execute(
|
||||||
@@ -379,6 +402,7 @@ async def set_default_style(
|
|||||||
@router.post("/project/{project_id}/init-defaults", response_model=WritingStyleListResponse)
|
@router.post("/project/{project_id}/init-defaults", response_model=WritingStyleListResponse)
|
||||||
async def initialize_default_styles(
|
async def initialize_default_styles(
|
||||||
project_id: str,
|
project_id: str,
|
||||||
|
request: Request,
|
||||||
db: AsyncSession = Depends(get_db)
|
db: AsyncSession = Depends(get_db)
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@@ -387,13 +411,9 @@ async def initialize_default_styles(
|
|||||||
新架构下,预设风格是全局的,不需要为每个项目单独初始化
|
新架构下,预设风格是全局的,不需要为每个项目单独初始化
|
||||||
该接口保留用于兼容性,直接返回项目可用的所有风格
|
该接口保留用于兼容性,直接返回项目可用的所有风格
|
||||||
"""
|
"""
|
||||||
# 验证项目是否存在
|
# 验证用户权限
|
||||||
result = await db.execute(
|
user_id = getattr(request.state, 'user_id', None)
|
||||||
select(Project).where(Project.id == project_id)
|
await verify_project_access(project_id, user_id, db)
|
||||||
)
|
|
||||||
project = result.scalar_one_or_none()
|
|
||||||
if not project:
|
|
||||||
raise HTTPException(status_code=404, detail="项目不存在")
|
|
||||||
|
|
||||||
# 直接返回项目可用的所有风格(全局预设 + 项目自定义)
|
# 直接返回项目可用的所有风格(全局预设 + 项目自定义)
|
||||||
return await get_project_styles(project_id, db)
|
return await get_project_styles(project_id, request, db)
|
||||||
+31
-6
@@ -3,6 +3,7 @@ from pydantic_settings import BaseSettings
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
|
|
||||||
# 获取项目根目录(从backend/app/config.py向上两级)
|
# 获取项目根目录(从backend/app/config.py向上两级)
|
||||||
PROJECT_ROOT = Path(__file__).parent.parent
|
PROJECT_ROOT = Path(__file__).parent.parent
|
||||||
@@ -12,13 +13,15 @@ DATA_DIR.mkdir(exist_ok=True)
|
|||||||
# 配置模块使用标准logging(在logger.py初始化之前)
|
# 配置模块使用标准logging(在logger.py初始化之前)
|
||||||
config_logger = logging.getLogger(__name__)
|
config_logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# 数据库文件路径(绝对路径)
|
# 数据库配置:支持PostgreSQL和SQLite
|
||||||
|
# 优先使用环境变量DATABASE_URL,否则使用SQLite
|
||||||
DB_FILE = DATA_DIR / "ai_story.db"
|
DB_FILE = DATA_DIR / "ai_story.db"
|
||||||
|
DEFAULT_SQLITE_URL = f"sqlite+aiosqlite:///{str(DB_FILE.absolute()).replace(chr(92), '/')}"
|
||||||
|
|
||||||
# 生成数据库URL(在类外部生成,确保使用绝对路径)
|
# 从环境变量获取数据库URL,如果未设置则使用SQLite
|
||||||
# 将Windows反斜杠转换为正斜杠,SQLite URL格式要求
|
DATABASE_URL = os.getenv("DATABASE_URL", DEFAULT_SQLITE_URL)
|
||||||
DATABASE_URL = f"sqlite+aiosqlite:///{str(DB_FILE.absolute()).replace(chr(92), '/')}"
|
|
||||||
config_logger.debug(f"数据库文件路径: {DB_FILE}")
|
config_logger.debug(f"数据库类型: {'PostgreSQL' if 'postgresql' in DATABASE_URL else 'SQLite'}")
|
||||||
config_logger.debug(f"数据库URL: {DATABASE_URL}")
|
config_logger.debug(f"数据库URL: {DATABASE_URL}")
|
||||||
|
|
||||||
class Settings(BaseSettings):
|
class Settings(BaseSettings):
|
||||||
@@ -41,9 +44,31 @@ class Settings(BaseSettings):
|
|||||||
# CORS配置
|
# CORS配置
|
||||||
cors_origins: list[str] = ["http://localhost:8000", "http://127.0.0.1:8000"]
|
cors_origins: list[str] = ["http://localhost:8000", "http://127.0.0.1:8000"]
|
||||||
|
|
||||||
# 数据库配置 - 使用预先计算好的绝对路径URL
|
# 数据库配置 - 支持PostgreSQL和SQLite
|
||||||
database_url: str = DATABASE_URL
|
database_url: str = DATABASE_URL
|
||||||
|
|
||||||
|
# PostgreSQL连接池配置(优化后支持80-150并发用户)
|
||||||
|
database_pool_size: int = 30 # 核心连接池大小(从20提升到30)
|
||||||
|
database_max_overflow: int = 20 # 最大溢出连接数(从10提升到20)
|
||||||
|
database_pool_timeout: int = 60 # 连接池超时秒数(从30提升到60)
|
||||||
|
database_pool_recycle: int = 1800 # 连接回收时间秒数(从3600降低到1800,30分钟)
|
||||||
|
database_pool_pre_ping: bool = True # 连接前ping检测,确保连接有效
|
||||||
|
database_pool_use_lifo: bool = True # 使用LIFO策略提高连接复用率
|
||||||
|
|
||||||
|
# 会话监控配置
|
||||||
|
database_session_max_active: int = 50 # 活跃会话警告阈值(从100降低到50)
|
||||||
|
database_session_leak_threshold: int = 100 # 会话泄漏严重告警阈值
|
||||||
|
|
||||||
|
# SQLite优化配置
|
||||||
|
sqlite_cache_size_mb: int = 128 # SQLite缓存大小MB(从64提升到128)
|
||||||
|
sqlite_mmap_size_mb: int = 256 # 内存映射I/O大小MB
|
||||||
|
sqlite_wal_autocheckpoint: int = 1000 # WAL自动检查点间隔
|
||||||
|
|
||||||
|
# 数据库监控配置
|
||||||
|
database_enable_slow_query_log: bool = True # 启用慢查询日志
|
||||||
|
database_slow_query_threshold: float = 1.0 # 慢查询阈值(秒)
|
||||||
|
database_enable_metrics: bool = True # 启用性能指标收集
|
||||||
|
|
||||||
# AI服务配置
|
# AI服务配置
|
||||||
openai_api_key: Optional[str] = None
|
openai_api_key: Optional[str] = None
|
||||||
openai_base_url: Optional[str] = None
|
openai_base_url: Optional[str] = None
|
||||||
|
|||||||
+219
-7
@@ -45,12 +45,59 @@ _session_stats = {
|
|||||||
async def get_engine(user_id: str):
|
async def get_engine(user_id: str):
|
||||||
"""获取或创建用户专属的数据库引擎(线程安全)
|
"""获取或创建用户专属的数据库引擎(线程安全)
|
||||||
|
|
||||||
|
支持PostgreSQL和SQLite两种数据库:
|
||||||
|
- PostgreSQL: 所有用户共享一个数据库,通过user_id字段隔离数据
|
||||||
|
- SQLite: 每个用户一个独立的数据库文件
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_id: 用户ID
|
user_id: 用户ID
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
用户专属的异步引擎
|
用户专属的异步引擎
|
||||||
"""
|
"""
|
||||||
|
# PostgreSQL模式:所有用户共享同一个引擎
|
||||||
|
if "postgresql" in settings.database_url:
|
||||||
|
cache_key = "shared_postgres"
|
||||||
|
if cache_key in _engine_cache:
|
||||||
|
return _engine_cache[cache_key]
|
||||||
|
|
||||||
|
async with _cache_lock:
|
||||||
|
if cache_key not in _engine_cache:
|
||||||
|
# 优化后的PostgreSQL连接配置
|
||||||
|
connect_args = {
|
||||||
|
"server_settings": {
|
||||||
|
"application_name": settings.app_name,
|
||||||
|
"jit": "off", # 关闭JIT以提高短查询性能
|
||||||
|
},
|
||||||
|
"command_timeout": 60, # 命令超时60秒
|
||||||
|
"statement_cache_size": 500, # 启用语句缓存,提升重复查询性能
|
||||||
|
}
|
||||||
|
|
||||||
|
engine = create_async_engine(
|
||||||
|
settings.database_url,
|
||||||
|
echo=False, # 生产环境关闭SQL日志
|
||||||
|
future=True,
|
||||||
|
pool_size=settings.database_pool_size, # 核心连接数:30
|
||||||
|
max_overflow=settings.database_max_overflow, # 溢出连接数:20
|
||||||
|
pool_timeout=settings.database_pool_timeout, # 连接超时:60秒
|
||||||
|
pool_pre_ping=settings.database_pool_pre_ping, # 连接前检测
|
||||||
|
pool_recycle=settings.database_pool_recycle, # 连接回收:1800秒
|
||||||
|
pool_use_lifo=settings.database_pool_use_lifo, # LIFO策略提高复用
|
||||||
|
connect_args=connect_args
|
||||||
|
)
|
||||||
|
_engine_cache[cache_key] = engine
|
||||||
|
logger.info(
|
||||||
|
f"✅ PostgreSQL引擎已创建(优化配置)\n"
|
||||||
|
f" ├─ 连接池: {settings.database_pool_size} 核心 + {settings.database_max_overflow} 溢出 = {settings.database_pool_size + settings.database_max_overflow} 总连接\n"
|
||||||
|
f" ├─ 超时: {settings.database_pool_timeout}秒\n"
|
||||||
|
f" ├─ 回收: {settings.database_pool_recycle}秒\n"
|
||||||
|
f" ├─ 策略: LIFO(提高复用率)\n"
|
||||||
|
f" └─ 预估并发: 80-150用户"
|
||||||
|
)
|
||||||
|
|
||||||
|
return _engine_cache[cache_key]
|
||||||
|
|
||||||
|
# SQLite模式:每个用户独立的数据库文件
|
||||||
if user_id in _engine_cache:
|
if user_id in _engine_cache:
|
||||||
return _engine_cache[user_id]
|
return _engine_cache[user_id]
|
||||||
|
|
||||||
@@ -76,18 +123,30 @@ async def get_engine(user_id: str):
|
|||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
# 应用优化后的SQLite配置
|
||||||
|
cache_size = -1024 * settings.sqlite_cache_size_mb # 负数表示KB单位
|
||||||
|
mmap_size = settings.sqlite_mmap_size_mb * 1024 * 1024 # 转换为字节
|
||||||
|
|
||||||
async with engine.begin() as conn:
|
async with engine.begin() as conn:
|
||||||
await conn.execute(text("PRAGMA journal_mode=WAL"))
|
await conn.execute(text("PRAGMA journal_mode=WAL"))
|
||||||
await conn.execute(text("PRAGMA synchronous=NORMAL"))
|
await conn.execute(text("PRAGMA synchronous=NORMAL"))
|
||||||
await conn.execute(text("PRAGMA cache_size=-64000"))
|
await conn.execute(text(f"PRAGMA cache_size={cache_size}")) # 128MB缓存
|
||||||
|
await conn.execute(text(f"PRAGMA mmap_size={mmap_size}")) # 256MB内存映射
|
||||||
await conn.execute(text("PRAGMA temp_store=MEMORY"))
|
await conn.execute(text("PRAGMA temp_store=MEMORY"))
|
||||||
await conn.execute(text("PRAGMA busy_timeout=5000"))
|
await conn.execute(text("PRAGMA busy_timeout=5000"))
|
||||||
|
await conn.execute(text(f"PRAGMA wal_autocheckpoint={settings.sqlite_wal_autocheckpoint}"))
|
||||||
|
|
||||||
logger.info(f"✅ 用户 {user_id} 的数据库已优化(WAL模式 + 64MB缓存)")
|
logger.info(
|
||||||
|
f"✅ 用户 {user_id} 的SQLite数据库已优化\n"
|
||||||
|
f" ├─ WAL模式\n"
|
||||||
|
f" ├─ 缓存: {settings.sqlite_cache_size_mb}MB\n"
|
||||||
|
f" ├─ 内存映射: {settings.sqlite_mmap_size_mb}MB\n"
|
||||||
|
f" └─ 预估并发: 15-20写入用户"
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"⚠️ 用户 {user_id} 数据库优化失败: {str(e)}")
|
logger.warning(f"⚠️ 用户 {user_id} SQLite数据库优化失败: {str(e)}")
|
||||||
_engine_cache[user_id] = engine
|
_engine_cache[user_id] = engine
|
||||||
logger.info(f"为用户 {user_id} 创建数据库引擎")
|
logger.info(f"为用户 {user_id} 创建SQLite数据库引擎")
|
||||||
|
|
||||||
return _engine_cache[user_id]
|
return _engine_cache[user_id]
|
||||||
|
|
||||||
@@ -157,8 +216,11 @@ async def get_db(request: Request):
|
|||||||
|
|
||||||
logger.debug(f"📊 会话关闭 [User:{user_id}][ID:{session_id}] - 活跃:{_session_stats['active']}, 总创建:{_session_stats['created']}, 总关闭:{_session_stats['closed']}, 错误:{_session_stats['errors']}")
|
logger.debug(f"📊 会话关闭 [User:{user_id}][ID:{session_id}] - 活跃:{_session_stats['active']}, 总创建:{_session_stats['created']}, 总关闭:{_session_stats['closed']}, 错误:{_session_stats['errors']}")
|
||||||
|
|
||||||
if _session_stats["active"] > 100:
|
# 使用优化后的会话监控阈值
|
||||||
logger.warning(f"🚨 活跃会话数过多: {_session_stats['active']},可能存在连接泄漏!")
|
if _session_stats["active"] > settings.database_session_leak_threshold:
|
||||||
|
logger.error(f"🚨 严重告警:活跃会话数 {_session_stats['active']} 超过泄漏阈值 {settings.database_session_leak_threshold}!")
|
||||||
|
elif _session_stats["active"] > settings.database_session_max_active:
|
||||||
|
logger.warning(f"⚠️ 警告:活跃会话数 {_session_stats['active']} 超过警告阈值 {settings.database_session_max_active},可能存在连接泄漏!")
|
||||||
elif _session_stats["active"] < 0:
|
elif _session_stats["active"] < 0:
|
||||||
logger.error(f"🚨 活跃会话数异常: {_session_stats['active']},统计可能不准确!")
|
logger.error(f"🚨 活跃会话数异常: {_session_stats['active']},统计可能不准确!")
|
||||||
|
|
||||||
@@ -324,4 +386,154 @@ async def close_db():
|
|||||||
logger.info("所有数据库连接已关闭")
|
logger.info("所有数据库连接已关闭")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"关闭数据库连接失败: {str(e)}", exc_info=True)
|
logger.error(f"关闭数据库连接失败: {str(e)}", exc_info=True)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
async def get_database_stats():
|
||||||
|
"""获取数据库连接和会话统计信息
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: 包含数据库统计信息的字典
|
||||||
|
"""
|
||||||
|
from app.config import settings
|
||||||
|
|
||||||
|
stats = {
|
||||||
|
"session_stats": {
|
||||||
|
"created": _session_stats["created"],
|
||||||
|
"closed": _session_stats["closed"],
|
||||||
|
"active": _session_stats["active"],
|
||||||
|
"errors": _session_stats["errors"],
|
||||||
|
"generator_exits": _session_stats["generator_exits"],
|
||||||
|
"last_check": _session_stats["last_check"],
|
||||||
|
},
|
||||||
|
"engine_cache": {
|
||||||
|
"total_engines": len(_engine_cache),
|
||||||
|
"engine_keys": list(_engine_cache.keys()),
|
||||||
|
},
|
||||||
|
"config": {
|
||||||
|
"database_type": "PostgreSQL" if "postgresql" in settings.database_url else "SQLite",
|
||||||
|
"pool_size": settings.database_pool_size,
|
||||||
|
"max_overflow": settings.database_max_overflow,
|
||||||
|
"total_connections": settings.database_pool_size + settings.database_max_overflow,
|
||||||
|
"pool_timeout": settings.database_pool_timeout,
|
||||||
|
"session_max_active_threshold": settings.database_session_max_active,
|
||||||
|
"session_leak_threshold": settings.database_session_leak_threshold,
|
||||||
|
},
|
||||||
|
"health": {
|
||||||
|
"status": "healthy",
|
||||||
|
"warnings": [],
|
||||||
|
"errors": [],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
# 健康检查
|
||||||
|
if _session_stats["active"] > settings.database_session_leak_threshold:
|
||||||
|
stats["health"]["status"] = "critical"
|
||||||
|
stats["health"]["errors"].append(
|
||||||
|
f"活跃会话数 {_session_stats['active']} 超过泄漏阈值 {settings.database_session_leak_threshold}"
|
||||||
|
)
|
||||||
|
elif _session_stats["active"] > settings.database_session_max_active:
|
||||||
|
stats["health"]["status"] = "warning"
|
||||||
|
stats["health"]["warnings"].append(
|
||||||
|
f"活跃会话数 {_session_stats['active']} 超过警告阈值 {settings.database_session_max_active}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if _session_stats["active"] < 0:
|
||||||
|
stats["health"]["status"] = "error"
|
||||||
|
stats["health"]["errors"].append(f"活跃会话数异常: {_session_stats['active']}")
|
||||||
|
|
||||||
|
error_rate = (_session_stats["errors"] / max(_session_stats["created"], 1)) * 100
|
||||||
|
if error_rate > 5:
|
||||||
|
stats["health"]["status"] = "warning"
|
||||||
|
stats["health"]["warnings"].append(f"会话错误率过高: {error_rate:.2f}%")
|
||||||
|
|
||||||
|
stats["health"]["error_rate"] = f"{error_rate:.2f}%"
|
||||||
|
|
||||||
|
return stats
|
||||||
|
|
||||||
|
|
||||||
|
async def check_database_health(user_id: str = None) -> dict:
|
||||||
|
"""检查数据库连接健康状态
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: 可选的用户ID,如果提供则检查特定用户的数据库
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: 健康检查结果
|
||||||
|
"""
|
||||||
|
result = {
|
||||||
|
"healthy": True,
|
||||||
|
"checks": {},
|
||||||
|
"timestamp": datetime.now().isoformat()
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 检查引擎是否存在
|
||||||
|
if user_id:
|
||||||
|
engine = await get_engine(user_id)
|
||||||
|
cache_key = user_id
|
||||||
|
else:
|
||||||
|
if "postgresql" in settings.database_url:
|
||||||
|
cache_key = "shared_postgres"
|
||||||
|
if cache_key not in _engine_cache:
|
||||||
|
result["checks"]["engine"] = {"status": "not_initialized", "healthy": True}
|
||||||
|
return result
|
||||||
|
engine = _engine_cache[cache_key]
|
||||||
|
else:
|
||||||
|
result["checks"]["engine"] = {"status": "skipped", "message": "需要提供user_id检查SQLite"}
|
||||||
|
return result
|
||||||
|
|
||||||
|
# 测试数据库连接
|
||||||
|
AsyncSessionLocal = async_sessionmaker(
|
||||||
|
engine,
|
||||||
|
class_=AsyncSession,
|
||||||
|
expire_on_commit=False
|
||||||
|
)
|
||||||
|
|
||||||
|
async with AsyncSessionLocal() as session:
|
||||||
|
# 执行简单查询测试连接
|
||||||
|
await session.execute(text("SELECT 1"))
|
||||||
|
result["checks"]["connection"] = {"status": "ok", "healthy": True}
|
||||||
|
|
||||||
|
# 检查连接池状态(仅PostgreSQL)
|
||||||
|
if hasattr(engine.pool, 'size'):
|
||||||
|
pool_status = {
|
||||||
|
"size": engine.pool.size(),
|
||||||
|
"checked_in": engine.pool.checkedin(),
|
||||||
|
"checked_out": engine.pool.checkedout(),
|
||||||
|
"overflow": engine.pool.overflow(),
|
||||||
|
"healthy": True
|
||||||
|
}
|
||||||
|
|
||||||
|
# 连接池健康检查
|
||||||
|
if engine.pool.overflow() >= settings.database_max_overflow:
|
||||||
|
pool_status["healthy"] = False
|
||||||
|
pool_status["warning"] = "连接池溢出已满"
|
||||||
|
result["healthy"] = False
|
||||||
|
|
||||||
|
result["checks"]["pool"] = pool_status
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
result["healthy"] = False
|
||||||
|
result["checks"]["error"] = {
|
||||||
|
"status": "error",
|
||||||
|
"message": str(e),
|
||||||
|
"healthy": False
|
||||||
|
}
|
||||||
|
logger.error(f"数据库健康检查失败: {str(e)}", exc_info=True)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
async def reset_session_stats():
|
||||||
|
"""重置会话统计信息(用于测试或维护)"""
|
||||||
|
global _session_stats
|
||||||
|
_session_stats = {
|
||||||
|
"created": 0,
|
||||||
|
"closed": 0,
|
||||||
|
"active": 0,
|
||||||
|
"errors": 0,
|
||||||
|
"generator_exits": 0,
|
||||||
|
"last_check": datetime.now().isoformat()
|
||||||
|
}
|
||||||
|
logger.info("✅ 会话统计信息已重置")
|
||||||
|
return _session_stats
|
||||||
@@ -28,7 +28,6 @@ logger = get_logger(__name__)
|
|||||||
async def lifespan(app: FastAPI):
|
async def lifespan(app: FastAPI):
|
||||||
"""应用生命周期管理"""
|
"""应用生命周期管理"""
|
||||||
logger.info("应用启动,等待用户登录...")
|
logger.info("应用启动,等待用户登录...")
|
||||||
logger.info("💡 MCP插件采用延迟加载策略,将在用户首次使用时自动加载")
|
|
||||||
|
|
||||||
yield
|
yield
|
||||||
|
|
||||||
|
|||||||
@@ -267,13 +267,26 @@ class HTTPMCPClient:
|
|||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 尝试连接并列举工具
|
# 尝试连接并列举工具(直接调用SDK,避免重复日志)
|
||||||
await self._ensure_connected()
|
await self._ensure_connected()
|
||||||
tools = await self.list_tools()
|
|
||||||
|
result = await self._session.list_tools()
|
||||||
|
|
||||||
|
# 转换为字典格式
|
||||||
|
tools = []
|
||||||
|
for tool in result.tools:
|
||||||
|
tool_dict = {
|
||||||
|
"name": tool.name,
|
||||||
|
"description": tool.description or "",
|
||||||
|
"inputSchema": tool.inputSchema
|
||||||
|
}
|
||||||
|
tools.append(tool_dict)
|
||||||
|
|
||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
response_time = round((end_time - start_time) * 1000, 2)
|
response_time = round((end_time - start_time) * 1000, 2)
|
||||||
|
|
||||||
|
logger.info(f"✅ 连接测试成功,获取到 {len(tools)} 个工具")
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"success": True,
|
"success": True,
|
||||||
"message": "连接测试成功",
|
"message": "连接测试成功",
|
||||||
|
|||||||
@@ -14,8 +14,8 @@ class Character(Base):
|
|||||||
|
|
||||||
# 基本信息
|
# 基本信息
|
||||||
name = Column(String(100), nullable=False, comment="角色/组织名称")
|
name = Column(String(100), nullable=False, comment="角色/组织名称")
|
||||||
age = Column(String(20), comment="年龄")
|
age = Column(String(50), comment="年龄")
|
||||||
gender = Column(String(20), comment="性别")
|
gender = Column(String(50), comment="性别")
|
||||||
is_organization = Column(Boolean, default=False, comment="是否为组织")
|
is_organization = Column(Boolean, default=False, comment="是否为组织")
|
||||||
|
|
||||||
# 角色类型:protagonist(主角)/supporting(配角)/antagonist(反派)
|
# 角色类型:protagonist(主角)/supporting(配角)/antagonist(反派)
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ class StoryMemory(Base):
|
|||||||
"""故事记忆表 - 存储结构化的故事片段和元数据"""
|
"""故事记忆表 - 存储结构化的故事片段和元数据"""
|
||||||
__tablename__ = "story_memories"
|
__tablename__ = "story_memories"
|
||||||
|
|
||||||
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
|
id = Column(String(100), primary_key=True, default=lambda: str(uuid.uuid4()))
|
||||||
project_id = Column(String(36), ForeignKey("projects.id", ondelete="CASCADE"), nullable=False, index=True)
|
project_id = Column(String(36), ForeignKey("projects.id", ondelete="CASCADE"), nullable=False, index=True)
|
||||||
chapter_id = Column(String(36), ForeignKey("chapters.id", ondelete="CASCADE"), nullable=True, index=True)
|
chapter_id = Column(String(36), ForeignKey("chapters.id", ondelete="CASCADE"), nullable=True, index=True)
|
||||||
|
|
||||||
@@ -45,7 +45,7 @@ class StoryMemory(Base):
|
|||||||
|
|
||||||
# 伏笔相关字段
|
# 伏笔相关字段
|
||||||
is_foreshadow = Column(Integer, default=0, comment="伏笔状态: 0=普通记忆, 1=已埋下伏笔, 2=伏笔已回收")
|
is_foreshadow = Column(Integer, default=0, comment="伏笔状态: 0=普通记忆, 1=已埋下伏笔, 2=伏笔已回收")
|
||||||
foreshadow_resolved_at = Column(String(36), ForeignKey("chapters.id", ondelete="SET NULL"), comment="伏笔回收的章节ID")
|
foreshadow_resolved_at = Column(String(100), ForeignKey("chapters.id", ondelete="SET NULL"), comment="伏笔回收的章节ID")
|
||||||
foreshadow_strength = Column(Float, comment="伏笔强度 0.0-1.0")
|
foreshadow_strength = Column(Float, comment="伏笔强度 0.0-1.0")
|
||||||
|
|
||||||
# 向量数据库关联
|
# 向量数据库关联
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ class Project(Base):
|
|||||||
__tablename__ = "projects"
|
__tablename__ = "projects"
|
||||||
|
|
||||||
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
|
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
|
||||||
|
user_id = Column(String(36), nullable=False, index=True, comment="用户ID")
|
||||||
title = Column(String(200), nullable=False, comment="项目标题")
|
title = Column(String(200), nullable=False, comment="项目标题")
|
||||||
description = Column(Text, comment="项目简介")
|
description = Column(Text, comment="项目简介")
|
||||||
theme = Column(Text, comment="主题")
|
theme = Column(Text, comment="主题")
|
||||||
|
|||||||
@@ -75,7 +75,7 @@ class Organization(Base):
|
|||||||
|
|
||||||
# 组织特色
|
# 组织特色
|
||||||
motto = Column(String(200), comment="宗旨/口号")
|
motto = Column(String(200), comment="宗旨/口号")
|
||||||
color = Column(String(20), comment="代表颜色")
|
color = Column(String(100), comment="代表颜色")
|
||||||
|
|
||||||
created_at = Column(DateTime, server_default=func.now(), comment="创建时间")
|
created_at = Column(DateTime, server_default=func.now(), comment="创建时间")
|
||||||
updated_at = Column(DateTime, server_default=func.now(), onupdate=func.now(), comment="更新时间")
|
updated_at = Column(DateTime, server_default=func.now(), onupdate=func.now(), comment="更新时间")
|
||||||
|
|||||||
@@ -82,7 +82,9 @@ class AIService:
|
|||||||
self.openai_http_client = None
|
self.openai_http_client = None
|
||||||
self.openai_api_key = None
|
self.openai_api_key = None
|
||||||
self.openai_base_url = None
|
self.openai_base_url = None
|
||||||
logger.warning("OpenAI API key未配置")
|
# 只有当用户明确选择OpenAI作为提供商时才警告
|
||||||
|
if self.api_provider == "openai":
|
||||||
|
logger.warning("⚠️ OpenAI API key未配置,但被设置为当前AI提供商")
|
||||||
|
|
||||||
# 初始化Anthropic客户端
|
# 初始化Anthropic客户端
|
||||||
anthropic_key = api_key if api_provider == "anthropic" else app_settings.anthropic_api_key
|
anthropic_key = api_key if api_provider == "anthropic" else app_settings.anthropic_api_key
|
||||||
@@ -118,7 +120,9 @@ class AIService:
|
|||||||
self.anthropic_client = None
|
self.anthropic_client = None
|
||||||
else:
|
else:
|
||||||
self.anthropic_client = None
|
self.anthropic_client = None
|
||||||
logger.warning("Anthropic API key未配置")
|
# 只有当用户明确选择Anthropic作为提供商时才警告
|
||||||
|
if self.api_provider == "anthropic":
|
||||||
|
logger.warning("⚠️ Anthropic API key未配置,但被设置为当前AI提供商")
|
||||||
|
|
||||||
async def generate_text(
|
async def generate_text(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -87,7 +87,7 @@ class MemoryService:
|
|||||||
'sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2',
|
'sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2',
|
||||||
cache_folder=model_cache_dir,
|
cache_folder=model_cache_dir,
|
||||||
device='cpu', # 明确指定使用CPU
|
device='cpu', # 明确指定使用CPU
|
||||||
trust_remote_code=False # 安全起见
|
trust_remote_code=False, # 安全起见
|
||||||
)
|
)
|
||||||
logger.info("✅ Embedding模型加载成功 (paraphrase-multilingual-MiniLM-L12-v2)")
|
logger.info("✅ Embedding模型加载成功 (paraphrase-multilingual-MiniLM-L12-v2)")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@@ -5,7 +5,9 @@ python-multipart==0.0.20
|
|||||||
|
|
||||||
# 数据库
|
# 数据库
|
||||||
sqlalchemy==2.0.25
|
sqlalchemy==2.0.25
|
||||||
aiosqlite==0.19.0
|
aiosqlite==0.19.0 # SQLite支持(保留用于开发环境)
|
||||||
|
asyncpg==0.29.0 # PostgreSQL异步驱动(生产环境)
|
||||||
|
psycopg2-binary==2.9.9 # PostgreSQL同步驱动(备用)
|
||||||
|
|
||||||
# 数据验证
|
# 数据验证
|
||||||
pydantic==2.12.4
|
pydantic==2.12.4
|
||||||
@@ -29,8 +31,8 @@ numpy==1.26.4
|
|||||||
chromadb==1.3.2
|
chromadb==1.3.2
|
||||||
|
|
||||||
|
|
||||||
# Transformers(锁定兼容版本)
|
# Transformers(更新到最新稳定版本以修复 FutureWarning)
|
||||||
transformers==4.35.2
|
transformers==4.57.1
|
||||||
|
|
||||||
# Sentence Transformers(基于PyTorch的文本embedding库)
|
# Sentence Transformers(更新到最新稳定版本以修复 FutureWarning)
|
||||||
sentence-transformers==2.3.1
|
sentence-transformers==5.1.2
|
||||||
|
|||||||
@@ -0,0 +1,30 @@
|
|||||||
|
-- PostgreSQL 初始化脚本
|
||||||
|
-- 此脚本会在PostgreSQL容器首次启动时自动执行
|
||||||
|
|
||||||
|
-- 确保使用UTF8编码
|
||||||
|
SET client_encoding = 'UTF8';
|
||||||
|
|
||||||
|
-- 创建必要的扩展
|
||||||
|
CREATE EXTENSION IF NOT EXISTS "uuid-ossp";
|
||||||
|
CREATE EXTENSION IF NOT EXISTS "pg_trgm";
|
||||||
|
|
||||||
|
-- 设置时区
|
||||||
|
SET timezone = 'Asia/Shanghai';
|
||||||
|
|
||||||
|
-- 优化配置(这些设置会在容器启动后生效)
|
||||||
|
-- 注意:部分配置已在docker-compose.yml的command中设置
|
||||||
|
|
||||||
|
-- 创建索引优化查询性能(表会由SQLAlchemy自动创建)
|
||||||
|
-- 这里只是预留空间,实际索引会在应用启动时创建
|
||||||
|
|
||||||
|
-- 输出初始化信息
|
||||||
|
DO $$
|
||||||
|
BEGIN
|
||||||
|
RAISE NOTICE '==================================================';
|
||||||
|
RAISE NOTICE 'MuMuAINovel PostgreSQL 数据库初始化完成';
|
||||||
|
RAISE NOTICE '数据库名称: mumuai_novel';
|
||||||
|
RAISE NOTICE '字符编码: UTF8';
|
||||||
|
RAISE NOTICE '时区设置: Asia/Shanghai';
|
||||||
|
RAISE NOTICE '扩展已安装: uuid-ossp, pg_trgm';
|
||||||
|
RAISE NOTICE '==================================================';
|
||||||
|
END $$;
|
||||||
@@ -0,0 +1,816 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
SQLite to PostgreSQL 数据迁移脚本
|
||||||
|
|
||||||
|
使用方法:
|
||||||
|
python backend/scripts/migrate_sqlite_to_postgres.py
|
||||||
|
|
||||||
|
前置条件:
|
||||||
|
1. PostgreSQL数据库已创建
|
||||||
|
2. .env文件中DATABASE_URL已配置为PostgreSQL
|
||||||
|
3. SQLite数据文件存在于 backend/data/ 目录
|
||||||
|
"""
|
||||||
|
import asyncio
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List, Dict, Any
|
||||||
|
import logging
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
# 添加项目根目录到Python路径
|
||||||
|
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||||
|
|
||||||
|
from sqlalchemy import create_engine, text, select
|
||||||
|
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker
|
||||||
|
from app.database import Base
|
||||||
|
from app.models import (
|
||||||
|
Project, Outline, Character, Chapter, GenerationHistory,
|
||||||
|
Settings, WritingStyle, ProjectDefaultStyle,
|
||||||
|
RelationshipType, CharacterRelationship, Organization, OrganizationMember,
|
||||||
|
StoryMemory, PlotAnalysis, AnalysisTask, BatchGenerationTask,
|
||||||
|
MCPPlugin
|
||||||
|
)
|
||||||
|
from app.config import settings
|
||||||
|
|
||||||
|
# 创建日志目录
|
||||||
|
log_dir = Path(__file__).parent.parent / "logs"
|
||||||
|
log_dir.mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
# 生成日志文件名(带时间戳)
|
||||||
|
log_filename = log_dir / f"migration_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log"
|
||||||
|
|
||||||
|
# 设置日志 - 同时输出到控制台和文件
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO,
|
||||||
|
format='%(asctime)s - %(levelname)s - %(message)s',
|
||||||
|
handlers=[
|
||||||
|
logging.StreamHandler(), # 控制台输出
|
||||||
|
logging.FileHandler(log_filename, encoding='utf-8') # 文件输出
|
||||||
|
]
|
||||||
|
)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
logger.info(f"📝 日志文件: {log_filename}")
|
||||||
|
|
||||||
|
|
||||||
|
class SQLiteToPostgresMigrator:
|
||||||
|
"""SQLite到PostgreSQL的数据迁移器"""
|
||||||
|
|
||||||
|
def __init__(self, sqlite_dir: Path, target_user_id: str):
|
||||||
|
"""
|
||||||
|
初始化迁移器
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sqlite_dir: SQLite数据库文件目录
|
||||||
|
target_user_id: 目标用户ID(迁移后的数据归属)
|
||||||
|
"""
|
||||||
|
self.sqlite_dir = sqlite_dir
|
||||||
|
self.target_user_id = target_user_id
|
||||||
|
self.sqlite_files = list(sqlite_dir.glob("ai_story_user_*.db"))
|
||||||
|
|
||||||
|
# PostgreSQL连接
|
||||||
|
if "postgresql" not in settings.database_url:
|
||||||
|
raise ValueError("DATABASE_URL必须配置为PostgreSQL")
|
||||||
|
|
||||||
|
self.pg_engine = create_async_engine(
|
||||||
|
settings.database_url,
|
||||||
|
echo=False,
|
||||||
|
pool_pre_ping=True
|
||||||
|
)
|
||||||
|
|
||||||
|
self.pg_session_maker = async_sessionmaker(
|
||||||
|
self.pg_engine,
|
||||||
|
class_=AsyncSession,
|
||||||
|
expire_on_commit=False
|
||||||
|
)
|
||||||
|
|
||||||
|
async def migrate_all(self):
|
||||||
|
"""迁移所有SQLite数据库"""
|
||||||
|
if not self.sqlite_files:
|
||||||
|
logger.warning(f"未找到SQLite数据库文件: {self.sqlite_dir}")
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info(f"找到 {len(self.sqlite_files)} 个SQLite数据库文件")
|
||||||
|
|
||||||
|
# 创建PostgreSQL表结构
|
||||||
|
await self._create_tables()
|
||||||
|
|
||||||
|
# 初始化关系类型数据
|
||||||
|
await self._init_relationship_types()
|
||||||
|
|
||||||
|
# 逐个迁移
|
||||||
|
for sqlite_file in self.sqlite_files:
|
||||||
|
await self._migrate_single_db(sqlite_file)
|
||||||
|
|
||||||
|
logger.info("✅ 所有数据迁移完成")
|
||||||
|
|
||||||
|
async def _create_tables(self):
|
||||||
|
"""创建PostgreSQL表结构"""
|
||||||
|
logger.info("创建PostgreSQL表结构...")
|
||||||
|
async with self.pg_engine.begin() as conn:
|
||||||
|
await conn.run_sync(Base.metadata.create_all)
|
||||||
|
logger.info("✅ 表结构创建完成")
|
||||||
|
|
||||||
|
async def _init_relationship_types(self):
|
||||||
|
"""初始化关系类型数据"""
|
||||||
|
logger.info("初始化关系类型数据...")
|
||||||
|
|
||||||
|
# 预置关系类型数据
|
||||||
|
relationship_types = [
|
||||||
|
# 家族关系
|
||||||
|
{"name": "父亲", "category": "family", "reverse_name": "子女", "intimacy_range": "high", "icon": "👨"},
|
||||||
|
{"name": "母亲", "category": "family", "reverse_name": "子女", "intimacy_range": "high", "icon": "👩"},
|
||||||
|
{"name": "兄弟", "category": "family", "reverse_name": "兄弟", "intimacy_range": "high", "icon": "👬"},
|
||||||
|
{"name": "姐妹", "category": "family", "reverse_name": "姐妹", "intimacy_range": "high", "icon": "👭"},
|
||||||
|
{"name": "子女", "category": "family", "reverse_name": "父母", "intimacy_range": "high", "icon": "👶"},
|
||||||
|
{"name": "配偶", "category": "family", "reverse_name": "配偶", "intimacy_range": "high", "icon": "💑"},
|
||||||
|
{"name": "恋人", "category": "family", "reverse_name": "恋人", "intimacy_range": "high", "icon": "💕"},
|
||||||
|
|
||||||
|
# 社交关系
|
||||||
|
{"name": "师父", "category": "social", "reverse_name": "徒弟", "intimacy_range": "high", "icon": "🎓"},
|
||||||
|
{"name": "徒弟", "category": "social", "reverse_name": "师父", "intimacy_range": "high", "icon": "📚"},
|
||||||
|
{"name": "朋友", "category": "social", "reverse_name": "朋友", "intimacy_range": "medium", "icon": "🤝"},
|
||||||
|
{"name": "同学", "category": "social", "reverse_name": "同学", "intimacy_range": "medium", "icon": "🎒"},
|
||||||
|
{"name": "邻居", "category": "social", "reverse_name": "邻居", "intimacy_range": "low", "icon": "🏘️"},
|
||||||
|
{"name": "知己", "category": "social", "reverse_name": "知己", "intimacy_range": "high", "icon": "💙"},
|
||||||
|
|
||||||
|
# 职业关系
|
||||||
|
{"name": "上司", "category": "professional", "reverse_name": "下属", "intimacy_range": "low", "icon": "👔"},
|
||||||
|
{"name": "下属", "category": "professional", "reverse_name": "上司", "intimacy_range": "low", "icon": "💼"},
|
||||||
|
{"name": "同事", "category": "professional", "reverse_name": "同事", "intimacy_range": "medium", "icon": "🤵"},
|
||||||
|
{"name": "合作伙伴", "category": "professional", "reverse_name": "合作伙伴", "intimacy_range": "medium", "icon": "🤜🤛"},
|
||||||
|
|
||||||
|
# 敌对关系
|
||||||
|
{"name": "敌人", "category": "hostile", "reverse_name": "敌人", "intimacy_range": "low", "icon": "⚔️"},
|
||||||
|
{"name": "仇人", "category": "hostile", "reverse_name": "仇人", "intimacy_range": "low", "icon": "💢"},
|
||||||
|
{"name": "竞争对手", "category": "hostile", "reverse_name": "竞争对手", "intimacy_range": "low", "icon": "🎯"},
|
||||||
|
{"name": "宿敌", "category": "hostile", "reverse_name": "宿敌", "intimacy_range": "low", "icon": "⚡"},
|
||||||
|
]
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with self.pg_session_maker() as session:
|
||||||
|
# 检查是否已经有数据
|
||||||
|
result = await session.execute(select(RelationshipType))
|
||||||
|
existing = result.scalars().first()
|
||||||
|
|
||||||
|
if existing:
|
||||||
|
logger.info("关系类型数据已存在,跳过初始化")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 插入预置数据
|
||||||
|
logger.info("开始插入关系类型数据...")
|
||||||
|
for rt_data in relationship_types:
|
||||||
|
relationship_type = RelationshipType(**rt_data)
|
||||||
|
session.add(relationship_type)
|
||||||
|
|
||||||
|
await session.commit()
|
||||||
|
logger.info(f"✅ 成功插入 {len(relationship_types)} 条关系类型数据")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"初始化关系类型数据失败: {str(e)}", exc_info=True)
|
||||||
|
# 不抛出异常,继续迁移流程
|
||||||
|
logger.warning("关系类型初始化失败,将跳过有外键依赖的记录")
|
||||||
|
|
||||||
|
async def _migrate_single_db(self, sqlite_file: Path):
|
||||||
|
"""迁移单个SQLite数据库"""
|
||||||
|
# 从文件名提取user_id
|
||||||
|
filename = sqlite_file.stem # ai_story_user_xxx
|
||||||
|
if filename.startswith("ai_story_user_"):
|
||||||
|
user_id = filename.replace("ai_story_user_", "")
|
||||||
|
else:
|
||||||
|
user_id = self.target_user_id
|
||||||
|
|
||||||
|
logger.info(f"\n{'='*60}")
|
||||||
|
logger.info(f"开始迁移: {sqlite_file.name} -> user_id: {user_id}")
|
||||||
|
logger.info(f"{'='*60}")
|
||||||
|
|
||||||
|
# 创建SQLite连接
|
||||||
|
sqlite_url = f"sqlite+aiosqlite:///{sqlite_file.absolute()}"
|
||||||
|
sqlite_engine = create_async_engine(sqlite_url, echo=False)
|
||||||
|
sqlite_session_maker = async_sessionmaker(
|
||||||
|
sqlite_engine,
|
||||||
|
class_=AsyncSession,
|
||||||
|
expire_on_commit=False
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 迁移各个表
|
||||||
|
async with sqlite_session_maker() as sqlite_session:
|
||||||
|
async with self.pg_session_maker() as pg_session:
|
||||||
|
# 按照依赖顺序迁移
|
||||||
|
await self._migrate_table(
|
||||||
|
sqlite_session, pg_session, user_id, Settings, "设置"
|
||||||
|
)
|
||||||
|
await self._migrate_table(
|
||||||
|
sqlite_session, pg_session, user_id, Project, "项目"
|
||||||
|
)
|
||||||
|
await self._migrate_table(
|
||||||
|
sqlite_session, pg_session, user_id, Character, "角色"
|
||||||
|
)
|
||||||
|
await self._migrate_table(
|
||||||
|
sqlite_session, pg_session, user_id, Outline, "大纲"
|
||||||
|
)
|
||||||
|
await self._migrate_table(
|
||||||
|
sqlite_session, pg_session, user_id, Chapter, "章节"
|
||||||
|
)
|
||||||
|
await self._migrate_table(
|
||||||
|
sqlite_session, pg_session, user_id, CharacterRelationship, "角色关系"
|
||||||
|
)
|
||||||
|
await self._migrate_table(
|
||||||
|
sqlite_session, pg_session, user_id, Organization, "组织"
|
||||||
|
)
|
||||||
|
await self._migrate_table(
|
||||||
|
sqlite_session, pg_session, user_id, OrganizationMember, "组织成员"
|
||||||
|
)
|
||||||
|
await self._migrate_table(
|
||||||
|
sqlite_session, pg_session, user_id, GenerationHistory, "生成历史"
|
||||||
|
)
|
||||||
|
await self._migrate_table(
|
||||||
|
sqlite_session, pg_session, user_id, WritingStyle, "写作风格"
|
||||||
|
)
|
||||||
|
await self._migrate_table(
|
||||||
|
sqlite_session, pg_session, user_id, ProjectDefaultStyle, "项目默认风格"
|
||||||
|
)
|
||||||
|
await self._migrate_table(
|
||||||
|
sqlite_session, pg_session, user_id, StoryMemory, "记忆"
|
||||||
|
)
|
||||||
|
await self._migrate_table(
|
||||||
|
sqlite_session, pg_session, user_id, PlotAnalysis, "剧情分析"
|
||||||
|
)
|
||||||
|
await self._migrate_table(
|
||||||
|
sqlite_session, pg_session, user_id, AnalysisTask, "分析任务"
|
||||||
|
)
|
||||||
|
await self._migrate_table(
|
||||||
|
sqlite_session, pg_session, user_id, BatchGenerationTask, "批量生成任务"
|
||||||
|
)
|
||||||
|
await self._migrate_table(
|
||||||
|
sqlite_session, pg_session, user_id, MCPPlugin, "MCP插件"
|
||||||
|
)
|
||||||
|
|
||||||
|
await pg_session.commit()
|
||||||
|
|
||||||
|
logger.info(f"✅ {sqlite_file.name} 迁移完成")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"❌ 迁移失败: {e}", exc_info=True)
|
||||||
|
finally:
|
||||||
|
await sqlite_engine.dispose()
|
||||||
|
|
||||||
|
async def _migrate_table(
|
||||||
|
self,
|
||||||
|
sqlite_session: AsyncSession,
|
||||||
|
pg_session: AsyncSession,
|
||||||
|
user_id: str,
|
||||||
|
model_class,
|
||||||
|
table_name: str
|
||||||
|
):
|
||||||
|
"""迁移单个表的数据"""
|
||||||
|
try:
|
||||||
|
# 获取SQLite表中实际存在的列
|
||||||
|
sqlite_table = model_class.__table__
|
||||||
|
sqlite_conn = await sqlite_session.connection()
|
||||||
|
|
||||||
|
# 查询SQLite表结构
|
||||||
|
inspect_result = await sqlite_conn.execute(
|
||||||
|
text(f"PRAGMA table_info({sqlite_table.name})")
|
||||||
|
)
|
||||||
|
sqlite_columns = {row[1] for row in inspect_result.fetchall()} # row[1]是列名
|
||||||
|
|
||||||
|
# 构建只包含SQLite中存在的列的查询
|
||||||
|
available_columns = [
|
||||||
|
c for c in model_class.__table__.columns
|
||||||
|
if c.name in sqlite_columns
|
||||||
|
]
|
||||||
|
|
||||||
|
if not available_columns:
|
||||||
|
logger.warning(f" ⚠️ {table_name}: 表结构不匹配,跳过")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 从SQLite读取数据(只查询存在的列)
|
||||||
|
result = await sqlite_session.execute(
|
||||||
|
select(*available_columns)
|
||||||
|
)
|
||||||
|
records = result.all()
|
||||||
|
|
||||||
|
if not records:
|
||||||
|
logger.info(f" - {table_name}: 无数据")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 为每条记录创建字典并添加user_id
|
||||||
|
migrated_count = 0
|
||||||
|
skipped_count = 0
|
||||||
|
|
||||||
|
for record in records:
|
||||||
|
# 从查询结果构建字典
|
||||||
|
record_dict = {}
|
||||||
|
for i, col in enumerate(available_columns):
|
||||||
|
record_dict[col.name] = record[i]
|
||||||
|
|
||||||
|
# 添加user_id(如果PostgreSQL模型有这个字段但SQLite没有)
|
||||||
|
if hasattr(model_class, 'user_id') and 'user_id' not in record_dict:
|
||||||
|
record_dict['user_id'] = user_id
|
||||||
|
|
||||||
|
# 验证字段长度(防止超长字段导致插入失败)
|
||||||
|
if not self._validate_field_lengths(model_class, record_dict, table_name):
|
||||||
|
skipped_count += 1
|
||||||
|
record_id = record_dict.get('id', 'unknown')
|
||||||
|
logger.warning(f" ⚠️ [{table_name}] 跳过超长字段记录 ID={record_id}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 验证外键引用(针对有外键的表)
|
||||||
|
validation_result = await self._validate_foreign_keys(pg_session, model_class, record_dict)
|
||||||
|
if not validation_result:
|
||||||
|
skipped_count += 1
|
||||||
|
record_id = record_dict.get('id', 'unknown')
|
||||||
|
logger.warning(f" ⚠️ [{table_name}] 跳过无效外键记录 ID={record_id}")
|
||||||
|
# 输出记录详情以便调试
|
||||||
|
if model_class.__tablename__ == 'story_memories':
|
||||||
|
logger.warning(f" 记忆详情: project_id={record_dict.get('project_id')}, "
|
||||||
|
f"chapter_id={record_dict.get('chapter_id')}, "
|
||||||
|
f"type={record_dict.get('memory_type')}")
|
||||||
|
elif model_class.__tablename__ == 'character_relationships':
|
||||||
|
logger.warning(f" 关系详情: project_id={record_dict.get('project_id')}, "
|
||||||
|
f"from={record_dict.get('character_from_id')}, "
|
||||||
|
f"to={record_dict.get('character_to_id')}, "
|
||||||
|
f"type_id={record_dict.get('relationship_type_id')}")
|
||||||
|
elif model_class.__tablename__ == 'organizations':
|
||||||
|
logger.warning(f" 组织详情: project_id={record_dict.get('project_id')}, "
|
||||||
|
f"character_id={record_dict.get('character_id')}")
|
||||||
|
elif model_class.__tablename__ == 'organization_members':
|
||||||
|
logger.warning(f" 成员详情: org_id={record_dict.get('organization_id')}, "
|
||||||
|
f"character_id={record_dict.get('character_id')}")
|
||||||
|
elif model_class.__tablename__ == 'writing_styles':
|
||||||
|
logger.warning(f" 写作风格详情: project_id={record_dict.get('project_id')}, "
|
||||||
|
f"name={record_dict.get('name')}, "
|
||||||
|
f"style_type={record_dict.get('style_type')}")
|
||||||
|
elif model_class.__tablename__ == 'characters':
|
||||||
|
logger.warning(f" 角色详情: project_id={record_dict.get('project_id')}, "
|
||||||
|
f"name={record_dict.get('name')}, "
|
||||||
|
f"is_organization={record_dict.get('is_organization')}")
|
||||||
|
elif model_class.__tablename__ == 'outlines':
|
||||||
|
logger.warning(f" 大纲详情: project_id={record_dict.get('project_id')}, "
|
||||||
|
f"title={record_dict.get('title')}")
|
||||||
|
elif model_class.__tablename__ == 'chapters':
|
||||||
|
logger.warning(f" 章节详情: project_id={record_dict.get('project_id')}, "
|
||||||
|
f"title={record_dict.get('title')}, "
|
||||||
|
f"chapter_number={record_dict.get('chapter_number')}")
|
||||||
|
elif model_class.__tablename__ == 'generation_history':
|
||||||
|
logger.warning(f" 生成历史详情: project_id={record_dict.get('project_id')}, "
|
||||||
|
f"chapter_id={record_dict.get('chapter_id')}, "
|
||||||
|
f"model={record_dict.get('model')}")
|
||||||
|
elif model_class.__tablename__ == 'plot_analysis':
|
||||||
|
logger.warning(f" 剧情分析详情: project_id={record_dict.get('project_id')}, "
|
||||||
|
f"chapter_id={record_dict.get('chapter_id')}, "
|
||||||
|
f"plot_stage={record_dict.get('plot_stage')}")
|
||||||
|
elif model_class.__tablename__ == 'analysis_tasks':
|
||||||
|
logger.warning(f" 分析任务详情: chapter_id={record_dict.get('chapter_id')}, "
|
||||||
|
f"project_id={record_dict.get('project_id')}, "
|
||||||
|
f"status={record_dict.get('status')}")
|
||||||
|
elif model_class.__tablename__ == 'batch_generation_tasks':
|
||||||
|
logger.warning(f" 批量生成任务详情: project_id={record_dict.get('project_id')}, "
|
||||||
|
f"status={record_dict.get('status')}, "
|
||||||
|
f"completed={record_dict.get('completed_chapters')}/{record_dict.get('total_chapters')}")
|
||||||
|
elif model_class.__tablename__ == 'project_default_styles':
|
||||||
|
logger.warning(f" 项目默认风格详情: project_id={record_dict.get('project_id')}, "
|
||||||
|
f"style_id={record_dict.get('style_id')}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 检查记录是否已存在(避免主键冲突)
|
||||||
|
record_id = record_dict.get('id')
|
||||||
|
if record_id and await self._record_exists(pg_session, model_class, record_id):
|
||||||
|
skipped_count += 1
|
||||||
|
logger.debug(f" 跳过已存在的记录: {record_id}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 创建新记录
|
||||||
|
try:
|
||||||
|
new_record = model_class(**record_dict)
|
||||||
|
pg_session.add(new_record)
|
||||||
|
migrated_count += 1
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f" ⚠️ 跳过无效记录: {str(e)[:100]}")
|
||||||
|
skipped_count += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
await pg_session.flush()
|
||||||
|
|
||||||
|
if skipped_count > 0:
|
||||||
|
logger.info(f" ✅ {table_name}: {migrated_count} 条记录(跳过 {skipped_count} 条无效记录)")
|
||||||
|
else:
|
||||||
|
logger.info(f" ✅ {table_name}: {migrated_count} 条记录")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f" ❌ {table_name} 迁移失败: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def _record_exists(
|
||||||
|
self,
|
||||||
|
pg_session: AsyncSession,
|
||||||
|
model_class,
|
||||||
|
record_id: Any
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
检查记录是否已存在
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pg_session: PostgreSQL会话
|
||||||
|
model_class: 模型类
|
||||||
|
record_id: 记录ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: 记录是否存在
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 获取主键列
|
||||||
|
pk_column = list(model_class.__table__.primary_key.columns)[0]
|
||||||
|
result = await pg_session.execute(
|
||||||
|
select(pk_column).where(pk_column == record_id)
|
||||||
|
)
|
||||||
|
return result.scalar_one_or_none() is not None
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def _validate_foreign_keys(
|
||||||
|
self,
|
||||||
|
pg_session: AsyncSession,
|
||||||
|
model_class,
|
||||||
|
record_dict: Dict[str, Any]
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
验证记录的外键是否有效
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pg_session: PostgreSQL会话
|
||||||
|
model_class: 模型类
|
||||||
|
record_dict: 记录字典
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: 外键是否全部有效
|
||||||
|
"""
|
||||||
|
from app.models import Character, Project, Chapter
|
||||||
|
|
||||||
|
# 使用no_autoflush防止过早flush
|
||||||
|
with pg_session.no_autoflush:
|
||||||
|
# 针对StoryMemory表验证外键
|
||||||
|
if model_class.__tablename__ == 'story_memories':
|
||||||
|
# 验证project_id
|
||||||
|
project_id = record_dict.get('project_id')
|
||||||
|
if project_id:
|
||||||
|
result = await pg_session.execute(
|
||||||
|
select(Project.id).where(Project.id == project_id)
|
||||||
|
)
|
||||||
|
if not result.scalar_one_or_none():
|
||||||
|
logger.warning(f" ❌ [记忆] 无效的project_id: {project_id}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# 验证chapter_id(可选)
|
||||||
|
chapter_id = record_dict.get('chapter_id')
|
||||||
|
if chapter_id:
|
||||||
|
result = await pg_session.execute(
|
||||||
|
select(Chapter.id).where(Chapter.id == chapter_id)
|
||||||
|
)
|
||||||
|
if not result.scalar_one_or_none():
|
||||||
|
logger.warning(f" ❌ [记忆] 无效的chapter_id: {chapter_id}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# 针对CharacterRelationship表验证外键
|
||||||
|
elif model_class.__tablename__ == 'character_relationships':
|
||||||
|
# 验证project_id
|
||||||
|
project_id = record_dict.get('project_id')
|
||||||
|
if project_id:
|
||||||
|
result = await pg_session.execute(
|
||||||
|
select(Project.id).where(Project.id == project_id)
|
||||||
|
)
|
||||||
|
if not result.scalar_one_or_none():
|
||||||
|
logger.warning(f" ❌ 无效的project_id: {project_id}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# 验证character_from_id
|
||||||
|
char_from_id = record_dict.get('character_from_id')
|
||||||
|
if char_from_id:
|
||||||
|
result = await pg_session.execute(
|
||||||
|
select(Character.id).where(Character.id == char_from_id)
|
||||||
|
)
|
||||||
|
if not result.scalar_one_or_none():
|
||||||
|
logger.warning(f" ❌ 无效的character_from_id: {char_from_id}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# 验证character_to_id
|
||||||
|
char_to_id = record_dict.get('character_to_id')
|
||||||
|
if char_to_id:
|
||||||
|
result = await pg_session.execute(
|
||||||
|
select(Character.id).where(Character.id == char_to_id)
|
||||||
|
)
|
||||||
|
if not result.scalar_one_or_none():
|
||||||
|
logger.warning(f" ❌ 无效的character_to_id: {char_to_id}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# 验证relationship_type_id
|
||||||
|
rel_type_id = record_dict.get('relationship_type_id')
|
||||||
|
if rel_type_id:
|
||||||
|
result = await pg_session.execute(
|
||||||
|
select(RelationshipType.id).where(RelationshipType.id == rel_type_id)
|
||||||
|
)
|
||||||
|
if not result.scalar_one_or_none():
|
||||||
|
logger.warning(f" ❌ 无效的relationship_type_id: {rel_type_id}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# 针对Organization表验证外键
|
||||||
|
elif model_class.__tablename__ == 'organizations':
|
||||||
|
# 验证character_id
|
||||||
|
char_id = record_dict.get('character_id')
|
||||||
|
if char_id:
|
||||||
|
result = await pg_session.execute(
|
||||||
|
select(Character.id).where(Character.id == char_id)
|
||||||
|
)
|
||||||
|
if not result.scalar_one_or_none():
|
||||||
|
logger.warning(f" ❌ [组织] 无效的character_id: {char_id}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# 针对OrganizationMember表验证外键
|
||||||
|
elif model_class.__tablename__ == 'organization_members':
|
||||||
|
from app.models import Organization
|
||||||
|
|
||||||
|
# 验证organization_id
|
||||||
|
org_id = record_dict.get('organization_id')
|
||||||
|
if org_id:
|
||||||
|
result = await pg_session.execute(
|
||||||
|
select(Organization.id).where(Organization.id == org_id)
|
||||||
|
)
|
||||||
|
if not result.scalar_one_or_none():
|
||||||
|
logger.warning(f" ❌ 无效的organization_id: {org_id}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# 验证character_id
|
||||||
|
char_id = record_dict.get('character_id')
|
||||||
|
if char_id:
|
||||||
|
result = await pg_session.execute(
|
||||||
|
select(Character.id).where(Character.id == char_id)
|
||||||
|
)
|
||||||
|
if not result.scalar_one_or_none():
|
||||||
|
logger.warning(f" ❌ [组织成员] 无效的character_id: {char_id}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# 针对Character表验证外键
|
||||||
|
elif model_class.__tablename__ == 'characters':
|
||||||
|
# 验证project_id
|
||||||
|
project_id = record_dict.get('project_id')
|
||||||
|
if project_id:
|
||||||
|
result = await pg_session.execute(
|
||||||
|
select(Project.id).where(Project.id == project_id)
|
||||||
|
)
|
||||||
|
if not result.scalar_one_or_none():
|
||||||
|
logger.warning(f" ❌ [角色] 无效的project_id: {project_id}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# 针对Outline表验证外键
|
||||||
|
elif model_class.__tablename__ == 'outlines':
|
||||||
|
# 验证project_id
|
||||||
|
project_id = record_dict.get('project_id')
|
||||||
|
if project_id:
|
||||||
|
result = await pg_session.execute(
|
||||||
|
select(Project.id).where(Project.id == project_id)
|
||||||
|
)
|
||||||
|
if not result.scalar_one_or_none():
|
||||||
|
logger.warning(f" ❌ [大纲] 无效的project_id: {project_id}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# 针对Chapter表验证外键
|
||||||
|
elif model_class.__tablename__ == 'chapters':
|
||||||
|
# 验证project_id
|
||||||
|
project_id = record_dict.get('project_id')
|
||||||
|
if project_id:
|
||||||
|
result = await pg_session.execute(
|
||||||
|
select(Project.id).where(Project.id == project_id)
|
||||||
|
)
|
||||||
|
if not result.scalar_one_or_none():
|
||||||
|
logger.warning(f" ❌ [章节] 无效的project_id: {project_id}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# 针对WritingStyle表验证外键
|
||||||
|
elif model_class.__tablename__ == 'writing_styles':
|
||||||
|
# 验证project_id(可选)
|
||||||
|
project_id = record_dict.get('project_id')
|
||||||
|
if project_id:
|
||||||
|
result = await pg_session.execute(
|
||||||
|
select(Project.id).where(Project.id == project_id)
|
||||||
|
)
|
||||||
|
if not result.scalar_one_or_none():
|
||||||
|
logger.warning(f" ❌ [写作风格] 无效的project_id: {project_id}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# 针对GenerationHistory表验证外键
|
||||||
|
elif model_class.__tablename__ == 'generation_history':
|
||||||
|
# 验证project_id
|
||||||
|
project_id = record_dict.get('project_id')
|
||||||
|
if project_id:
|
||||||
|
result = await pg_session.execute(
|
||||||
|
select(Project.id).where(Project.id == project_id)
|
||||||
|
)
|
||||||
|
if not result.scalar_one_or_none():
|
||||||
|
logger.warning(f" ❌ [生成历史] 无效的project_id: {project_id}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# 验证chapter_id(可选)
|
||||||
|
chapter_id = record_dict.get('chapter_id')
|
||||||
|
if chapter_id:
|
||||||
|
result = await pg_session.execute(
|
||||||
|
select(Chapter.id).where(Chapter.id == chapter_id)
|
||||||
|
)
|
||||||
|
if not result.scalar_one_or_none():
|
||||||
|
logger.warning(f" ❌ [生成历史] 无效的chapter_id: {chapter_id}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# 针对PlotAnalysis表验证外键
|
||||||
|
elif model_class.__tablename__ == 'plot_analysis':
|
||||||
|
# 验证project_id(必需)
|
||||||
|
project_id = record_dict.get('project_id')
|
||||||
|
if project_id:
|
||||||
|
result = await pg_session.execute(
|
||||||
|
select(Project.id).where(Project.id == project_id)
|
||||||
|
)
|
||||||
|
if not result.scalar_one_or_none():
|
||||||
|
logger.warning(f" ❌ [剧情分析] 无效的project_id: {project_id}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# 验证chapter_id(必需)
|
||||||
|
chapter_id = record_dict.get('chapter_id')
|
||||||
|
if chapter_id:
|
||||||
|
result = await pg_session.execute(
|
||||||
|
select(Chapter.id).where(Chapter.id == chapter_id)
|
||||||
|
)
|
||||||
|
if not result.scalar_one_or_none():
|
||||||
|
logger.warning(f" ❌ [剧情分析] 无效的chapter_id: {chapter_id}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# 针对AnalysisTask表验证外键
|
||||||
|
elif model_class.__tablename__ == 'analysis_tasks':
|
||||||
|
# 验证chapter_id(必需)
|
||||||
|
chapter_id = record_dict.get('chapter_id')
|
||||||
|
if chapter_id:
|
||||||
|
result = await pg_session.execute(
|
||||||
|
select(Chapter.id).where(Chapter.id == chapter_id)
|
||||||
|
)
|
||||||
|
if not result.scalar_one_or_none():
|
||||||
|
logger.warning(f" ❌ [分析任务] 无效的chapter_id: {chapter_id}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# 验证project_id
|
||||||
|
project_id = record_dict.get('project_id')
|
||||||
|
if project_id:
|
||||||
|
result = await pg_session.execute(
|
||||||
|
select(Project.id).where(Project.id == project_id)
|
||||||
|
)
|
||||||
|
if not result.scalar_one_or_none():
|
||||||
|
logger.warning(f" ❌ [分析任务] 无效的project_id: {project_id}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# 针对BatchGenerationTask表验证外键
|
||||||
|
elif model_class.__tablename__ == 'batch_generation_tasks':
|
||||||
|
# 验证project_id(必需)
|
||||||
|
project_id = record_dict.get('project_id')
|
||||||
|
if project_id:
|
||||||
|
result = await pg_session.execute(
|
||||||
|
select(Project.id).where(Project.id == project_id)
|
||||||
|
)
|
||||||
|
if not result.scalar_one_or_none():
|
||||||
|
logger.warning(f" ❌ [批量生成任务] 无效的project_id: {project_id}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# 针对ProjectDefaultStyle表验证外键
|
||||||
|
elif model_class.__tablename__ == 'project_default_styles':
|
||||||
|
from app.models import WritingStyle
|
||||||
|
|
||||||
|
# 验证project_id(必需)
|
||||||
|
project_id = record_dict.get('project_id')
|
||||||
|
if project_id:
|
||||||
|
result = await pg_session.execute(
|
||||||
|
select(Project.id).where(Project.id == project_id)
|
||||||
|
)
|
||||||
|
if not result.scalar_one_or_none():
|
||||||
|
logger.warning(f" ❌ [项目默认风格] 无效的project_id: {project_id}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# 验证style_id(必需)
|
||||||
|
style_id = record_dict.get('style_id')
|
||||||
|
if style_id:
|
||||||
|
result = await pg_session.execute(
|
||||||
|
select(WritingStyle.id).where(WritingStyle.id == style_id)
|
||||||
|
)
|
||||||
|
if not result.scalar_one_or_none():
|
||||||
|
logger.warning(f" ❌ [项目默认风格] 无效的style_id: {style_id}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _validate_field_lengths(
|
||||||
|
self,
|
||||||
|
model_class,
|
||||||
|
record_dict: Dict[str, Any],
|
||||||
|
table_name: str
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
验证记录的字段长度是否符合模型定义
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_class: 模型类
|
||||||
|
record_dict: 记录字典
|
||||||
|
table_name: 表名(用于日志)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: 字段长度是否全部有效
|
||||||
|
"""
|
||||||
|
from sqlalchemy import String
|
||||||
|
|
||||||
|
# 检查所有字符串类型字段
|
||||||
|
for column in model_class.__table__.columns:
|
||||||
|
# 只检查有长度限制的String类型字段
|
||||||
|
if isinstance(column.type, String) and column.type.length:
|
||||||
|
field_name = column.name
|
||||||
|
field_value = record_dict.get(field_name)
|
||||||
|
max_length = column.type.length
|
||||||
|
|
||||||
|
# 如果字段有值且超过最大长度
|
||||||
|
if field_value and isinstance(field_value, str) and len(field_value) > max_length:
|
||||||
|
logger.warning(
|
||||||
|
f" ❌ [{table_name}] 字段 '{field_name}' 超长: "
|
||||||
|
f"{len(field_value)} > {max_length} (截断了 {len(field_value) - max_length} 字符)"
|
||||||
|
)
|
||||||
|
# 对于敏感字段如API密钥,记录部分内容
|
||||||
|
if field_name in ['api_key', 'api_base_url']:
|
||||||
|
preview = field_value[:50] + "..." + field_value[-20:] if len(field_value) > 70 else field_value
|
||||||
|
logger.warning(f" 值预览: {preview}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def cleanup(self):
|
||||||
|
"""清理资源"""
|
||||||
|
await self.pg_engine.dispose()
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
"""主函数"""
|
||||||
|
banner = """
|
||||||
|
╔══════════════════════════════════════════════════════════════╗
|
||||||
|
║ SQLite to PostgreSQL 数据迁移工具 ║
|
||||||
|
║ ║
|
||||||
|
║ 此工具将SQLite数据迁移到PostgreSQL ║
|
||||||
|
║ 请确保: ║
|
||||||
|
║ 1. PostgreSQL数据库已创建 ║
|
||||||
|
║ 2. .env中DATABASE_URL已配置为PostgreSQL ║
|
||||||
|
║ 3. SQLite数据文件存在 ║
|
||||||
|
╚══════════════════════════════════════════════════════════════╝
|
||||||
|
"""
|
||||||
|
print(banner)
|
||||||
|
logger.info(banner)
|
||||||
|
|
||||||
|
# 配置
|
||||||
|
sqlite_dir = Path(__file__).parent.parent / "data"
|
||||||
|
target_user_id = "migrated_user" # 默认用户ID
|
||||||
|
|
||||||
|
config_info = f"""
|
||||||
|
配置信息:
|
||||||
|
SQLite目录: {sqlite_dir}
|
||||||
|
PostgreSQL: {settings.database_url}
|
||||||
|
目标用户ID: {target_user_id}
|
||||||
|
日志文件: {log_filename}
|
||||||
|
"""
|
||||||
|
print(config_info)
|
||||||
|
logger.info(config_info)
|
||||||
|
|
||||||
|
# 确认
|
||||||
|
response = input("是否继续迁移? (yes/no): ")
|
||||||
|
if response.lower() not in ['yes', 'y']:
|
||||||
|
print("已取消迁移")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 执行迁移
|
||||||
|
migrator = SQLiteToPostgresMigrator(sqlite_dir, target_user_id)
|
||||||
|
|
||||||
|
try:
|
||||||
|
await migrator.migrate_all()
|
||||||
|
success_msg = """
|
||||||
|
🎉 数据迁移成功完成!
|
||||||
|
|
||||||
|
下一步:
|
||||||
|
1. 测试应用功能
|
||||||
|
2. 验证数据完整性
|
||||||
|
3. 备份SQLite文件后可删除
|
||||||
|
|
||||||
|
详细日志已保存到: {}
|
||||||
|
""".format(log_filename)
|
||||||
|
print(success_msg)
|
||||||
|
logger.info(success_msg)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"\n❌ 迁移失败: {e}\n详细日志已保存到: {log_filename}"
|
||||||
|
print(error_msg)
|
||||||
|
logger.error("迁移过程出错", exc_info=True)
|
||||||
|
|
||||||
|
finally:
|
||||||
|
await migrator.cleanup()
|
||||||
|
logger.info(f"🔒 数据库连接已关闭,日志文件: {log_filename}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(main())
|
||||||
@@ -0,0 +1,408 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
PostgreSQL 数据库自动设置脚本
|
||||||
|
|
||||||
|
功能:
|
||||||
|
1. 自动连接到PostgreSQL服务器
|
||||||
|
2. 创建数据库和用户
|
||||||
|
3. 设置权限
|
||||||
|
4. 初始化表结构
|
||||||
|
|
||||||
|
使用方法:
|
||||||
|
python backend/scripts/setup_postgres.py
|
||||||
|
|
||||||
|
前置条件:
|
||||||
|
- PostgreSQL服务已安装并运行
|
||||||
|
- 知道PostgreSQL的超级用户密码(通常是postgres用户)
|
||||||
|
"""
|
||||||
|
import sys
|
||||||
|
import asyncio
|
||||||
|
from pathlib import Path
|
||||||
|
from getpass import getpass
|
||||||
|
import logging
|
||||||
|
|
||||||
|
# 添加项目根目录到Python路径
|
||||||
|
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||||
|
|
||||||
|
try:
|
||||||
|
import psycopg2
|
||||||
|
from psycopg2 import sql
|
||||||
|
from psycopg2.extensions import ISOLATION_LEVEL_AUTOCOMMIT
|
||||||
|
except ImportError:
|
||||||
|
print("❌ 缺少psycopg2依赖,正在安装...")
|
||||||
|
import subprocess
|
||||||
|
subprocess.check_call([sys.executable, "-m", "pip", "install", "psycopg2-binary"])
|
||||||
|
import psycopg2
|
||||||
|
from psycopg2 import sql
|
||||||
|
from psycopg2.extensions import ISOLATION_LEVEL_AUTOCOMMIT
|
||||||
|
|
||||||
|
from app.database import init_db
|
||||||
|
|
||||||
|
# 设置日志
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO,
|
||||||
|
format='%(message)s'
|
||||||
|
)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class PostgreSQLSetup:
|
||||||
|
"""PostgreSQL数据库自动设置"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
host: str = "localhost",
|
||||||
|
port: int = 5432,
|
||||||
|
admin_user: str = "postgres",
|
||||||
|
admin_password: str = None,
|
||||||
|
db_name: str = "mumuai_novel",
|
||||||
|
db_user: str = "mumuai",
|
||||||
|
db_password: str = "123456"
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
初始化设置参数
|
||||||
|
|
||||||
|
Args:
|
||||||
|
host: PostgreSQL主机地址
|
||||||
|
port: PostgreSQL端口
|
||||||
|
admin_user: 管理员用户名
|
||||||
|
admin_password: 管理员密码
|
||||||
|
db_name: 要创建的数据库名
|
||||||
|
db_user: 要创建的用户名
|
||||||
|
db_password: 用户密码
|
||||||
|
"""
|
||||||
|
self.host = host
|
||||||
|
self.port = port
|
||||||
|
self.admin_user = admin_user
|
||||||
|
self.admin_password = admin_password
|
||||||
|
self.db_name = db_name
|
||||||
|
self.db_user = db_user
|
||||||
|
self.db_password = db_password
|
||||||
|
self.conn = None
|
||||||
|
|
||||||
|
def connect_as_admin(self) -> bool:
|
||||||
|
"""连接到PostgreSQL(使用管理员权限)"""
|
||||||
|
try:
|
||||||
|
logger.info(f"🔌 连接到 PostgreSQL ({self.host}:{self.port})...")
|
||||||
|
|
||||||
|
self.conn = psycopg2.connect(
|
||||||
|
host=self.host,
|
||||||
|
port=self.port,
|
||||||
|
user=self.admin_user,
|
||||||
|
password=self.admin_password,
|
||||||
|
database="postgres" # 连接到默认数据库
|
||||||
|
)
|
||||||
|
self.conn.set_isolation_level(ISOLATION_LEVEL_AUTOCOMMIT)
|
||||||
|
|
||||||
|
logger.info(f"✅ 已连接到 PostgreSQL")
|
||||||
|
return True
|
||||||
|
|
||||||
|
except psycopg2.OperationalError as e:
|
||||||
|
logger.error(f"❌ 连接失败: {e}")
|
||||||
|
logger.error("\n可能的原因:")
|
||||||
|
logger.error("1. PostgreSQL服务未启动")
|
||||||
|
logger.error("2. 管理员密码错误")
|
||||||
|
logger.error("3. 主机地址或端口错误")
|
||||||
|
logger.error("4. pg_hba.conf配置不允许连接")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def database_exists(self) -> bool:
|
||||||
|
"""检查数据库是否存在"""
|
||||||
|
cursor = self.conn.cursor()
|
||||||
|
cursor.execute(
|
||||||
|
"SELECT 1 FROM pg_database WHERE datname = %s",
|
||||||
|
(self.db_name,)
|
||||||
|
)
|
||||||
|
exists = cursor.fetchone() is not None
|
||||||
|
cursor.close()
|
||||||
|
return exists
|
||||||
|
|
||||||
|
def user_exists(self) -> bool:
|
||||||
|
"""检查用户是否存在"""
|
||||||
|
cursor = self.conn.cursor()
|
||||||
|
cursor.execute(
|
||||||
|
"SELECT 1 FROM pg_user WHERE usename = %s",
|
||||||
|
(self.db_user,)
|
||||||
|
)
|
||||||
|
exists = cursor.fetchone() is not None
|
||||||
|
cursor.close()
|
||||||
|
return exists
|
||||||
|
|
||||||
|
def create_user(self) -> bool:
|
||||||
|
"""创建数据库用户"""
|
||||||
|
try:
|
||||||
|
if self.user_exists():
|
||||||
|
logger.info(f"ℹ️ 用户 '{self.db_user}' 已存在")
|
||||||
|
|
||||||
|
# 询问是否重置密码
|
||||||
|
response = input(f"是否重置用户 '{self.db_user}' 的密码? (yes/no): ")
|
||||||
|
if response.lower() in ['yes', 'y']:
|
||||||
|
cursor = self.conn.cursor()
|
||||||
|
cursor.execute(
|
||||||
|
sql.SQL("ALTER USER {} WITH PASSWORD %s").format(
|
||||||
|
sql.Identifier(self.db_user)
|
||||||
|
),
|
||||||
|
(self.db_password,)
|
||||||
|
)
|
||||||
|
cursor.close()
|
||||||
|
logger.info(f"✅ 用户密码已更新")
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
logger.info(f"👤 创建用户 '{self.db_user}'...")
|
||||||
|
cursor = self.conn.cursor()
|
||||||
|
cursor.execute(
|
||||||
|
sql.SQL("CREATE USER {} WITH PASSWORD %s").format(
|
||||||
|
sql.Identifier(self.db_user)
|
||||||
|
),
|
||||||
|
(self.db_password,)
|
||||||
|
)
|
||||||
|
cursor.close()
|
||||||
|
logger.info(f"✅ 用户创建成功")
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"❌ 创建用户失败: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def create_database(self) -> bool:
|
||||||
|
"""创建数据库"""
|
||||||
|
try:
|
||||||
|
if self.database_exists():
|
||||||
|
logger.info(f"ℹ️ 数据库 '{self.db_name}' 已存在")
|
||||||
|
|
||||||
|
# 询问是否删除重建
|
||||||
|
response = input(f"是否删除并重建数据库 '{self.db_name}'? (yes/no): ")
|
||||||
|
if response.lower() in ['yes', 'y']:
|
||||||
|
logger.warning(f"⚠️ 删除数据库 '{self.db_name}'...")
|
||||||
|
cursor = self.conn.cursor()
|
||||||
|
# 断开所有连接
|
||||||
|
cursor.execute(
|
||||||
|
sql.SQL("""
|
||||||
|
SELECT pg_terminate_backend(pg_stat_activity.pid)
|
||||||
|
FROM pg_stat_activity
|
||||||
|
WHERE pg_stat_activity.datname = %s
|
||||||
|
AND pid <> pg_backend_pid()
|
||||||
|
"""),
|
||||||
|
(self.db_name,)
|
||||||
|
)
|
||||||
|
cursor.execute(
|
||||||
|
sql.SQL("DROP DATABASE {}").format(
|
||||||
|
sql.Identifier(self.db_name)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
cursor.close()
|
||||||
|
logger.info(f"✅ 数据库已删除")
|
||||||
|
else:
|
||||||
|
return True
|
||||||
|
|
||||||
|
logger.info(f"🗄️ 创建数据库 '{self.db_name}'...")
|
||||||
|
cursor = self.conn.cursor()
|
||||||
|
cursor.execute(
|
||||||
|
sql.SQL("CREATE DATABASE {} OWNER {}").format(
|
||||||
|
sql.Identifier(self.db_name),
|
||||||
|
sql.Identifier(self.db_user)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
cursor.close()
|
||||||
|
logger.info(f"✅ 数据库创建成功")
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"❌ 创建数据库失败: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def grant_privileges(self) -> bool:
|
||||||
|
"""授予用户权限"""
|
||||||
|
try:
|
||||||
|
logger.info(f"🔐 授予用户权限...")
|
||||||
|
cursor = self.conn.cursor()
|
||||||
|
|
||||||
|
# 授予数据库所有权限
|
||||||
|
cursor.execute(
|
||||||
|
sql.SQL("GRANT ALL PRIVILEGES ON DATABASE {} TO {}").format(
|
||||||
|
sql.Identifier(self.db_name),
|
||||||
|
sql.Identifier(self.db_user)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
cursor.close()
|
||||||
|
logger.info(f"✅ 权限授予成功")
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"❌ 授予权限失败: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def update_env_file(self) -> bool:
|
||||||
|
"""更新.env文件"""
|
||||||
|
try:
|
||||||
|
env_file = Path(__file__).parent.parent / ".env"
|
||||||
|
|
||||||
|
database_url = (
|
||||||
|
f"postgresql+asyncpg://{self.db_user}:{self.db_password}"
|
||||||
|
f"@{self.host}:{self.port}/{self.db_name}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if env_file.exists():
|
||||||
|
# 读取现有内容
|
||||||
|
with open(env_file, 'r', encoding='utf-8') as f:
|
||||||
|
lines = f.readlines()
|
||||||
|
|
||||||
|
# 更新DATABASE_URL
|
||||||
|
updated = False
|
||||||
|
for i, line in enumerate(lines):
|
||||||
|
if line.startswith('DATABASE_URL='):
|
||||||
|
lines[i] = f"DATABASE_URL={database_url}\n"
|
||||||
|
updated = True
|
||||||
|
break
|
||||||
|
|
||||||
|
if not updated:
|
||||||
|
lines.append(f"\nDATABASE_URL={database_url}\n")
|
||||||
|
|
||||||
|
# 写回文件
|
||||||
|
with open(env_file, 'w', encoding='utf-8') as f:
|
||||||
|
f.writelines(lines)
|
||||||
|
else:
|
||||||
|
# 创建新文件
|
||||||
|
with open(env_file, 'w', encoding='utf-8') as f:
|
||||||
|
f.write(f"DATABASE_URL={database_url}\n")
|
||||||
|
|
||||||
|
logger.info(f"✅ .env 文件已更新")
|
||||||
|
logger.info(f" DATABASE_URL={database_url}")
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"❌ 更新.env文件失败: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def initialize_tables(self) -> bool:
|
||||||
|
"""初始化数据库表结构"""
|
||||||
|
try:
|
||||||
|
logger.info(f"📋 初始化数据库表结构...")
|
||||||
|
await init_db('system')
|
||||||
|
logger.info(f"✅ 表结构初始化成功")
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"❌ 初始化表结构失败: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
"""关闭数据库连接"""
|
||||||
|
if self.conn:
|
||||||
|
self.conn.close()
|
||||||
|
logger.info(f"🔌 已断开连接")
|
||||||
|
|
||||||
|
async def setup(self) -> bool:
|
||||||
|
"""执行完整设置流程"""
|
||||||
|
try:
|
||||||
|
# 1. 连接
|
||||||
|
if not self.connect_as_admin():
|
||||||
|
return False
|
||||||
|
|
||||||
|
# 2. 创建用户
|
||||||
|
if not self.create_user():
|
||||||
|
return False
|
||||||
|
|
||||||
|
# 3. 创建数据库
|
||||||
|
if not self.create_database():
|
||||||
|
return False
|
||||||
|
|
||||||
|
# 4. 授予权限
|
||||||
|
if not self.grant_privileges():
|
||||||
|
return False
|
||||||
|
|
||||||
|
# 5. 更新配置
|
||||||
|
if not self.update_env_file():
|
||||||
|
return False
|
||||||
|
|
||||||
|
# 6. 关闭管理员连接
|
||||||
|
self.close()
|
||||||
|
|
||||||
|
# 7. 初始化表结构
|
||||||
|
if not await self.initialize_tables():
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"❌ 设置过程出错: {e}")
|
||||||
|
return False
|
||||||
|
finally:
|
||||||
|
if self.conn:
|
||||||
|
self.close()
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
"""主函数"""
|
||||||
|
print("""
|
||||||
|
╔═══════════════════════════════════════════════════════════════╗
|
||||||
|
║ PostgreSQL 数据库自动设置工具 ║
|
||||||
|
║ ║
|
||||||
|
║ 此工具将自动完成: ║
|
||||||
|
║ 1. 连接到PostgreSQL服务器 ║
|
||||||
|
║ 2. 创建数据库和用户 ║
|
||||||
|
║ 3. 设置权限 ║
|
||||||
|
║ 4. 初始化表结构 ║
|
||||||
|
║ 5. 更新.env配置文件 ║
|
||||||
|
╚═══════════════════════════════════════════════════════════════╝
|
||||||
|
""")
|
||||||
|
|
||||||
|
# 获取配置
|
||||||
|
print("请输入PostgreSQL配置信息:\n")
|
||||||
|
|
||||||
|
host = input("主机地址 [localhost]: ").strip() or "localhost"
|
||||||
|
port = input("端口 [5432]: ").strip() or "5432"
|
||||||
|
port = int(port)
|
||||||
|
|
||||||
|
admin_user = input("管理员用户名 [postgres]: ").strip() or "postgres"
|
||||||
|
admin_password = getpass(f"管理员密码: ")
|
||||||
|
|
||||||
|
print("\n请输入要创建的数据库信息:\n")
|
||||||
|
db_name = input("数据库名 [mumuai_novel]: ").strip() or "mumuai_novel"
|
||||||
|
db_user = input("数据库用户名 [mumuai]: ").strip() or "mumuai"
|
||||||
|
db_password = getpass("数据库用户密码 [mumuai123]: ") or "mumuai123"
|
||||||
|
|
||||||
|
print(f"\n{'='*60}")
|
||||||
|
print(f"配置摘要:")
|
||||||
|
print(f" 服务器: {host}:{port}")
|
||||||
|
print(f" 数据库: {db_name}")
|
||||||
|
print(f" 用户: {db_user}")
|
||||||
|
print(f"{'='*60}\n")
|
||||||
|
|
||||||
|
response = input("确认开始设置? (yes/no): ")
|
||||||
|
if response.lower() not in ['yes', 'y']:
|
||||||
|
print("已取消设置")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 执行设置
|
||||||
|
setup = PostgreSQLSetup(
|
||||||
|
host=host,
|
||||||
|
port=port,
|
||||||
|
admin_user=admin_user,
|
||||||
|
admin_password=admin_password,
|
||||||
|
db_name=db_name,
|
||||||
|
db_user=db_user,
|
||||||
|
db_password=db_password
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"\n{'='*60}")
|
||||||
|
success = await setup.setup()
|
||||||
|
print(f"{'='*60}\n")
|
||||||
|
|
||||||
|
if success:
|
||||||
|
print("🎉 PostgreSQL设置完成!\n")
|
||||||
|
print("下一步:")
|
||||||
|
print("1. 启动应用: python -m app.main")
|
||||||
|
print("2. 访问: http://localhost:8000")
|
||||||
|
print("3. 查看API文档: http://localhost:8000/docs")
|
||||||
|
else:
|
||||||
|
print("❌ 设置过程中出现错误,请检查日志")
|
||||||
|
print("\n故障排查:")
|
||||||
|
print("1. 确认PostgreSQL服务正在运行")
|
||||||
|
print("2. 检查管理员用户名和密码")
|
||||||
|
print("3. 查看PostgreSQL日志")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(main())
|
||||||
+74
-5
@@ -1,15 +1,67 @@
|
|||||||
services:
|
services:
|
||||||
|
postgres:
|
||||||
|
image: postgres:18-alpine
|
||||||
|
container_name: mumuainovel-postgres
|
||||||
|
environment:
|
||||||
|
POSTGRES_DB: mumuai_novel
|
||||||
|
POSTGRES_USER: mumuai
|
||||||
|
POSTGRES_PASSWORD: ${POSTGRES_PASSWORD:-mumuai_password_2024}
|
||||||
|
POSTGRES_INITDB_ARGS: "--encoding=UTF8 --locale=C"
|
||||||
|
TZ: Asia/Shanghai
|
||||||
|
volumes:
|
||||||
|
- postgres_data:/var/lib/postgresql/data
|
||||||
|
- ./backend/scripts/init_postgres.sql:/docker-entrypoint-initdb.d/init.sql:ro
|
||||||
|
ports:
|
||||||
|
- "5432:5432"
|
||||||
|
restart: unless-stopped
|
||||||
|
healthcheck:
|
||||||
|
test: ["CMD-SHELL", "pg_isready -U mumuai -d mumuai_novel"]
|
||||||
|
interval: 10s
|
||||||
|
timeout: 5s
|
||||||
|
retries: 5
|
||||||
|
start_period: 10s
|
||||||
|
networks:
|
||||||
|
- ai-story-network
|
||||||
|
command:
|
||||||
|
- postgres
|
||||||
|
- -c
|
||||||
|
- max_connections=200
|
||||||
|
- -c
|
||||||
|
- shared_buffers=256MB
|
||||||
|
- -c
|
||||||
|
- effective_cache_size=1GB
|
||||||
|
- -c
|
||||||
|
- maintenance_work_mem=64MB
|
||||||
|
- -c
|
||||||
|
- checkpoint_completion_target=0.9
|
||||||
|
- -c
|
||||||
|
- wal_buffers=16MB
|
||||||
|
- -c
|
||||||
|
- default_statistics_target=100
|
||||||
|
- -c
|
||||||
|
- random_page_cost=1.1
|
||||||
|
- -c
|
||||||
|
- effective_io_concurrency=200
|
||||||
|
- -c
|
||||||
|
- work_mem=4MB
|
||||||
|
- -c
|
||||||
|
- min_wal_size=1GB
|
||||||
|
- -c
|
||||||
|
- max_wal_size=4GB
|
||||||
|
|
||||||
mumuainovel:
|
mumuainovel:
|
||||||
build:
|
build:
|
||||||
context: .
|
context: .
|
||||||
dockerfile: Dockerfile
|
dockerfile: Dockerfile
|
||||||
image: mumujie/mumuainovel:latest
|
image: mumujie/mumuainovel:latest
|
||||||
container_name: mumuainovel
|
container_name: mumuainovel
|
||||||
|
depends_on:
|
||||||
|
postgres:
|
||||||
|
condition: service_healthy
|
||||||
ports:
|
ports:
|
||||||
- "8000:8000"
|
- "8000:8000"
|
||||||
volumes:
|
volumes:
|
||||||
# 持久化数据库和日志
|
# 持久化日志
|
||||||
- ./data:/app/data
|
|
||||||
- ./logs:/app/logs
|
- ./logs:/app/logs
|
||||||
# 挂载环境变量文件(可选)
|
# 挂载环境变量文件(可选)
|
||||||
- ./.env:/app/.env:ro
|
- ./.env:/app/.env:ro
|
||||||
@@ -21,8 +73,20 @@ services:
|
|||||||
- APP_PORT=8000
|
- APP_PORT=8000
|
||||||
- DEBUG=false
|
- DEBUG=false
|
||||||
|
|
||||||
# 重要:环境变量会从 .env 文件自动加载
|
# 数据库配置(使用PostgreSQL)
|
||||||
# 也可以在这里显式设置,优先级:此处设置 > .env 文件
|
- DATABASE_URL=postgresql+asyncpg://mumuai:${POSTGRES_PASSWORD:-mumuai_password_2024}@postgres:5432/mumuai_novel
|
||||||
|
|
||||||
|
# PostgreSQL连接池配置
|
||||||
|
- DATABASE_POOL_SIZE=30
|
||||||
|
- DATABASE_MAX_OVERFLOW=20
|
||||||
|
- DATABASE_POOL_TIMEOUT=60
|
||||||
|
- DATABASE_POOL_RECYCLE=1800
|
||||||
|
- DATABASE_POOL_PRE_PING=True
|
||||||
|
- DATABASE_POOL_USE_LIFO=True
|
||||||
|
|
||||||
|
- HTTP_PROXY=http://172.16.66.175:7890
|
||||||
|
- HTTPS_PROXY=http://172.16.66.175:7890
|
||||||
|
- NO_PROXY=localhost,127.0.0.1
|
||||||
|
|
||||||
# AI服务配置(建议在 .env 文件中设置)
|
# AI服务配置(建议在 .env 文件中设置)
|
||||||
# - OPENAI_API_KEY=${OPENAI_API_KEY}
|
# - OPENAI_API_KEY=${OPENAI_API_KEY}
|
||||||
@@ -41,10 +105,15 @@ services:
|
|||||||
interval: 30s
|
interval: 30s
|
||||||
timeout: 10s
|
timeout: 10s
|
||||||
retries: 3
|
retries: 3
|
||||||
start_period: 10s
|
start_period: 30s
|
||||||
networks:
|
networks:
|
||||||
- ai-story-network
|
- ai-story-network
|
||||||
|
|
||||||
|
volumes:
|
||||||
|
postgres_data:
|
||||||
|
driver: local
|
||||||
|
|
||||||
networks:
|
networks:
|
||||||
ai-story-network:
|
ai-story-network:
|
||||||
driver: bridge
|
driver: bridge
|
||||||
|
|
||||||
|
|||||||
@@ -67,6 +67,14 @@ export default function ChapterAnalysis({ chapterId, visible, onClose }: Chapter
|
|||||||
}
|
}
|
||||||
|
|
||||||
const taskData: AnalysisTask = await response.json();
|
const taskData: AnalysisTask = await response.json();
|
||||||
|
|
||||||
|
// 如果状态为 none(无任务),设置 task 为 null,让前端显示"开始分析"按钮
|
||||||
|
if (taskData.status === 'none' || !taskData.has_task) {
|
||||||
|
setTask(null);
|
||||||
|
setError(null); // 清除错误,这不是错误状态
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
setTask(taskData);
|
setTask(taskData);
|
||||||
|
|
||||||
if (taskData.status === 'completed') {
|
if (taskData.status === 'completed') {
|
||||||
|
|||||||
@@ -321,6 +321,7 @@ export default function Chapters() {
|
|||||||
setAnalysisTasksMap(prev => ({
|
setAnalysisTasksMap(prev => ({
|
||||||
...prev,
|
...prev,
|
||||||
[editingId]: {
|
[editingId]: {
|
||||||
|
has_task: true,
|
||||||
task_id: taskId,
|
task_id: taskId,
|
||||||
chapter_id: editingId,
|
chapter_id: editingId,
|
||||||
status: 'pending',
|
status: 'pending',
|
||||||
|
|||||||
@@ -390,14 +390,16 @@ export interface ApiError {
|
|||||||
|
|
||||||
// 章节分析任务相关类型
|
// 章节分析任务相关类型
|
||||||
export interface AnalysisTask {
|
export interface AnalysisTask {
|
||||||
task_id: string;
|
has_task: boolean;
|
||||||
|
task_id: string | null;
|
||||||
chapter_id: string;
|
chapter_id: string;
|
||||||
status: 'pending' | 'running' | 'completed' | 'failed';
|
status: 'pending' | 'running' | 'completed' | 'failed' | 'none';
|
||||||
progress: number;
|
progress: number;
|
||||||
error_message?: string;
|
error_message?: string | null;
|
||||||
created_at?: string;
|
auto_recovered?: boolean;
|
||||||
started_at?: string;
|
created_at?: string | null;
|
||||||
completed_at?: string;
|
started_at?: string | null;
|
||||||
|
completed_at?: string | null;
|
||||||
}
|
}
|
||||||
|
|
||||||
// 分析结果 - 钩子
|
// 分析结果 - 钩子
|
||||||
|
|||||||
Reference in New Issue
Block a user