This commit is contained in:
xiamuceer
2026-01-09 19:28:58 +08:00
188 changed files with 48882 additions and 8566 deletions
+1
View File
@@ -39,6 +39,7 @@ Thumbs.db
# 数据库文件(不包含在镜像中)
data/*.db
backend/data/*.db
postgres_data/
# ChromaDB数据库(不包含在镜像中,会在运行时生成)
backend/data/chroma_db/
+1 -1
View File
@@ -1 +1 @@
*.safetensors filter=lfs diff=lfs merge=lfs -text
# LFS tracking removed - models downloaded from HuggingFace at runtime
+16 -12
View File
@@ -32,9 +32,11 @@ WORKDIR /app
RUN sed -i 's/deb.debian.org/mirrors.aliyun.com/g' /etc/apt/sources.list.d/debian.sources \
&& sed -i 's/security.debian.org/mirrors.aliyun.com/g' /etc/apt/sources.list.d/debian.sources
# 安装系统依赖
# 安装系统依赖(添加数据库工具)
RUN apt-get update && apt-get install -y \
gcc \
postgresql-client \
netcat-traditional \
&& rm -rf /var/lib/apt/lists/*
# 复制后端依赖文件
@@ -46,21 +48,23 @@ RUN pip install --no-cache-dir torch==2.7.0 --index-url https://download.pytorch
# 再安装其他Python依赖(使用阿里云镜像加速)
RUN pip install --no-cache-dir -r requirements.txt -i https://mirrors.aliyun.com/pypi/simple/
# 复制后端代码
# 复制后端代码(包含embedding模型)
COPY backend/ ./
# 从前端构建阶段复制构建好的静态文件
COPY --from=frontend-builder /frontend/dist ./static
# 复制 Alembic 迁移配置和脚本(PostgreSQL
COPY backend/alembic-postgres.ini ./alembic.ini
COPY backend/alembic/postgres ./alembic
COPY backend/scripts/entrypoint.sh /app/entrypoint.sh
COPY backend/scripts/migrate.py ./scripts/migrate.py
# 赋予执行权限
RUN chmod +x /app/entrypoint.sh
# 创建必要的目录
RUN mkdir -p /app/data /app/logs /app/embedding
# 复制预下载的Embedding模型到独立目录(避免被docker-compose的data挂载覆盖)
# 这样可以避免首次运行时联网下载约420MB的模型文件
COPY backend/embedding /app/embedding
# 复制环境变量示例文件
COPY backend/.env.example ./.env.example
RUN mkdir -p /app/data /app/logs
# 暴露端口
EXPOSE 8000
@@ -80,5 +84,5 @@ ENV SENTENCE_TRANSFORMERS_HOME=/app/embedding
HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
CMD python -c "import urllib.request; urllib.request.urlopen('http://localhost:8000/health')" || exit 1
# 启动命令
CMD ["python", "-m", "uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"]
# 使用 entrypoint 脚本启动(自动执行迁移)
ENTRYPOINT ["/app/entrypoint.sh"]
+444 -417
View File
@@ -2,16 +2,47 @@
<div align="center">
![Version](https://img.shields.io/badge/version-1.0.0-blue.svg)
![Version](https://img.shields.io/badge/version-1.2.5-blue.svg)
![Python](https://img.shields.io/badge/python-3.11-blue.svg)
![FastAPI](https://img.shields.io/badge/FastAPI-0.109.0-green.svg)
![React](https://img.shields.io/badge/react-18.3.1-blue.svg)
![TypeScript](https://img.shields.io/badge/typescript-5.9.3-blue.svg)
![License](https://img.shields.io/badge/license-GPL%20v3-blue.svg)
**一款基于 AI 的智能小说创作助手,帮助你轻松创作精彩故事**
**基于 AI 的智能小说创作助手**
[特性](#-特性) • [快速开始](#-快速开始) • [部署方式](#-部署方式) • [配置说明](#%EF%B8%8F-配置说明) • [项目结构](#-项目结构)
[特性](#-特性) • [快速开始](#-快速开始) • [配置说明](#%EF%B8%8F-配置说明) • [项目结构](#-项目结构)
</div>
---
<div align="center">
## 💖 支持项目
如果这个项目对你有帮助,欢迎通过以下方式支持开发:
**[☕ 请我喝杯咖啡](https://mumuverse.space:1588/)**
### 🎁 赞助专属权益
| 权益 | 说明 |
|------|------|
| 📋 **优先需求响应** | 您的功能需求和问题反馈将获得优先处理 |
| 🚀 **Windows一键启动** | 获取免安装EXE程序,双击即可使用 |
| 💬 **专属技术支持** | 加入赞助者内部群,获得远程协助和配置指导 |
### ☕ 赞助金额
| 金额 | 描述 |
|------|------|
| ¥5 | 🌶️ 一包辣条 |
| ¥10 | 🍱 一顿拼好饭 |
| ¥20 | 🧋 一杯咖啡 |
| ¥50 | 🍖 一次烧烤 |
| ¥99 | 🍲 一顿海底捞 |
您的支持是我持续开发的动力!🙏
</div>
@@ -19,116 +50,387 @@
## ✨ 特性
- 🤖 **多 AI 模型支持** - 支持 OpenAI、Google Gemini、Anthropic Claude 等主流 AI 模型
- 📝 **智能向导** - 通过向导式引导快速创建小说项目,AI 自动生成大纲、角色和世界观
- 👥 **角色管理** - 创建和管理小说角色,包括人物关系、组织架构
- 📖 **章节编辑** - 支持章节的创建、编辑、重新生成和润色功能
- 🌐 **世界观设定** - 构建完整的故事世界观和背景设定
- 🔐 **多种登录方式** - 支持 LinuxDO OAuth 登录和本地账户登录
- 🐳 **Docker 部署** - 一键部署,开箱即用
- 💾 **数据持久化** - 基于 SQLite 的本地数据存储,支持多用户隔离
- 🎨 **现代化 UI** - 基于 Ant Design 的美观界面,响应式设计
- 🤖 **多 AI 模型** - 支持 OpenAI、Gemini、Claude 等主流模型
- 📝 **智能向导** - AI 自动生成大纲、角色和世界观
- 👥 **角色管理** - 人物关系、组织架构可视化管理
- 📖 **章节编辑** - 支持创建、编辑、重新生成和润色
- 🌐 **世界观设定** - 构建完整的故事背景
- 🔐 **多种登录** - LinuxDO OAuth 本地账户登录
- 💾 **PostgreSQL** - 生产级数据库,多用户数据隔离
- 🐳 **Docker 部署** - 一键启动,开箱即用
## 📸 项目预览
<details>
<summary>多图预警</summary>
<div align="center">
### 登录界面
![登录界面](images/1.png)
### 主界面
![主界面](images/2.png)
### 项目管理
![项目管理](images/3.png)
### 赞助我 💖
![赞助我](images/4.png)
</div>
</details>
## 📋 TODO List
以下是正在规划和开发中的功能:
### ✅ 已完成功能
- [ ] **灵感模式** - 提供创作灵感和点子生成功能
- [✔] **自定义写作风格** - 支持自定义AI写作风格和语言风格
- [ ] **支持数据导入导出** - 支持项目数据的导入导出功能
- [ ] **添加prompt调整界面** - 提供可视化的prompt模板编辑和调整界面
- [✔] **开放章节内容字数限制** - 支持用户在生成章节内容时设置字数 @wyf007
- [ ] **设定追溯与矛盾检测** - 对大纲、世界观、角色档案中的设定支持悬停查看注释,显示相关章节来源和佐证原文;自动检测新章节与已有设定的矛盾(吃书),标记为"矛盾"设定并提供解决建议,当新设定解决矛盾后自动更新注释说明 @lulujiang
- [ ] **思维链与章节关系图谱** - 为每章建立思维链,总结与上文的逻辑关系、明暗线发展;可选的章节关系满图功能,自动识别和标注伏笔埋设与揭晓、角色出场与呼应等内在联系,帮助提升小说结构的紧密性和连贯性 @lulujiang
- [x] **灵感模式** - 创作灵感和点子生成
- [x] **自定义写作风格** - 支持自定义 AI 写作风格
- [x] **数据导入导出** - 项目数据的导入导出
- [x] **Prompt 调整界面** - 可视化编辑 Prompt 模板
- [x] **章节字数限制** - 用户可设置生成字数
- [x] **思维链与章节关系图谱** - 可视化章节逻辑关系
- [x] **根据分析一键重写** - 根据分析建议重新生成
- [x] **Linux DO 自动创建账号** - OAuth 登录自动生成账号
- [x] **职业等级体系** - 自定义职业和等级系统,支持修仙境界、魔法等级等多种体系
- [x] **角色/组织卡片导入导出** - 单独导出角色和组织卡片,支持跨项目数据共享
> 💡 如果你有其他功能建议,欢迎提交 Issue 或 Pull Request
### 📝 规划中功能
- [ ] **伏笔管理** - 智能追踪剧情伏笔,提醒未回收线索,可视化伏笔时间线
- [ ] **提示词工坊** - 社区驱动的 Prompt 模板分享平台,一键导入优质提示词
> 💡 欢迎提交 Issue 或 Pull Request
## 🚀 快速开始
### 前置要求
- **Docker 部署**Docker 和 Docker Compose
- **本地开发**Python 3.11+ 和 Node.js 18+
- **必需**:至少一个 AI 服务的 API KeyOpenAI/Gemini/Anthropic
- Docker 和 Docker Compose
- 至少一个 AI 服务的 API KeyOpenAI/Gemini/Claude
### 方式一:从源码构建 Docker 镜像
### Docker Compose 部署(推荐)
```bash
# 1. 克隆项目
git clone https://github.com/xiamuceer-j/MuMuAINovel.git
cd MuMuAINovel
# 2. 配置环境变量
# 2. 配置环境变量(必需)
cp backend/.env.example .env
# 编辑 .env 文件,填入你的 API Keys
# 编辑 .env 文件,填入必要配置(API Key、数据库密码等)
# 3. 启动服务(会自动构建镜像)
# 3. 确保文件准备完整
# ⚠️ 重要:确保以下文件存在
# - .env(配置文件,必需挂载到容器)
# - backend/scripts/init_postgres.sql(数据库初始化脚本)
# 4. 启动服务
docker-compose up -d
# 4. 访问应用
# 5. 访问应用
# 打开浏览器访问 http://localhost:8000
```
### 方式二:本地开发
> **📌 注意事项**
>
> 1. **`.env` 文件挂载**: `docker-compose.yml` 会自动将 `.env` 挂载到容器,确保文件存在
> 2. **数据库初始化**: `init_postgres.sql` 会在首次启动时自动执行,安装必要的PostgreSQL扩展
> 3. **自行构建**: 如需从源码构建,请先下载 embedding 模型文件([加群获取](frontend/public/qq.jpg)
#### 后端设置
### 使用 Docker Hub 镜像(推荐新手)
```bash
# 进入后端目录
cd backend
# 创建虚拟环境
python -m venv .venv
# 激活虚拟环境
# Windows:
.venv\Scripts\activate
# Linux/Mac:
source .venv/bin/activate
# 安装依赖
pip install -r requirements.txt
# 配置环境变量
cp .env.example .env
# 编辑 .env 文件,填入你的配置
# 启动后端服务
python -m uvicorn app.main:app --host 0.0.0.0 --port 8000 --reload
```
#### 前端设置
```bash
# 进入前端目录
cd frontend
# 安装依赖
npm install
# 开发模式(需要后端已启动)
npm run dev
# 或构建生产版本
npm run build
```
## 🐳 部署方式
### Docker Compose 部署
#### 使用 Docker Hub 镜像(推荐)
项目已发布到 Docker Hub,可直接拉取使用:
```bash
# 查看可用版本
# 1. 拉取最新镜像(已包含模型文件)
docker pull mumujie/mumuainovel:latest
# 2. 创建 docker-compose.yml(点击下方展开查看完整配置)
```
<details>
<summary>📄 点击展开 docker-compose.yml 完整配置</summary>
```yaml
services:
postgres:
image: postgres:18-alpine
container_name: mumuainovel-postgres
environment:
POSTGRES_DB: ${POSTGRES_DB:-mumuai_novel}
POSTGRES_USER: ${POSTGRES_USER:-mumuai}
POSTGRES_PASSWORD: ${POSTGRES_PASSWORD:-123456}
POSTGRES_INITDB_ARGS: "--encoding=UTF8 --locale=C"
TZ: ${TZ:-Asia/Shanghai}
volumes:
- postgres_data:/var/lib/postgresql/data
- ./backend/scripts/init_postgres.sql:/docker-entrypoint-initdb.d/init.sql:ro
ports:
- "${POSTGRES_PORT:-5432}:5432"
restart: unless-stopped
healthcheck:
test: ["CMD-SHELL", "pg_isready -U ${POSTGRES_USER:-mumuai} -d ${POSTGRES_DB:-mumuai_novel}"]
interval: 10s
timeout: 5s
retries: 5
start_period: 10s
networks:
- ai-story-network
command:
- postgres
- -c
- max_connections=${POSTGRES_MAX_CONNECTIONS:-200}
- -c
- shared_buffers=${POSTGRES_SHARED_BUFFERS:-256MB}
- -c
- effective_cache_size=${POSTGRES_EFFECTIVE_CACHE_SIZE:-1GB}
- -c
- maintenance_work_mem=${POSTGRES_MAINTENANCE_WORK_MEM:-64MB}
- -c
- checkpoint_completion_target=${POSTGRES_CHECKPOINT_COMPLETION_TARGET:-0.9}
- -c
- wal_buffers=${POSTGRES_WAL_BUFFERS:-16MB}
- -c
- default_statistics_target=${POSTGRES_DEFAULT_STATISTICS_TARGET:-100}
- -c
- random_page_cost=${POSTGRES_RANDOM_PAGE_COST:-1.1}
- -c
- effective_io_concurrency=${POSTGRES_EFFECTIVE_IO_CONCURRENCY:-200}
- -c
- work_mem=${POSTGRES_WORK_MEM:-4MB}
- -c
- min_wal_size=${POSTGRES_MIN_WAL_SIZE:-1GB}
- -c
- max_wal_size=${POSTGRES_MAX_WAL_SIZE:-4GB}
mumuainovel:
image: mumujie/mumuainovel:latest
container_name: mumuainovel
depends_on:
postgres:
condition: service_healthy
ports:
- "${APP_PORT:-8000}:8000"
volumes:
- ./logs:/app/logs
- ./.env:/app/.env:ro
environment:
# 应用配置
- APP_NAME=${APP_NAME:-MuMuAINovel}
- APP_VERSION=${APP_VERSION:-1.0.0}
- APP_HOST=${APP_HOST:-0.0.0.0}
- APP_PORT=8000
- DEBUG=${DEBUG:-false}
# 数据库配置
- DATABASE_URL=postgresql+asyncpg://${POSTGRES_USER:-mumuai}:${POSTGRES_PASSWORD:-123456}@postgres:5432/${POSTGRES_DB:-mumuai_novel}
- DB_HOST=postgres
- DB_PORT=5432
- POSTGRES_PASSWORD=${POSTGRES_PASSWORD:-123456}
# PostgreSQL 连接池配置
- DATABASE_POOL_SIZE=${DATABASE_POOL_SIZE:-30}
- DATABASE_MAX_OVERFLOW=${DATABASE_MAX_OVERFLOW:-20}
- DATABASE_POOL_TIMEOUT=${DATABASE_POOL_TIMEOUT:-60}
- DATABASE_POOL_RECYCLE=${DATABASE_POOL_RECYCLE:-1800}
- DATABASE_POOL_PRE_PING=${DATABASE_POOL_PRE_PING:-True}
- DATABASE_POOL_USE_LIFO=${DATABASE_POOL_USE_LIFO:-True}
# 代理配置(可选)
- HTTP_PROXY=${HTTP_PROXY:-}
- HTTPS_PROXY=${HTTPS_PROXY:-}
- NO_PROXY=${NO_PROXY:-localhost,127.0.0.1}
# AI 服务配置
- OPENAI_API_KEY=${OPENAI_API_KEY:-}
- OPENAI_BASE_URL=${OPENAI_BASE_URL:-https://api.openai.com/v1}
- GEMINI_API_KEY=${GEMINI_API_KEY:-}
- GEMINI_BASE_URL=${GEMINI_BASE_URL:-}
- ANTHROPIC_API_KEY=${ANTHROPIC_API_KEY:-}
- ANTHROPIC_BASE_URL=${ANTHROPIC_BASE_URL:-}
- DEFAULT_AI_PROVIDER=${DEFAULT_AI_PROVIDER:-openai}
- DEFAULT_MODEL=${DEFAULT_MODEL:-gpt-4o-mini}
- DEFAULT_TEMPERATURE=${DEFAULT_TEMPERATURE:-0.7}
- DEFAULT_MAX_TOKENS=${DEFAULT_MAX_TOKENS:-32000}
# LinuxDO OAuth 配置
- LINUXDO_CLIENT_ID=${LINUXDO_CLIENT_ID:-11111}
- LINUXDO_CLIENT_SECRET=${LINUXDO_CLIENT_SECRET:-11111}
- LINUXDO_REDIRECT_URI=${LINUXDO_REDIRECT_URI:-http://localhost:8000/api/auth/linuxdo/callback}
- FRONTEND_URL=${FRONTEND_URL:-http://localhost:8000}
# 本地账户登录配置
- LOCAL_AUTH_ENABLED=${LOCAL_AUTH_ENABLED:-true}
- LOCAL_AUTH_USERNAME=${LOCAL_AUTH_USERNAME:-admin}
- LOCAL_AUTH_PASSWORD=${LOCAL_AUTH_PASSWORD:-admin123}
- LOCAL_AUTH_DISPLAY_NAME=${LOCAL_AUTH_DISPLAY_NAME:-本地管理员}
# 会话配置
- SESSION_EXPIRE_MINUTES=${SESSION_EXPIRE_MINUTES:-120}
- SESSION_REFRESH_THRESHOLD_MINUTES=${SESSION_REFRESH_THRESHOLD_MINUTES:-30}
restart: unless-stopped
healthcheck:
test: ["CMD", "python", "-c", "import urllib.request; urllib.request.urlopen('http://localhost:8000/health')"]
interval: 30s
timeout: 10s
retries: 3
start_period: 30s
networks:
- ai-story-network
volumes:
postgres_data:
driver: local
networks:
ai-story-network:
driver: bridge
```
</details>
```bash
# 3. 启动服务
docker-compose up -d
# 4. 查看日志
docker-compose logs -f
# 5. 更新到最新版本
docker-compose pull
docker-compose up -d
```
> **💡 提示**: Docker Hub 镜像已包含所有依赖和模型文件,无需额外下载
### 本地开发 / 从源码构建
#### 前置准备
```bash
# ⚠️ 重要:如果从源码构建,需要先下载 embedding 模型文件
# 模型文件较大(约 400MB),需放置到以下目录:
# backend/embedding/models--sentence-transformers--paraphrase-multilingual-MiniLM-L12-v2/
#
# 📥 获取方式:
# - 加入项目 QQ 群或 Linux DO 讨论区获取下载链接
# - 群号:见项目主页
# - Linux DOhttps://linux.do/t/topic/1100112
```
#### 后端
```bash
cd backend
python -m venv .venv
source .venv/bin/activate # Windows: .venv\Scripts\activate
pip install -r requirements.txt
# 配置 .env 文件
cp .env.example .env
# 编辑 .env 填入必要配置
# 启动 PostgreSQL(可使用 Docker
docker run -d --name postgres \
-e POSTGRES_PASSWORD=your_password \
-e POSTGRES_DB=mumuai_novel \
-p 5432:5432 \
postgres:18-alpine
# 启动后端
python -m uvicorn app.main:app --host localhost --port 8000 --reload
```
#### 前端
```bash
cd frontend
npm install
npm run dev # 开发模式
npm run build # 生产构建
```
## ⚙️ 配置说明
### 必需配置
创建 `.env` 文件:
```bash
# PostgreSQL 数据库(必需)
DATABASE_URL=postgresql+asyncpg://mumuai:your_password@postgres:5432/mumuai_novel
POSTGRES_PASSWORD=your_secure_password
# AI 服务
OPENAI_API_KEY=your_openai_key
OPENAI_BASE_URL=https://api.openai.com/v1
DEFAULT_AI_PROVIDER=openai
DEFAULT_MODEL=gpt-4o-mini
# 本地账户登录
LOCAL_AUTH_ENABLED=true
LOCAL_AUTH_USERNAME=admin
LOCAL_AUTH_PASSWORD=your_password
```
### 可选配置
```bash
# LinuxDO OAuth
LINUXDO_CLIENT_ID=your_client_id
LINUXDO_CLIENT_SECRET=your_client_secret
LINUXDO_REDIRECT_URI=http://localhost:8000/api/auth/callback
# PostgreSQL 连接池(高并发优化)
DATABASE_POOL_SIZE=30
DATABASE_MAX_OVERFLOW=20
```
### 中转 API 配置
支持所有 OpenAI 兼容格式的中转服务:
```bash
# New API 示例
OPENAI_API_KEY=sk-xxxxxxxx
OPENAI_BASE_URL=https://api.new-api.com/v1
# 其他中转服务
OPENAI_BASE_URL=https://your-proxy-service.com/v1
```
## 🐳 Docker 部署详情
### 服务架构
- **postgres**: PostgreSQL 18 数据库
- 端口: 5432
- 数据持久化: `postgres_data` volume
- 初始化脚本: `backend/scripts/init_postgres.sql`(自动挂载)
- 优化配置: 支持 80-150 并发用户
- **mumuainovel**: 主应用服务
- 端口: 8000
- 日志目录: `./logs`
- 配置挂载: `.env` 文件
- 自动等待数据库就绪
- 健康检查: 每 30 秒检测一次
### 重要文件说明
| 文件 | 说明 | 是否必需 |
|------|------|---------|
| `.env` | 环境配置(API Key、数据库密码等) | ✅ 必需 |
| `docker-compose.yml` | 服务编排配置 | ✅ 必需 |
| `backend/scripts/init_postgres.sql` | PostgreSQL 扩展安装脚本 | ✅ 自动挂载 |
| `backend/embedding/models--*/` | Embedding 模型文件 | ⚠️ 自建需要 |
> **注意**: 使用 Docker Hub 镜像时,模型文件已包含在镜像中,无需额外下载
### 常用命令
```bash
# 启动服务
docker-compose up -d
# 查看状态
docker-compose ps
# 查看日志
docker-compose logs -f
@@ -138,346 +440,67 @@ docker-compose down
# 重启服务
docker-compose restart
# 更新到最新版本
docker-compose pull
docker-compose up -d
# 查看资源使用
docker stats
```
#### Docker Compose 配置文件示例
### 数据持久化
使用 Docker Hub 镜像的完整配置:
- `./postgres_data` - PostgreSQL 数据库文件
- `./logs` - 应用日志文件
### 端口配置
修改 `docker-compose.yml` 中的端口映射:
```yaml
services:
ai-story:
image: mumujie/mumuainovel:latest
container_name: mumuainovel
ports:
- "8800:8000" # 宿主机端口:容器端口
volumes:
# 持久化数据库和日志
- ./data:/app/data
- ./logs:/app/logs
# 挂载环境变量文件
- ./.env:/app/.env:ro
environment:
- APP_NAME=mumuainovel
- APP_VERSION=1.0.0
- APP_HOST=0.0.0.0
- APP_PORT=8000
- DEBUG=false
# 其他环境变量会从 .env 文件自动加载
restart: unless-stopped
healthcheck:
test: ["CMD", "python", "-c", "import urllib.request; urllib.request.urlopen('http://localhost:8000/health')"]
interval: 30s
timeout: 10s
retries: 3
start_period: 10s
networks:
- ai-story-network
networks:
ai-story-network:
driver: bridge
```
### 生产环境部署建议
#### 1. 环境变量配置
**必需配置**
- `OPENAI_API_KEY``GEMINI_API_KEY`:至少配置一个 AI 服务
- `LOCAL_AUTH_PASSWORD`:修改为强密码
**推荐配置**
- `OPENAI_BASE_URL`:如果使用中转 API,修改为中转服务地址
- `DEFAULT_AI_PROVIDER`:根据你的 API Key 选择 `openai``gemini``anthropic`
- `DEFAULT_MODEL`:选择合适的模型(如 `gpt-4o-mini``gemini-2.0-flash-exp`
#### 2. 数据持久化
数据目录已通过 volume 挂载,数据不会丢失:
- `./data`SQLite 数据库文件
- `./logs`:应用日志文件
#### 3. 端口配置
默认端口映射:`8800:8000`
- 宿主机端口:`8800`(可自定义修改)
- 容器内端口:`8000`(固定,不要修改)
访问地址:`http://your-server-ip:8800`
配置后记得更新 `.env` 中的 `LINUXDO_REDIRECT_URI``FRONTEND_URL`
#### 5. 资源限制(可选)
`docker-compose.yml` 中添加资源限制:
```yaml
services:
ai-story:
# ... 其他配置
deploy:
resources:
limits:
cpus: '2.0'
memory: 2G
reservations:
cpus: '0.5'
memory: 512M
```
### 端口说明
- **默认端口**`8800`(宿主机)→ `8000`(容器)
- **可自定义**:修改 docker-compose.yml 中的 `ports` 配置
- **健康检查**:容器内部使用 `8000` 端口进行健康检查
## ⚙️ 配置说明
### 环境变量
创建 `.env` 文件并配置以下变量:
```bash
# ===== AI 服务配置(必填)=====
# OpenAI 配置(支持官方API和中转API)
OPENAI_API_KEY=your_openai_key_here
OPENAI_BASE_URL=https://api.openai.com/v1
# Anthropic 配置
# ANTHROPIC_API_KEY=your_anthropic_key_here
# ANTHROPIC_BASE_URL=https://api.anthropic.com
# 中转API配置示例(使用OpenAI格式)
# New API 中转服务
# OPENAI_API_KEY=your_newapi_key_here
# OPENAI_BASE_URL=https://api.new-api.com/v1
# 默认 AI 提供商和模型
DEFAULT_AI_PROVIDER=openai
DEFAULT_MODEL=gpt-4o-mini
DEFAULT_TEMPERATURE=0.8
DEFAULT_MAX_TOKENS=32000
# ===== 应用配置 =====
APP_NAME=MuMuAINovel
APP_VERSION=1.0.0
APP_HOST=0.0.0.0
APP_PORT=8000
DEBUG=false
# ===== LinuxDO OAuth 配置(可选)=====
LINUXDO_CLIENT_ID=your_client_id_here
LINUXDO_CLIENT_SECRET=your_client_secret_here
LINUXDO_REDIRECT_URI=http://localhost:8000/api/auth/callback
FRONTEND_URL=http://localhost:8000
# ===== 本地账户登录配置 =====
LOCAL_AUTH_ENABLED=true
LOCAL_AUTH_USERNAME=admin
LOCAL_AUTH_PASSWORD=your_secure_password_here
LOCAL_AUTH_DISPLAY_NAME=管理员
# 会话配置
# 会话过期时间(分钟),默认120分钟(2小时)
SESSION_EXPIRE_MINUTES=120
# 会话刷新阈值(分钟),剩余时间少于此值时可刷新,默认30分钟
SESSION_REFRESH_THRESHOLD_MINUTES=30
# ===== CORS 配置(生产环境)=====
# CORS_ORIGINS=https://your-domain.com,https://www.your-domain.com
```
### AI 模型配置
项目支持多个 AI 提供商,你可以根据需要配置:
| 提供商 | 推荐模型 | 用途 |
|--------|---------|------|
| OpenAI | gpt-4, gpt-3.5-turbo | 高质量文本生成 |
| Anthropic | claude-3-opus, claude-3-sonnet | 长文本创作 |
#### 使用中转API服务
如果你无法直接访问 OpenAI 官方 API,或者想使用更经济实惠的中转服务,本项目完全支持各种 OpenAI 兼容格式的中转 API
##### 配置方法
只需修改 `.env` 文件中的两个参数:
```bash
# 1. 填入中转服务提供的 API Key
OPENAI_API_KEY=your_api_key_from_proxy_service
# 2. 修改 Base URL 为中转服务的地址
OPENAI_BASE_URL=https://your-proxy-service.com/v1
```
##### 常见中转服务配置示例
**New API**
```bash
OPENAI_API_KEY=sk-xxxxxxxxxxxxxxxx
OPENAI_BASE_URL=https://api.new-api.com/v1
```
**API2D**
```bash
OPENAI_API_KEY=fk-xxxxxxxxxxxxxxxx
OPENAI_BASE_URL=https://api.api2d.com/v1
```
**OpenAI-SB**
```bash
OPENAI_API_KEY=sb-xxxxxxxxxxxxxxxx
OPENAI_BASE_URL=https://api.openai-sb.com/v1
```
**自建 One API / New API**
```bash
OPENAI_API_KEY=sk-xxxxxxxxxxxxxxxx
OPENAI_BASE_URL=https://your-domain.com/v1
```
##### 注意事项
- ✅ 所有支持 OpenAI 接口格式的服务都可以使用
- ✅ 确保中转服务的 Base URL 以 `/v1` 结尾
- ✅ 根据中转服务支持的模型,修改 `DEFAULT_MODEL` 参数
- ⚠️ 不同中转服务的模型名称可能不同,请参考服务商文档
- ⚠️ 部分中转服务可能对请求频率或并发有限制
##### 推荐的中转服务
如果你需要中转服务,以下是一些常见选择:
1. **New API** - 开源的 API 分发系统,支持多种模型
2. **API2D** - 国内稳定的 API 中转服务
3. **OpenAI-SB** - 提供多种 AI 模型的中转
4. **自建服务** - 使用 One API 或 New API 自行搭建
> 💡 提示:使用中转服务时,请确保服务提供商的可靠性和数据安全性
### 登录方式配置
#### 本地账户登录(默认启用)
适合个人使用或小型团队:
```bash
LOCAL_AUTH_ENABLED=true
LOCAL_AUTH_USERNAME=admin
LOCAL_AUTH_PASSWORD=your_password
```
#### LinuxDO OAuth 登录
适合需要社区集成的场景,需要在 [LinuxDO](https://linux.do) 注册 OAuth 应用:
```bash
LINUXDO_CLIENT_ID=your_client_id
LINUXDO_CLIENT_SECRET=your_client_secret
LINUXDO_REDIRECT_URI=http://your-domain:8000/api/auth/callback
ports:
- "8800:8000" # 宿主机:容器
```
## 📁 项目结构
```
MuMuAINovel/
├── backend/ # 后端服务
├── backend/ # 后端服务
│ ├── app/
│ │ ├── api/ # API 路由
│ │ │ ├── auth.py # 认证接口
│ │ │ ├── projects.py # 项目管理
│ │ │ ├── chapters.py # 章节管理
│ │ │ ├── characters.py # 角色管理
│ │ │ ├── wizard_stream.py # 向导流式生成
└── ...
│ ├── models/ # 数据模型
│ │ ├── schemas/ # Pydantic 模型
│ │ ├── services/ # 业务逻辑
│ │ │ ├── ai_service.py # AI 服务封装
│ │ │ └── oauth_service.py # OAuth 服务
│ │ ├── middleware/ # 中间件
│ │ ├── utils/ # 工具函数
│ │ ├── config.py # 配置管理
│ │ ├── database.py # 数据库连接
│ │ └── main.py # 应用入口
│ ├── data/ # 数据存储目录
│ ├── static/ # 前端静态文件(构建后)
│ ├── requirements.txt # Python 依赖
│ └── .env.example # 环境变量示例
├── frontend/ # 前端应用
│ │ ├── api/ # API 路由
│ │ ├── models/ # 数据模型
│ │ ├── services/ # 业务逻辑
│ │ ├── middleware/ # 中间件
│ │ ├── database.py # 数据库连接
│ │ └── main.py # 应用入口
├── scripts/ # 工具脚本
└── requirements.txt # Python 依赖
├── frontend/ # 前端应用
│ ├── src/
│ │ ├── pages/ # 页面组件
│ │ │ ├── ProjectList.tsx # 项目列表
│ │ │ ├── ProjectWizardNew.tsx # 创建向导
│ │ │ ├── Chapters.tsx # 章节管理
│ │ ├── Characters.tsx # 角色管理
│ │ │ └── ...
│ │ ├── components/ # 通用组件
│ │ ├── services/ # API 服务
│ │ ├── store/ # 状态管理(Zustand
│ │ ├── types/ # TypeScript 类型
│ │ └── utils/ # 工具函数
│ ├── package.json
│ └── vite.config.ts
├── docker-compose.yml # Docker Compose 配置
├── Dockerfile # Docker 镜像构建
└── README.md # 项目说明文档
│ │ ├── pages/ # 页面组件
│ │ ├── components/ # 通用组件
│ │ ├── services/ # API 服务
│ │ └── store/ # 状态管理
└── package.json
├── docker-compose.yml # Docker Compose 配置
├── Dockerfile # Docker 镜像构建
└── README.md
```
## 🛠️ 技术栈
### 后端
**后端**: FastAPI • PostgreSQL • SQLAlchemy • OpenAI/Claude/Gemini SDK
- **框架**FastAPI 0.109.0
- **数据库**SQLite + SQLAlchemy(异步)
- **AI 集成**OpenAI、Anthropic、Google Gemini SDK
- **认证**LinuxDO OAuth2、本地账户
- **日志**Python logging + 文件轮转
### 前端
- **框架**React 18.3 + TypeScript
- **UI 库**Ant Design 5.27
- **路由**React Router 6.28
- **状态管理**Zustand 5.0
- **HTTP 客户端**Axios
- **构建工具**Vite 7.1
**前端**: React 18 • TypeScript • Ant Design • Zustand • Vite
## 📖 使用指南
### 创建第一个小说项目
1. **登录系统**
- 使用本地账户或 LinuxDO 账户登录
2. **创建项目**
- 点击"创建项目"按钮
- 选择"使用向导创建"或"手动创建"
3. **使用向导(推荐)**
- 输入小说基本信息(标题、类型、背景等)
- AI 自动生成大纲、角色和世界观
- 实时查看生成进度
4. **编辑和完善**
- 在项目详情页查看和编辑大纲
- 管理角色和人物关系
- 生成和编辑章节内容
1. **登录系统** - 使用本地账户或 LinuxDO 账户
2. **创建项目** - 选择"使用向导创建"
3. **AI 生成** - 输入基本信息,AI 自动生成大纲和角色
4. **编辑完善** - 管理角色关系,生成和编辑章节
### API 文档
应用启动后,可访问自动生成的 API 文档:
- Swagger UI`http://localhost:8000/docs`
- ReDoc`http://localhost:8000/redoc`
- Swagger UI: `http://localhost:8000/docs`
- ReDoc: `http://localhost:8000/redoc`
## 🤝 贡献
@@ -489,40 +512,44 @@ MuMuAINovel/
4. 推送到分支 (`git push origin feature/AmazingFeature`)
5. 提交 Pull Request
### 贡献者
感谢所有为本项目做出贡献的开发者!
<a href="https://github.com/xiamuceer-j/MuMuAINovel/graphs/contributors">
<img src="https://contrib.rocks/image?repo=xiamuceer-j/MuMuAINovel" />
</a>
## 📝 许可证
本项目采用 [GNU General Public License v3.0](https://www.gnu.org/licenses/gpl-3.0.html) 开源协议
本项目采用 [GNU General Public License v3.0](LICENSE)
**意味着:**
-**可以** - 自由使用、复制、修改和分发本项目
- **可以** - 用于商业目的
- **可以** - 用于个人学习和研究
- 📝 **必须** - 开源你的修改版本
- 📝 **必须** - 保留原作者版权声明
- 📝 **必须** - 以相同的 GPL v3 协议发布衍生作品
详见 [LICENSE](LICENSE) 文件
**GPL v3 意味着:**
- ✅ 可自由使用、修改和分发
-可用于商业目的
- 📝 必须开源修改版本
- 📝 必须保留原作者版权
- 📝 衍生作品必须使用 GPL v3 协议
## 🙏 致谢
- [FastAPI](https://fastapi.tiangolo.com/) - 现代化的 Python Web 框架
- [React](https://react.dev/) - 用户界面构建库
- [Ant Design](https://ant.design/) - 企业级 UI 设计语言
- [OpenAI](https://openai.com/) / [Anthropic](https://www.anthropic.com/) - AI 模型提供商
- [FastAPI](https://fastapi.tiangolo.com/) - Python Web 框架
- [React](https://react.dev/) - 前端框架
- [Ant Design](https://ant.design/) - UI 组件库
- [PostgreSQL](https://www.postgresql.org/) - 数据库
## 📧 联系方式
如有问题或建议,欢迎通过以下方式联系:
- 提交 [Issue](https://github.com/yourusername/MuMuAINovel/issues)
- Linux DO [LD](https://linux.do/t/topic/1100112)
- 提交 [Issue](https://github.com/xiamuceer-j/MuMuAINovel/issues)
- Linux DO [讨论](https://linux.do/t/topic/1106333)
- 加入QQ群 [QQ群](frontend/public/qq.jpg)
- 加入WX群 [WX群](frontend/public/WX.png)
---
<div align="center">
**如果这个项目对你有帮助,请给个 ⭐️ Star 支持一下**
**如果这个项目对你有帮助,请给个 ⭐️ Star!**
Made with ❤️
+78 -38
View File
@@ -1,54 +1,94 @@
# AI服务配置
# OpenAI配置
OPENAI_API_KEY=your_openai_key_here
OPENAI_BASE_URL=https://api.openai.com/v1
# Anthropic配置
ANTHROPIC_API_KEY=your_anthropic_key_here
ANTHROPIC_BASE_URL=https://api.anthropic.com
# 默认AI提供商:openai, gemini, anthropic
DEFAULT_AI_PROVIDER=openai
DEFAULT_MODEL=gpt-4.1
DEFAULT_TEMPERATURE=0.8
DEFAULT_MAX_TOKENS=32000
# ==========================================
# MuMuAINovel 配置文件示例
# ==========================================
# 复制此文件为 .env 并修改配置值
# cp .env.example .env
# ==========================================
# 应用配置
# ==========================================
APP_NAME=MuMuAINovel
APP_VERSION=1.0.0
APP_VERSION=1.2.6
APP_HOST=0.0.0.0
APP_PORT=8000
DEBUG=true
DEBUG=false
TZ=Asia/Shanghai
# LinuxDO OAuth2 配置(可选)
# 注意:Docker部署时,LINUXDO_REDIRECT_URI 应该使用实际的域名或服务器IP
# 本地开发: http://localhost:8000/api/auth/callback
# 生产环境: https://your-domain.com/api/auth/callback 或 http://your-server-ip:8000/api/auth/callback
LINUXDO_CLIENT_ID=your_client_id_here
LINUXDO_CLIENT_SECRET=your_client_secret_here
# ==========================================
# PostgreSQL 数据库配置
# ==========================================
# PostgreSQL 连接信息
POSTGRES_DB=mumuai_novel
POSTGRES_USER=mumuai
POSTGRES_PASSWORD=123456
POSTGRES_PORT=5432
# 数据库连接 URL(Docker 部署时自动生成)
# DATABASE_URL=postgresql+asyncpg://mumuai:123456@localhost:5432/mumuai_novel
# ==========================================
# SQLite 数据库配置
# ==========================================
# DATABASE_URL=sqlite+aiosqlite:///data/ai_story.db
# ==========================================
# 日志配置
# ==========================================
LOG_LEVEL=INFO
LOG_TO_FILE=true
LOG_FILE_PATH=logs/app.log
LOG_MAX_BYTES=10485760
LOG_BACKUP_COUNT=30
# ==========================================
# CORS 配置
# ==========================================
CORS_ORIGINS=["http://localhost:8000","http://127.0.0.1:8000"]
# ==========================================
# 代理配置(可选)
# ==========================================
# HTTP_PROXY=http://your-proxy:port
# HTTPS_PROXY=http://your-proxy:port
# NO_PROXY=localhost,127.0.0.1
# ==========================================
# AI 服务配置(至少配置一个)
# ==========================================
# OpenAI 配置
OPENAI_API_KEY=your_openai_api_key_here
OPENAI_BASE_URL=https://api.openai.com/v1
# 默认 AI 配置
DEFAULT_AI_PROVIDER=openai
DEFAULT_MODEL=gpt-4o-mini
DEFAULT_TEMPERATURE=0.7
DEFAULT_MAX_TOKENS=32000
# ==========================================
# LinuxDO OAuth 配置(可选)
# ==========================================
LINUXDO_CLIENT_ID=11111
LINUXDO_CLIENT_SECRET=11111
LINUXDO_REDIRECT_URI=http://localhost:8000/api/auth/callback
# 前端URL配置(用于OAuth回调后重定向到前端)
# 本地开发: http://localhost:8000
# 生产环境: https://your-domain.com 或 http://your-server-ip:8000
FRONTEND_URL=http://localhost:8000
# 初始管理员(LinuxDO user_id
# INITIAL_ADMIN_LINUXDO_ID=your_linuxdo_user_id
# ==========================================
# 本地账户登录配置
# 启用本地账户登录(true/false)
# ==========================================
LOCAL_AUTH_ENABLED=true
# 本地登录用户名
LOCAL_AUTH_USERNAME=admin
# 本地登录密码
LOCAL_AUTH_PASSWORD=your_secure_password_here
# 本地用户显示名称
LOCAL_AUTH_DISPLAY_NAME=管理员
LOCAL_AUTH_PASSWORD=admin123
LOCAL_AUTH_DISPLAY_NAME=本地管理员
# ==========================================
# 会话配置
# 会话过期时间(分钟),默认120分钟(2小时)
# ==========================================
SESSION_EXPIRE_MINUTES=120
# 会话刷新阈值(分钟),剩余时间少于此值时可刷新,默认30分钟
SESSION_REFRESH_THRESHOLD_MINUTES=30
# CORS配置(生产环境)
# 允许的跨域来源,多个用逗号分隔
# CORS_ORIGINS=https://your-domain.com,https://www.your-domain.com
+48
View File
@@ -0,0 +1,48 @@
# Alembic Database Migration Profile - PostgreSQL
# Database version management for the MuMuAINovel project
[alembic]
# Migration Script storage directory (PostgreSQL)
script_location = alembic/postgres
# Template File Path (for generating migration scripts)
file_template = %%(year)d%%(month).2d%%(day).2d_%%(hour).2d%%(minute).2d_%%(rev)s_%%(slug)s
# Database connection string
# Note: The actual connection string is read from the environment variable in env.py
# sqlalchemy.url = postgresql+asyncpg://mumuai:password@localhost:5432/mumuai_novel
# Log Configuration
[loggers]
keys = root,sqlalchemy,alembic
[handlers]
keys = console
[formatters]
keys = generic
[logger_root]
level = WARN
handlers = console
qualname =
[logger_sqlalchemy]
level = WARN
handlers =
qualname = sqlalchemy.engine
[logger_alembic]
level = INFO
handlers =
qualname = alembic
[handler_console]
class = StreamHandler
args = (sys.stderr,)
level = NOTSET
formatter = generic
[formatter_generic]
format = %(levelname)-5.5s [%(name)s] %(message)s
datefmt = %H:%M:%S
+48
View File
@@ -0,0 +1,48 @@
# Alembic Database Migration Profile - SQLite
# Database version management for the MuMuAINovel project
[alembic]
# Migration Script storage directory (SQLite)
script_location = alembic/sqlite
# Template File Path (for generating migration scripts)
file_template = %%(year)d%%(month).2d%%(day).2d_%%(hour).2d%%(minute).2d_%%(rev)s_%%(slug)s
# Database connection string
# Note: The actual connection string is read from the environment variable in env.py
# sqlalchemy.url = sqlite+aiosqlite:///data/ai_story.db
# Log Configuration
[loggers]
keys = root,sqlalchemy,alembic
[handlers]
keys = console
[formatters]
keys = generic
[logger_root]
level = WARN
handlers = console
qualname =
[logger_sqlalchemy]
level = WARN
handlers =
qualname = sqlalchemy.engine
[logger_alembic]
level = INFO
handlers =
qualname = alembic
[handler_console]
class = StreamHandler
args = (sys.stderr,)
level = NOTSET
formatter = generic
[formatter_generic]
format = %(levelname)-5.5s [%(name)s] %(message)s
datefmt = %H:%M:%S
+145
View File
@@ -0,0 +1,145 @@
# Alembic 数据库迁移指南
本项目支持 **PostgreSQL** 和 **SQLite** 两种数据库,使用独立的 Alembic 配置管理迁移。
## 📁 目录结构
```
backend/
├── alembic-postgres.ini # PostgreSQL 配置文件
├── alembic-sqlite.ini # SQLite 配置文件
├── alembic/
│ ├── postgres/ # PostgreSQL 迁移脚本目录
│ │ ├── env.py
│ │ ├── script.py.mako
│ │ └── versions/ # PostgreSQL 迁移版本
│ │ ├── 20251226_1008_ee0a189f1532_初始数据库结构.py
│ │ └── 20251226_1102_e411428f00c0_初始化预置数据.py
│ └── sqlite/ # SQLite 迁移脚本目录
│ ├── env.py
│ ├── script.py.mako
│ └── versions/ # SQLite 迁移版本
│ └── 20251226_1322_fbeb1038c728_初始化sqlite数据库.py
```
## 🚀 使用方法
### 1. PostgreSQL 数据库
#### 配置环境变量
```bash
# .env 文件
DATABASE_URL=postgresql+asyncpg://username:password@localhost:5432/database_name
```
#### 生成迁移脚本
```bash
cd backend
alembic -c alembic-postgres.ini revision --autogenerate -m "描述信息"
```
#### 应用迁移
```bash
alembic -c alembic-postgres.ini upgrade head
```
#### 回退迁移
```bash
alembic -c alembic-postgres.ini downgrade -1
```
#### 查看迁移历史
```bash
alembic -c alembic-postgres.ini history
alembic -c alembic-postgres.ini current
```
### 2. SQLite 数据库
#### 配置环境变量
```bash
# .env 文件
DATABASE_URL=sqlite+aiosqlite:///./data/mumuai.db
```
#### 生成迁移脚本
```bash
cd backend
alembic -c alembic-sqlite.ini revision --autogenerate -m "描述信息"
```
#### 应用迁移
```bash
alembic -c alembic-sqlite.ini upgrade head
```
#### 回退迁移
```bash
alembic -c alembic-sqlite.ini downgrade -1
```
#### 查看迁移历史
```bash
alembic -c alembic-sqlite.ini history
alembic -c alembic-sqlite.ini current
```
## ⚙️ 关键配置差异
### PostgreSQL (alembic/postgres/env.py)
- `render_as_batch=False` - 直接支持 ALTER TABLE
- 使用 `server_default=sa.text('now()')`
### SQLite (alembic/sqlite/env.py)
- `render_as_batch=True` - 通过重建表实现 ALTER TABLE
- 使用 `server_default=sa.text('(CURRENT_TIMESTAMP)')` - SQLite 格式
## 📝 注意事项
### SQLite 限制
1. **并发写入**:同时只允许一个写操作
2. **ALTER TABLE 限制**:某些操作需要重建表(Alembic 的批处理模式会自动处理)
3. **类型映射**
- `JSON` → `TEXT` (SQLAlchemy 自动处理)
- `BOOLEAN` → `INTEGER` (0/1)
- `DEFAULT now()` → `DEFAULT CURRENT_TIMESTAMP`
### PostgreSQL 优势
1. **高并发支持**:多用户同时读写
2. **完整的 ALTER TABLE 支持**
3. **高级特性**:全文搜索、JSON 操作符、数组类型等
## 🔄 切换数据库
只需修改 `.env` 文件中的 `DATABASE_URL`,然后使用对应的配置文件执行迁移:
```bash
# 切换到 SQLite
DATABASE_URL=sqlite+aiosqlite:///./data/mumuai.db
alembic -c alembic-sqlite.ini upgrade head
# 切换到 PostgreSQL
DATABASE_URL=postgresql+asyncpg://user:pass@localhost:5432/db
alembic -c alembic-postgres.ini upgrade head
```
## 💡 最佳实践
1. **开发环境**:使用 SQLite(简单、无需额外服务)
2. **生产环境**:使用 PostgreSQL(性能、并发、稳定性)
3. **保持同步**:两个数据库的模型定义必须一致
4. **测试迁移**:在两种数据库上都测试迁移脚本
## 🐛 常见问题
### Q: 迁移脚本生成后可以通用吗?
A: 不行。PostgreSQL 和 SQLite 的迁移脚本是独立的,因为:
- SQL 语法差异(如 DEFAULT 值)
- 类型差异(如 JSON、BOOLEAN
- ALTER TABLE 能力差异
### Q: 如何从 PostgreSQL 迁移数据到 SQLite
A: 需要编写数据导出/导入脚本,不能直接复用迁移脚本。
### Q: 为什么 SQLite 迁移这么慢?
A: SQLite 的 ALTER TABLE 限制导致需要重建表,这在大表时会很慢。
+2
View File
@@ -0,0 +1,2 @@
# 此文件确保 versions 目录被 Git 追踪
# 迁移版本文件将存放在此目录
+101
View File
@@ -0,0 +1,101 @@
"""Alembic 环境配置文件 - PostgreSQL"""
import asyncio
import os
import sys
from logging.config import fileConfig
from sqlalchemy import pool
from sqlalchemy.engine import Connection
from sqlalchemy.ext.asyncio import async_engine_from_config
from alembic import context
# 添加项目根目录到 Python 路径
project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
if project_root not in sys.path:
sys.path.insert(0, project_root)
# 导入应用配置
from app.config import settings
# 导入 Base 和所有模型
from app.database import Base
from app.models import (
Project, Outline, Character, Chapter, GenerationHistory,
Settings, WritingStyle, ProjectDefaultStyle,
RelationshipType, CharacterRelationship, Organization, OrganizationMember,
StoryMemory, PlotAnalysis, AnalysisTask, BatchGenerationTask,
RegenerationTask, Career, CharacterCareer, User, MCPPlugin, PromptTemplate
)
# Alembic Config 对象
config = context.config
# 设置数据库连接字符串(从环境变量读取)
config.set_main_option("sqlalchemy.url", settings.database_url)
# 配置日志
if config.config_file_name is not None:
fileConfig(config.config_file_name)
# 设置 target_metadata 为应用的 Base.metadata
target_metadata = Base.metadata
def run_migrations_offline() -> None:
"""'离线'模式下运行迁移"""
url = config.get_main_option("sqlalchemy.url")
context.configure(
url=url,
target_metadata=target_metadata,
literal_binds=True,
dialect_opts={"paramstyle": "named"},
compare_type=True,
compare_server_default=True,
)
with context.begin_transaction():
context.run_migrations()
def do_run_migrations(connection: Connection) -> None:
"""执行迁移的核心函数 - PostgreSQL 专用"""
context.configure(
connection=connection,
target_metadata=target_metadata,
compare_type=True,
compare_server_default=True,
render_as_batch=False, # PostgreSQL 不需要批处理模式
)
with context.begin_transaction():
context.run_migrations()
async def run_async_migrations() -> None:
"""'在线'模式下运行异步迁移"""
configuration = config.get_section(config.config_ini_section, {})
configuration["sqlalchemy.url"] = settings.database_url
connectable = async_engine_from_config(
configuration,
prefix="sqlalchemy.",
poolclass=pool.NullPool,
)
async with connectable.connect() as connection:
await connection.run_sync(do_run_migrations)
await connectable.dispose()
def run_migrations_online() -> None:
"""'在线'模式下运行迁移"""
asyncio.run(run_async_migrations())
# 根据上下文选择运行模式
if context.is_offline_mode():
run_migrations_offline()
else:
run_migrations_online()
+26
View File
@@ -0,0 +1,26 @@
"""${message}
Revision ID: ${up_revision}
Revises: ${down_revision | comma,n}
Create Date: ${create_date}
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
${imports if imports else ""}
# revision identifiers, used by Alembic.
revision: str = ${repr(up_revision)}
down_revision: Union[str, None] = ${repr(down_revision)}
branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
def upgrade() -> None:
${upgrades if upgrades else "pass"}
def downgrade() -> None:
${downgrades if downgrades else "pass"}
@@ -0,0 +1,554 @@
"""初始数据库结构
Revision ID: ee0a189f1532
Revises:
Create Date: 2025-12-26 10:08:55.432217
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = 'ee0a189f1532'
down_revision: Union[str, None] = None
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
"""升级数据库结构"""
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('batch_generation_tasks',
sa.Column('id', sa.String(length=36), nullable=False),
sa.Column('project_id', sa.String(length=36), nullable=False, comment='项目ID'),
sa.Column('user_id', sa.String(length=100), nullable=False, comment='用户ID'),
sa.Column('start_chapter_number', sa.Integer(), nullable=False, comment='起始章节序号'),
sa.Column('chapter_count', sa.Integer(), nullable=False, comment='生成章节数量'),
sa.Column('chapter_ids', sa.JSON(), nullable=False, comment='待生成的章节ID列表'),
sa.Column('style_id', sa.Integer(), nullable=True, comment='使用的写作风格ID'),
sa.Column('target_word_count', sa.Integer(), nullable=True, comment='目标字数'),
sa.Column('enable_analysis', sa.Boolean(), nullable=True, comment='是否启用同步分析'),
sa.Column('status', sa.String(length=20), nullable=True, comment='任务状态: pending/running/completed/failed/cancelled'),
sa.Column('total_chapters', sa.Integer(), nullable=True, comment='总章节数'),
sa.Column('completed_chapters', sa.Integer(), nullable=True, comment='已完成章节数'),
sa.Column('failed_chapters', sa.JSON(), nullable=True, comment='失败的章节信息列表'),
sa.Column('current_chapter_id', sa.String(length=36), nullable=True, comment='当前正在生成的章节ID'),
sa.Column('current_chapter_number', sa.Integer(), nullable=True, comment='当前正在生成的章节序号'),
sa.Column('current_retry_count', sa.Integer(), nullable=True, comment='当前章节重试次数'),
sa.Column('max_retries', sa.Integer(), nullable=True, comment='最大重试次数'),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('now()'), nullable=True, comment='创建时间'),
sa.Column('started_at', sa.DateTime(), nullable=True, comment='开始时间'),
sa.Column('completed_at', sa.DateTime(), nullable=True, comment='完成时间'),
sa.Column('error_message', sa.String(length=500), nullable=True, comment='错误信息'),
sa.PrimaryKeyConstraint('id')
)
op.create_table('mcp_plugins',
sa.Column('id', sa.String(length=36), nullable=False),
sa.Column('user_id', sa.String(length=50), nullable=False, comment='用户ID'),
sa.Column('plugin_name', sa.String(length=100), nullable=False, comment='插件名称(唯一标识)'),
sa.Column('display_name', sa.String(length=200), nullable=False, comment='显示名称'),
sa.Column('description', sa.Text(), nullable=True, comment='插件描述'),
sa.Column('plugin_type', sa.String(length=50), nullable=True, comment='插件类型:http/stdio'),
sa.Column('server_url', sa.String(length=500), nullable=True, comment='服务器URLHTTP类型)'),
sa.Column('command', sa.String(length=500), nullable=True, comment='启动命令(stdio类型)'),
sa.Column('args', sa.JSON(), nullable=True, comment='命令参数(stdio类型)'),
sa.Column('env', sa.JSON(), nullable=True, comment='环境变量'),
sa.Column('headers', sa.JSON(), nullable=True, comment='HTTP请求头'),
sa.Column('config', sa.JSON(), nullable=True, comment='插件特定配置(JSON'),
sa.Column('tools', sa.JSON(), nullable=True, comment='提供的工具列表'),
sa.Column('enabled', sa.Boolean(), nullable=True, comment='是否启用'),
sa.Column('status', sa.String(length=50), nullable=True, comment='状态:active/inactive/error'),
sa.Column('last_error', sa.Text(), nullable=True, comment='最后错误信息'),
sa.Column('last_test_at', sa.DateTime(), nullable=True, comment='最后测试时间'),
sa.Column('category', sa.String(length=100), nullable=True, comment='分类'),
sa.Column('sort_order', sa.Integer(), nullable=True, comment='排序顺序'),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('now()'), nullable=True, comment='创建时间'),
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('now()'), nullable=True, comment='更新时间'),
sa.PrimaryKeyConstraint('id')
)
op.create_index('idx_user_enabled', 'mcp_plugins', ['user_id', 'enabled'], unique=False)
op.create_index('idx_user_plugin', 'mcp_plugins', ['user_id', 'plugin_name'], unique=True)
op.create_index(op.f('ix_mcp_plugins_user_id'), 'mcp_plugins', ['user_id'], unique=False)
op.create_table('projects',
sa.Column('id', sa.String(length=36), nullable=False),
sa.Column('user_id', sa.String(length=100), nullable=False, comment='用户ID'),
sa.Column('title', sa.String(length=200), nullable=False, comment='项目标题'),
sa.Column('description', sa.Text(), nullable=True, comment='项目简介'),
sa.Column('theme', sa.Text(), nullable=True, comment='主题'),
sa.Column('genre', sa.String(length=50), nullable=True, comment='小说类型'),
sa.Column('target_words', sa.Integer(), nullable=True, comment='目标字数'),
sa.Column('current_words', sa.Integer(), nullable=True, comment='当前字数'),
sa.Column('status', sa.String(length=20), nullable=True, comment='创作状态'),
sa.Column('wizard_status', sa.String(length=20), nullable=True, comment='向导完成状态: incomplete/completed'),
sa.Column('wizard_step', sa.Integer(), nullable=True, comment='向导当前步骤: 0-4'),
sa.Column('outline_mode', sa.String(length=20), nullable=False, comment='大纲章节模式: one-to-one(传统模式) 或 one-to-many(细化模式)'),
sa.Column('world_time_period', sa.Text(), nullable=True, comment='时间背景'),
sa.Column('world_location', sa.Text(), nullable=True, comment='地理位置'),
sa.Column('world_atmosphere', sa.Text(), nullable=True, comment='氛围基调'),
sa.Column('world_rules', sa.Text(), nullable=True, comment='世界规则'),
sa.Column('chapter_count', sa.Integer(), nullable=True, comment='章节数量'),
sa.Column('narrative_perspective', sa.String(length=50), nullable=True, comment='叙事视角:first_person/third_person/omniscient'),
sa.Column('character_count', sa.Integer(), nullable=True, comment='角色数量'),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('now()'), nullable=True, comment='创建时间'),
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('now()'), nullable=True, comment='更新时间'),
sa.CheckConstraint("outline_mode IN ('one-to-one', 'one-to-many')", name='check_outline_mode'),
sa.PrimaryKeyConstraint('id')
)
op.create_index(op.f('ix_projects_user_id'), 'projects', ['user_id'], unique=False)
op.create_table('prompt_templates',
sa.Column('id', sa.String(length=36), nullable=False),
sa.Column('user_id', sa.String(length=50), nullable=False, comment='用户ID'),
sa.Column('template_key', sa.String(length=100), nullable=False, comment='模板键名'),
sa.Column('template_name', sa.String(length=200), nullable=False, comment='模板显示名称'),
sa.Column('template_content', sa.Text(), nullable=False, comment='模板内容'),
sa.Column('description', sa.Text(), nullable=True, comment='模板描述'),
sa.Column('category', sa.String(length=50), nullable=True, comment='模板分类'),
sa.Column('parameters', sa.Text(), nullable=True, comment='模板参数定义(JSON)'),
sa.Column('is_active', sa.Boolean(), nullable=True, comment='是否启用'),
sa.Column('is_system_default', sa.Boolean(), nullable=True, comment='是否为系统默认模板'),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('now()'), nullable=True, comment='创建时间'),
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('now()'), nullable=True, comment='更新时间'),
sa.PrimaryKeyConstraint('id')
)
op.create_index('idx_user_template', 'prompt_templates', ['user_id', 'template_key'], unique=True)
op.create_index(op.f('ix_prompt_templates_user_id'), 'prompt_templates', ['user_id'], unique=False)
op.create_table('relationship_types',
sa.Column('id', sa.Integer(), autoincrement=True, nullable=False),
sa.Column('name', sa.String(length=50), nullable=False, comment='关系名称'),
sa.Column('category', sa.String(length=20), nullable=False, comment='分类:family/social/hostile/professional'),
sa.Column('reverse_name', sa.String(length=50), nullable=True, comment='反向关系名称'),
sa.Column('intimacy_range', sa.String(length=20), nullable=True, comment='亲密度范围:high/medium/low'),
sa.Column('icon', sa.String(length=50), nullable=True, comment='图标标识'),
sa.Column('description', sa.Text(), nullable=True, comment='关系描述'),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('now()'), nullable=True, comment='创建时间'),
sa.PrimaryKeyConstraint('id')
)
op.create_index(op.f('ix_relationship_types_id'), 'relationship_types', ['id'], unique=False)
op.create_table('settings',
sa.Column('id', sa.String(length=36), nullable=False),
sa.Column('user_id', sa.String(length=50), nullable=False, comment='用户ID'),
sa.Column('api_provider', sa.String(length=50), nullable=True, comment='API提供商'),
sa.Column('api_key', sa.String(length=500), nullable=True, comment='API密钥'),
sa.Column('api_base_url', sa.String(length=500), nullable=True, comment='自定义API地址'),
sa.Column('llm_model', sa.String(length=100), nullable=True, comment='模型名称'),
sa.Column('temperature', sa.Float(), nullable=True, comment='温度参数'),
sa.Column('max_tokens', sa.Integer(), nullable=True, comment='最大token数'),
sa.Column('preferences', sa.Text(), nullable=True, comment='其他偏好设置(JSON)'),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('now()'), nullable=True, comment='创建时间'),
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('now()'), nullable=True, comment='更新时间'),
sa.PrimaryKeyConstraint('id')
)
op.create_index('idx_user_id', 'settings', ['user_id'], unique=False)
op.create_index(op.f('ix_settings_user_id'), 'settings', ['user_id'], unique=True)
op.create_table('user_passwords',
sa.Column('user_id', sa.String(length=100), nullable=False, comment='用户ID'),
sa.Column('username', sa.String(length=100), nullable=False, comment='用户名'),
sa.Column('password_hash', sa.String(length=64), nullable=False, comment='密码哈希(SHA256'),
sa.Column('has_custom_password', sa.Boolean(), nullable=True, comment='是否为自定义密码'),
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=True, comment='创建时间'),
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=True, comment='更新时间'),
sa.PrimaryKeyConstraint('user_id')
)
op.create_index(op.f('ix_user_passwords_user_id'), 'user_passwords', ['user_id'], unique=False)
op.create_table('users',
sa.Column('user_id', sa.String(length=100), nullable=False, comment='用户ID,格式:linuxdo_{id} 或 local_{id}'),
sa.Column('username', sa.String(length=100), nullable=False, comment='用户名'),
sa.Column('display_name', sa.String(length=200), nullable=False, comment='显示名称'),
sa.Column('avatar_url', sa.String(length=500), nullable=True, comment='头像URL'),
sa.Column('trust_level', sa.Integer(), nullable=True, comment='信任等级(仅用于显示)'),
sa.Column('is_admin', sa.Boolean(), nullable=True, comment='是否为管理员'),
sa.Column('linuxdo_id', sa.String(length=100), nullable=False, comment='LinuxDO用户ID或本地用户ID'),
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=True, comment='创建时间'),
sa.Column('last_login', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=True, comment='最后登录时间'),
sa.PrimaryKeyConstraint('user_id')
)
op.create_index(op.f('ix_users_linuxdo_id'), 'users', ['linuxdo_id'], unique=True)
op.create_index(op.f('ix_users_user_id'), 'users', ['user_id'], unique=False)
op.create_index(op.f('ix_users_username'), 'users', ['username'], unique=False)
op.create_table('careers',
sa.Column('id', sa.String(length=36), nullable=False),
sa.Column('project_id', sa.String(length=36), nullable=False),
sa.Column('name', sa.String(length=100), nullable=False, comment='职业名称'),
sa.Column('type', sa.String(length=20), nullable=False, comment='职业类型: main(主职业)/sub(副职业)'),
sa.Column('description', sa.Text(), nullable=True, comment='职业描述'),
sa.Column('category', sa.String(length=50), nullable=True, comment='职业分类(如:战斗系、生产系、辅助系)'),
sa.Column('stages', sa.Text(), nullable=False, comment="职业阶段列表(JSON): [{level:1, name:'', description:''}, ...]"),
sa.Column('max_stage', sa.Integer(), nullable=False, comment='最大阶段数'),
sa.Column('requirements', sa.Text(), nullable=True, comment='职业要求/限制'),
sa.Column('special_abilities', sa.Text(), nullable=True, comment='特殊能力描述'),
sa.Column('worldview_rules', sa.Text(), nullable=True, comment='世界观规则关联'),
sa.Column('attribute_bonuses', sa.Text(), nullable=True, comment="属性加成(JSON): {strength: '+10%', intelligence: '+5%'}"),
sa.Column('source', sa.String(length=20), nullable=True, comment='来源: ai/manual'),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('now()'), nullable=True, comment='创建时间'),
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('now()'), nullable=True, comment='更新时间'),
sa.ForeignKeyConstraint(['project_id'], ['projects.id'], ondelete='CASCADE'),
sa.PrimaryKeyConstraint('id')
)
op.create_index('idx_project_id', 'careers', ['project_id'], unique=False)
op.create_index('idx_type', 'careers', ['type'], unique=False)
op.create_table('outlines',
sa.Column('id', sa.String(length=36), nullable=False),
sa.Column('project_id', sa.String(length=36), nullable=False),
sa.Column('title', sa.String(length=200), nullable=False, comment='大纲标题'),
sa.Column('content', sa.Text(), nullable=True, comment='大纲内容'),
sa.Column('structure', sa.Text(), nullable=True, comment='结构化大纲数据(JSON)'),
sa.Column('order_index', sa.Integer(), nullable=True, comment='排序序号'),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('now()'), nullable=True, comment='创建时间'),
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('now()'), nullable=True, comment='更新时间'),
sa.ForeignKeyConstraint(['project_id'], ['projects.id'], ondelete='CASCADE'),
sa.PrimaryKeyConstraint('id')
)
op.create_table('writing_styles',
sa.Column('id', sa.Integer(), autoincrement=True, nullable=False),
sa.Column('user_id', sa.String(length=255), nullable=True, comment='所属用户ID(NULL表示全局预设风格)'),
sa.Column('name', sa.String(length=100), nullable=False, comment='风格名称'),
sa.Column('style_type', sa.String(length=50), nullable=False, comment='风格类型:preset/custom'),
sa.Column('preset_id', sa.String(length=50), nullable=True, comment='预设风格IDnatural/classical/modern等'),
sa.Column('description', sa.Text(), nullable=True, comment='风格描述'),
sa.Column('prompt_content', sa.Text(), nullable=False, comment='风格提示词内容'),
sa.Column('order_index', sa.Integer(), nullable=True, comment='排序序号'),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('now()'), nullable=True, comment='创建时间'),
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('now()'), nullable=True, comment='更新时间'),
sa.ForeignKeyConstraint(['user_id'], ['users.user_id'], ondelete='CASCADE'),
sa.PrimaryKeyConstraint('id')
)
op.create_table('chapters',
sa.Column('id', sa.String(length=36), nullable=False),
sa.Column('project_id', sa.String(length=36), nullable=False),
sa.Column('chapter_number', sa.Integer(), nullable=False, comment='章节序号'),
sa.Column('title', sa.String(length=200), nullable=False, comment='章节标题'),
sa.Column('content', sa.Text(), nullable=True, comment='章节内容'),
sa.Column('summary', sa.Text(), nullable=True, comment='章节摘要'),
sa.Column('word_count', sa.Integer(), nullable=True, comment='字数统计'),
sa.Column('status', sa.String(length=20), nullable=True, comment='章节状态'),
sa.Column('outline_id', sa.String(length=36), nullable=True, comment='关联的大纲ID'),
sa.Column('sub_index', sa.Integer(), nullable=True, comment='大纲下的子章节序号'),
sa.Column('expansion_plan', sa.Text(), nullable=True, comment='展开规划详情(JSON): 包含key_events, character_focus, emotional_tone等'),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('now()'), nullable=True, comment='创建时间'),
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('now()'), nullable=True, comment='更新时间'),
sa.ForeignKeyConstraint(['outline_id'], ['outlines.id'], ondelete='SET NULL'),
sa.ForeignKeyConstraint(['project_id'], ['projects.id'], ondelete='CASCADE'),
sa.PrimaryKeyConstraint('id')
)
op.create_table('characters',
sa.Column('id', sa.String(length=36), nullable=False),
sa.Column('project_id', sa.String(length=36), nullable=False),
sa.Column('name', sa.String(length=100), nullable=False, comment='角色/组织名称'),
sa.Column('age', sa.String(length=50), nullable=True, comment='年龄'),
sa.Column('gender', sa.String(length=50), nullable=True, comment='性别'),
sa.Column('is_organization', sa.Boolean(), nullable=True, comment='是否为组织'),
sa.Column('role_type', sa.String(length=50), nullable=True, comment='角色类型'),
sa.Column('personality', sa.Text(), nullable=True, comment='性格特点/组织特性'),
sa.Column('background', sa.Text(), nullable=True, comment='背景故事'),
sa.Column('appearance', sa.Text(), nullable=True, comment='外貌描述'),
sa.Column('relationships', sa.Text(), nullable=True, comment='人物关系(JSON)'),
sa.Column('organization_type', sa.String(length=100), nullable=True, comment='组织类型'),
sa.Column('organization_purpose', sa.String(length=500), nullable=True, comment='组织目的'),
sa.Column('organization_members', sa.Text(), nullable=True, comment='组织成员(JSON)'),
sa.Column('main_career_id', sa.String(length=36), nullable=True, comment='主职业ID'),
sa.Column('main_career_stage', sa.Integer(), nullable=True, comment='主职业当前阶段'),
sa.Column('sub_careers', sa.Text(), nullable=True, comment='副职业列表(JSON): [{"career_id": "xxx", "stage": 3}, ...]'),
sa.Column('avatar_url', sa.String(length=500), nullable=True, comment='头像URL'),
sa.Column('traits', sa.Text(), nullable=True, comment='特征标签(JSON)'),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('now()'), nullable=True, comment='创建时间'),
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('now()'), nullable=True, comment='更新时间'),
sa.ForeignKeyConstraint(['main_career_id'], ['careers.id'], ondelete='SET NULL'),
sa.ForeignKeyConstraint(['project_id'], ['projects.id'], ondelete='CASCADE'),
sa.PrimaryKeyConstraint('id')
)
op.create_table('project_default_styles',
sa.Column('id', sa.Integer(), autoincrement=True, nullable=False),
sa.Column('project_id', sa.String(length=36), nullable=False, comment='项目ID'),
sa.Column('style_id', sa.Integer(), nullable=False, comment='风格ID'),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('now()'), nullable=True, comment='创建时间'),
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('now()'), nullable=True, comment='更新时间'),
sa.ForeignKeyConstraint(['project_id'], ['projects.id'], ondelete='CASCADE'),
sa.ForeignKeyConstraint(['style_id'], ['writing_styles.id'], ondelete='CASCADE'),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('project_id', name='uix_project_default_style')
)
op.create_table('analysis_tasks',
sa.Column('id', sa.String(length=36), nullable=False, comment='任务ID'),
sa.Column('chapter_id', sa.String(length=36), nullable=False, comment='章节ID'),
sa.Column('user_id', sa.String(length=50), nullable=False, comment='用户ID'),
sa.Column('project_id', sa.String(length=36), nullable=False, comment='项目ID'),
sa.Column('status', sa.String(length=20), nullable=False, comment='任务状态: pending/running/completed/failed'),
sa.Column('progress', sa.Integer(), nullable=True, comment='进度 0-100'),
sa.Column('error_message', sa.Text(), nullable=True, comment='错误信息'),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('now()'), nullable=True, comment='创建时间'),
sa.Column('started_at', sa.DateTime(), nullable=True, comment='开始执行时间'),
sa.Column('completed_at', sa.DateTime(), nullable=True, comment='完成时间'),
sa.ForeignKeyConstraint(['chapter_id'], ['chapters.id'], ondelete='CASCADE'),
sa.PrimaryKeyConstraint('id')
)
op.create_index('idx_chapter_id_created', 'analysis_tasks', ['chapter_id', 'created_at'], unique=False)
op.create_index('idx_status', 'analysis_tasks', ['status'], unique=False)
op.create_table('character_careers',
sa.Column('id', sa.String(length=36), nullable=False),
sa.Column('character_id', sa.String(length=36), nullable=False),
sa.Column('career_id', sa.String(length=36), nullable=False),
sa.Column('career_type', sa.String(length=20), nullable=False, comment='main(主职业)/sub(副职业)'),
sa.Column('current_stage', sa.Integer(), nullable=False, comment='当前阶段(对应职业中的数值)'),
sa.Column('stage_progress', sa.Integer(), nullable=True, comment='阶段内进度(0-100'),
sa.Column('started_at', sa.String(length=100), nullable=True, comment='开始修炼时间(小说时间线)'),
sa.Column('reached_current_stage_at', sa.String(length=100), nullable=True, comment='到达当前阶段时间'),
sa.Column('notes', sa.Text(), nullable=True, comment='备注(如:修炼心得、特殊事件)'),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('now()'), nullable=True, comment='创建时间'),
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('now()'), nullable=True, comment='更新时间'),
sa.ForeignKeyConstraint(['career_id'], ['careers.id'], ondelete='CASCADE'),
sa.ForeignKeyConstraint(['character_id'], ['characters.id'], ondelete='CASCADE'),
sa.PrimaryKeyConstraint('id')
)
op.create_index('idx_career_type', 'character_careers', ['career_type'], unique=False)
op.create_index('idx_character_career', 'character_careers', ['character_id', 'career_id'], unique=True)
op.create_index('idx_character_id', 'character_careers', ['character_id'], unique=False)
op.create_table('character_relationships',
sa.Column('id', sa.String(length=36), nullable=False, comment='关系ID'),
sa.Column('project_id', sa.String(length=36), nullable=False, comment='项目ID'),
sa.Column('character_from_id', sa.String(length=36), nullable=False, comment='角色A的ID'),
sa.Column('character_to_id', sa.String(length=36), nullable=False, comment='角色B的ID'),
sa.Column('relationship_type_id', sa.Integer(), nullable=True, comment='关系类型ID'),
sa.Column('relationship_name', sa.String(length=100), nullable=True, comment='自定义关系名称'),
sa.Column('intimacy_level', sa.Integer(), nullable=True, comment='亲密度:-100到100'),
sa.Column('status', sa.String(length=20), nullable=True, comment='状态:active/broken/past/complicated'),
sa.Column('description', sa.Text(), nullable=True, comment='关系详细描述'),
sa.Column('started_at', sa.String(length=100), nullable=True, comment='关系开始时间(故事时间)'),
sa.Column('ended_at', sa.String(length=100), nullable=True, comment='关系结束时间(故事时间)'),
sa.Column('source', sa.String(length=20), nullable=True, comment='来源:ai/manual/imported'),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('now()'), nullable=True, comment='创建时间'),
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('now()'), nullable=True, comment='更新时间'),
sa.ForeignKeyConstraint(['character_from_id'], ['characters.id'], ondelete='CASCADE'),
sa.ForeignKeyConstraint(['character_to_id'], ['characters.id'], ondelete='CASCADE'),
sa.ForeignKeyConstraint(['project_id'], ['projects.id'], ondelete='CASCADE'),
sa.ForeignKeyConstraint(['relationship_type_id'], ['relationship_types.id'], ),
sa.PrimaryKeyConstraint('id')
)
op.create_index(op.f('ix_character_relationships_character_from_id'), 'character_relationships', ['character_from_id'], unique=False)
op.create_index(op.f('ix_character_relationships_character_to_id'), 'character_relationships', ['character_to_id'], unique=False)
op.create_index(op.f('ix_character_relationships_project_id'), 'character_relationships', ['project_id'], unique=False)
op.create_index(op.f('ix_character_relationships_relationship_type_id'), 'character_relationships', ['relationship_type_id'], unique=False)
op.create_table('generation_history',
sa.Column('id', sa.String(length=36), nullable=False),
sa.Column('project_id', sa.String(length=36), nullable=False),
sa.Column('chapter_id', sa.String(length=36), nullable=True),
sa.Column('prompt', sa.Text(), nullable=True, comment='使用的提示词'),
sa.Column('generated_content', sa.Text(), nullable=True, comment='生成的内容'),
sa.Column('model', sa.String(length=50), nullable=True, comment='使用的模型'),
sa.Column('tokens_used', sa.Integer(), nullable=True, comment='消耗的token数'),
sa.Column('generation_time', sa.Float(), nullable=True, comment='生成耗时(秒)'),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('now()'), nullable=True, comment='创建时间'),
sa.ForeignKeyConstraint(['chapter_id'], ['chapters.id'], ondelete='SET NULL'),
sa.ForeignKeyConstraint(['project_id'], ['projects.id'], ondelete='CASCADE'),
sa.PrimaryKeyConstraint('id')
)
op.create_table('organizations',
sa.Column('id', sa.String(length=36), nullable=False, comment='组织ID'),
sa.Column('character_id', sa.String(length=36), nullable=False, comment='关联的角色ID'),
sa.Column('project_id', sa.String(length=36), nullable=False, comment='项目ID'),
sa.Column('parent_org_id', sa.String(length=36), nullable=True, comment='父组织ID'),
sa.Column('level', sa.Integer(), nullable=True, comment='组织层级'),
sa.Column('power_level', sa.Integer(), nullable=True, comment='势力等级:0-100'),
sa.Column('member_count', sa.Integer(), nullable=True, comment='成员数量'),
sa.Column('location', sa.Text(), nullable=True, comment='所在地'),
sa.Column('motto', sa.String(length=200), nullable=True, comment='宗旨/口号'),
sa.Column('color', sa.String(length=100), nullable=True, comment='代表颜色'),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('now()'), nullable=True, comment='创建时间'),
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('now()'), nullable=True, comment='更新时间'),
sa.ForeignKeyConstraint(['character_id'], ['characters.id'], ondelete='CASCADE'),
sa.ForeignKeyConstraint(['parent_org_id'], ['organizations.id'], ondelete='SET NULL'),
sa.ForeignKeyConstraint(['project_id'], ['projects.id'], ondelete='CASCADE'),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('character_id')
)
op.create_index(op.f('ix_organizations_project_id'), 'organizations', ['project_id'], unique=False)
op.create_table('plot_analysis',
sa.Column('id', sa.String(length=36), nullable=False),
sa.Column('project_id', sa.String(length=36), nullable=False),
sa.Column('chapter_id', sa.String(length=36), nullable=False),
sa.Column('plot_stage', sa.String(length=50), nullable=True, comment='剧情阶段: 开端/发展/高潮/结局/过渡'),
sa.Column('conflict_level', sa.Integer(), nullable=True, comment='冲突强度 1-10'),
sa.Column('conflict_types', sa.JSON(), nullable=True, comment="冲突类型列表: ['人与人', '人与己', '人与环境']"),
sa.Column('emotional_tone', sa.String(length=100), nullable=True, comment='主导情感: 紧张/温馨/悲伤/激昂/平静'),
sa.Column('emotional_intensity', sa.Float(), nullable=True, comment='情感强度 0.0-1.0'),
sa.Column('emotional_curve', sa.JSON(), nullable=True, comment='情感曲线: {start: 0.3, middle: 0.7, end: 0.5}'),
sa.Column('hooks', sa.JSON(), nullable=True, comment='钩子列表 - 吸引读者的元素: [\n {\n "type": "悬念|情感|冲突|认知",\n "content": "具体内容",\n "strength": 8,\n "position": "开头|中段|结尾"\n }\n ]'),
sa.Column('hooks_count', sa.Integer(), nullable=True, comment='钩子数量'),
sa.Column('hooks_avg_strength', sa.Float(), nullable=True, comment='钩子平均强度'),
sa.Column('foreshadows', sa.JSON(), nullable=True, comment='伏笔列表: [\n {\n "content": "伏笔内容",\n "type": "planted|resolved",\n "strength": 7,\n "subtlety": 8,\n "reference_chapter": 3\n }\n ]'),
sa.Column('foreshadows_planted', sa.Integer(), nullable=True, comment='本章埋下的伏笔数量'),
sa.Column('foreshadows_resolved', sa.Integer(), nullable=True, comment='本章回收的伏笔数量'),
sa.Column('plot_points', sa.JSON(), nullable=True, comment='情节点列表: [\n {\n "content": "情节点描述",\n "importance": 0.9,\n "type": "revelation|conflict|resolution|transition",\n "impact": "对故事的影响描述"\n }\n ]'),
sa.Column('plot_points_count', sa.Integer(), nullable=True, comment='情节点数量'),
sa.Column('character_states', sa.JSON(), nullable=True, comment='角色状态变化: [\n {\n "character_id": "xxx",\n "character_name": "张三",\n "state_before": "犹豫不决",\n "state_after": "坚定信念",\n "psychological_change": "内心描述",\n "key_event": "触发事件",\n "relationship_changes": {"李四": "关系变化"}\n }\n ]'),
sa.Column('scenes', sa.JSON(), nullable=True, comment="场景列表: [{location: '地点', atmosphere: '氛围', duration: '时长'}]"),
sa.Column('pacing', sa.String(length=50), nullable=True, comment='节奏: slow|moderate|fast|varied'),
sa.Column('overall_quality_score', sa.Float(), nullable=True, comment='整体质量评分 0.0-10.0'),
sa.Column('pacing_score', sa.Float(), nullable=True, comment='节奏评分 0.0-10.0'),
sa.Column('engagement_score', sa.Float(), nullable=True, comment='吸引力评分 0.0-10.0'),
sa.Column('coherence_score', sa.Float(), nullable=True, comment='连贯性评分 0.0-10.0'),
sa.Column('analysis_report', sa.Text(), nullable=True, comment='完整的文字分析报告'),
sa.Column('suggestions', sa.JSON(), nullable=True, comment="改进建议列表: ['建议1', '建议2']"),
sa.Column('word_count', sa.Integer(), nullable=True, comment='章节字数'),
sa.Column('dialogue_ratio', sa.Float(), nullable=True, comment='对话占比 0.0-1.0'),
sa.Column('description_ratio', sa.Float(), nullable=True, comment='描写占比 0.0-1.0'),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('now()'), nullable=True, comment='分析时间'),
sa.ForeignKeyConstraint(['chapter_id'], ['chapters.id'], ondelete='CASCADE'),
sa.ForeignKeyConstraint(['project_id'], ['projects.id'], ondelete='CASCADE'),
sa.PrimaryKeyConstraint('id')
)
op.create_index(op.f('ix_plot_analysis_chapter_id'), 'plot_analysis', ['chapter_id'], unique=True)
op.create_index(op.f('ix_plot_analysis_project_id'), 'plot_analysis', ['project_id'], unique=False)
op.create_table('regeneration_tasks',
sa.Column('id', sa.String(length=36), nullable=False),
sa.Column('chapter_id', sa.String(length=36), nullable=False),
sa.Column('analysis_id', sa.String(length=36), nullable=True, comment='关联的分析结果ID'),
sa.Column('user_id', sa.String(length=50), nullable=False),
sa.Column('project_id', sa.String(length=36), nullable=False),
sa.Column('modification_instructions', sa.Text(), nullable=False, comment='综合修改指令'),
sa.Column('original_suggestions', sa.JSON(), nullable=True, comment='来自分析的原始建议列表'),
sa.Column('selected_suggestion_indices', sa.JSON(), nullable=True, comment='用户选择的建议索引'),
sa.Column('custom_instructions', sa.Text(), nullable=True, comment='用户自定义修改意见'),
sa.Column('style_id', sa.Integer(), nullable=True, comment='写作风格ID'),
sa.Column('target_word_count', sa.Integer(), nullable=True, comment='目标字数'),
sa.Column('focus_areas', sa.JSON(), nullable=True, comment='重点优化方向'),
sa.Column('preserve_elements', sa.JSON(), nullable=True, comment='需要保留的元素配置'),
sa.Column('status', sa.String(length=20), nullable=True, comment='pending/running/completed/failed'),
sa.Column('progress', sa.Integer(), nullable=True, comment='进度 0-100'),
sa.Column('error_message', sa.Text(), nullable=True),
sa.Column('original_content', sa.Text(), nullable=True, comment='原始章节内容快照'),
sa.Column('original_word_count', sa.Integer(), nullable=True, comment='原始字数'),
sa.Column('regenerated_content', sa.Text(), nullable=True, comment='重新生成的内容'),
sa.Column('regenerated_word_count', sa.Integer(), nullable=True, comment='新内容字数'),
sa.Column('version_number', sa.Integer(), nullable=True, comment='版本号'),
sa.Column('version_note', sa.String(length=500), nullable=True, comment='版本说明'),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('now()'), nullable=True),
sa.Column('started_at', sa.DateTime(), nullable=True),
sa.Column('completed_at', sa.DateTime(), nullable=True),
sa.ForeignKeyConstraint(['chapter_id'], ['chapters.id'], ondelete='CASCADE'),
sa.PrimaryKeyConstraint('id')
)
op.create_index(op.f('ix_regeneration_tasks_chapter_id'), 'regeneration_tasks', ['chapter_id'], unique=False)
op.create_index(op.f('ix_regeneration_tasks_project_id'), 'regeneration_tasks', ['project_id'], unique=False)
op.create_index(op.f('ix_regeneration_tasks_user_id'), 'regeneration_tasks', ['user_id'], unique=False)
op.create_table('story_memories',
sa.Column('id', sa.String(length=100), nullable=False),
sa.Column('project_id', sa.String(length=36), nullable=False),
sa.Column('chapter_id', sa.String(length=36), nullable=True),
sa.Column('memory_type', sa.String(length=50), nullable=False, comment='\n 记忆类型:\n - plot_point: 情节点\n - character_event: 角色事件\n - world_detail: 世界观细节\n - hook: 钩子(悬念/冲突)\n - foreshadow: 伏笔\n - dialogue: 重要对话\n - scene: 场景描写\n '),
sa.Column('title', sa.String(length=200), nullable=True, comment='记忆标题/简述'),
sa.Column('content', sa.Text(), nullable=False, comment='记忆内容摘要(100-500字)'),
sa.Column('full_context', sa.Text(), nullable=True, comment='完整上下文(可选,用于详细记录)'),
sa.Column('related_characters', sa.JSON(), nullable=True, comment="涉及角色ID列表: ['char_id_1', 'char_id_2']"),
sa.Column('related_locations', sa.JSON(), nullable=True, comment="涉及地点列表: ['地点1', '地点2']"),
sa.Column('tags', sa.JSON(), nullable=True, comment="标签列表: ['悬念', '转折', '伏笔', '高潮']"),
sa.Column('importance_score', sa.Float(), nullable=True, comment='重要性评分 0.0-1.0'),
sa.Column('story_timeline', sa.Integer(), nullable=False, comment='故事时间线位置(章节序号)'),
sa.Column('chapter_position', sa.Integer(), nullable=True, comment='章节内位置(字符位置)'),
sa.Column('text_length', sa.Integer(), nullable=True, comment='文本长度(字符数)'),
sa.Column('is_foreshadow', sa.Integer(), nullable=True, comment='伏笔状态: 0=普通记忆, 1=已埋下伏笔, 2=伏笔已回收'),
sa.Column('foreshadow_resolved_at', sa.String(length=100), nullable=True, comment='伏笔回收的章节ID'),
sa.Column('foreshadow_strength', sa.Float(), nullable=True, comment='伏笔强度 0.0-1.0'),
sa.Column('vector_id', sa.String(length=100), nullable=True, comment='向量数据库中的唯一ID'),
sa.Column('embedding_model', sa.String(length=100), nullable=True, comment='使用的embedding模型'),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('now()'), nullable=True, comment='创建时间'),
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('now()'), nullable=True, comment='更新时间'),
sa.ForeignKeyConstraint(['chapter_id'], ['chapters.id'], ondelete='CASCADE'),
sa.ForeignKeyConstraint(['foreshadow_resolved_at'], ['chapters.id'], ondelete='SET NULL'),
sa.ForeignKeyConstraint(['project_id'], ['projects.id'], ondelete='CASCADE'),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('vector_id')
)
op.create_index(op.f('ix_story_memories_chapter_id'), 'story_memories', ['chapter_id'], unique=False)
op.create_index(op.f('ix_story_memories_memory_type'), 'story_memories', ['memory_type'], unique=False)
op.create_index(op.f('ix_story_memories_project_id'), 'story_memories', ['project_id'], unique=False)
op.create_index(op.f('ix_story_memories_story_timeline'), 'story_memories', ['story_timeline'], unique=False)
op.create_table('organization_members',
sa.Column('id', sa.String(length=36), nullable=False, comment='成员关系ID'),
sa.Column('organization_id', sa.String(length=36), nullable=False, comment='组织ID'),
sa.Column('character_id', sa.String(length=36), nullable=False, comment='角色ID'),
sa.Column('position', sa.String(length=100), nullable=False, comment='职位名称'),
sa.Column('rank', sa.Integer(), nullable=True, comment='职位等级'),
sa.Column('status', sa.String(length=20), nullable=True, comment='状态:active/retired/expelled/deceased'),
sa.Column('joined_at', sa.String(length=100), nullable=True, comment='加入时间(故事时间)'),
sa.Column('left_at', sa.String(length=100), nullable=True, comment='离开时间(故事时间)'),
sa.Column('loyalty', sa.Integer(), nullable=True, comment='忠诚度:0-100'),
sa.Column('contribution', sa.Integer(), nullable=True, comment='贡献度:0-100'),
sa.Column('source', sa.String(length=20), nullable=True, comment='来源:ai/manual'),
sa.Column('notes', sa.Text(), nullable=True, comment='备注'),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('now()'), nullable=True, comment='创建时间'),
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('now()'), nullable=True, comment='更新时间'),
sa.ForeignKeyConstraint(['character_id'], ['characters.id'], ondelete='CASCADE'),
sa.ForeignKeyConstraint(['organization_id'], ['organizations.id'], ondelete='CASCADE'),
sa.PrimaryKeyConstraint('id')
)
op.create_index(op.f('ix_organization_members_character_id'), 'organization_members', ['character_id'], unique=False)
op.create_index(op.f('ix_organization_members_organization_id'), 'organization_members', ['organization_id'], unique=False)
# ### end Alembic commands ###
def downgrade() -> None:
"""降级数据库结构"""
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index(op.f('ix_organization_members_organization_id'), table_name='organization_members')
op.drop_index(op.f('ix_organization_members_character_id'), table_name='organization_members')
op.drop_table('organization_members')
op.drop_index(op.f('ix_story_memories_story_timeline'), table_name='story_memories')
op.drop_index(op.f('ix_story_memories_project_id'), table_name='story_memories')
op.drop_index(op.f('ix_story_memories_memory_type'), table_name='story_memories')
op.drop_index(op.f('ix_story_memories_chapter_id'), table_name='story_memories')
op.drop_table('story_memories')
op.drop_index(op.f('ix_regeneration_tasks_user_id'), table_name='regeneration_tasks')
op.drop_index(op.f('ix_regeneration_tasks_project_id'), table_name='regeneration_tasks')
op.drop_index(op.f('ix_regeneration_tasks_chapter_id'), table_name='regeneration_tasks')
op.drop_table('regeneration_tasks')
op.drop_index(op.f('ix_plot_analysis_project_id'), table_name='plot_analysis')
op.drop_index(op.f('ix_plot_analysis_chapter_id'), table_name='plot_analysis')
op.drop_table('plot_analysis')
op.drop_index(op.f('ix_organizations_project_id'), table_name='organizations')
op.drop_table('organizations')
op.drop_table('generation_history')
op.drop_index(op.f('ix_character_relationships_relationship_type_id'), table_name='character_relationships')
op.drop_index(op.f('ix_character_relationships_project_id'), table_name='character_relationships')
op.drop_index(op.f('ix_character_relationships_character_to_id'), table_name='character_relationships')
op.drop_index(op.f('ix_character_relationships_character_from_id'), table_name='character_relationships')
op.drop_table('character_relationships')
op.drop_index('idx_character_id', table_name='character_careers')
op.drop_index('idx_character_career', table_name='character_careers')
op.drop_index('idx_career_type', table_name='character_careers')
op.drop_table('character_careers')
op.drop_index('idx_status', table_name='analysis_tasks')
op.drop_index('idx_chapter_id_created', table_name='analysis_tasks')
op.drop_table('analysis_tasks')
op.drop_table('project_default_styles')
op.drop_table('characters')
op.drop_table('chapters')
op.drop_table('writing_styles')
op.drop_table('outlines')
op.drop_index('idx_type', table_name='careers')
op.drop_index('idx_project_id', table_name='careers')
op.drop_table('careers')
op.drop_index(op.f('ix_users_username'), table_name='users')
op.drop_index(op.f('ix_users_user_id'), table_name='users')
op.drop_index(op.f('ix_users_linuxdo_id'), table_name='users')
op.drop_table('users')
op.drop_index(op.f('ix_user_passwords_user_id'), table_name='user_passwords')
op.drop_table('user_passwords')
op.drop_index(op.f('ix_settings_user_id'), table_name='settings')
op.drop_index('idx_user_id', table_name='settings')
op.drop_table('settings')
op.drop_index(op.f('ix_relationship_types_id'), table_name='relationship_types')
op.drop_table('relationship_types')
op.drop_index(op.f('ix_prompt_templates_user_id'), table_name='prompt_templates')
op.drop_index('idx_user_template', table_name='prompt_templates')
op.drop_table('prompt_templates')
op.drop_index(op.f('ix_projects_user_id'), table_name='projects')
op.drop_table('projects')
op.drop_index(op.f('ix_mcp_plugins_user_id'), table_name='mcp_plugins')
op.drop_index('idx_user_plugin', table_name='mcp_plugins')
op.drop_index('idx_user_enabled', table_name='mcp_plugins')
op.drop_table('mcp_plugins')
op.drop_table('batch_generation_tasks')
# ### end Alembic commands ###
@@ -0,0 +1,181 @@
"""初始化预置数据
Revision ID: e411428f00c0
Revises: ee0a189f1532
Create Date: 2025-12-26 11:02:24.080526
"""
from typing import Sequence, Union
from datetime import datetime
from alembic import op
import sqlalchemy as sa
from sqlalchemy import table, column, String, Integer, Float, Text, Boolean, DateTime
# revision identifiers, used by Alembic.
revision: str = 'e411428f00c0'
down_revision: Union[str, None] = 'ee0a189f1532'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
"""插入预置数据"""
# ==================== 1. 插入关系类型数据 ====================
relationship_types_table = table(
'relationship_types',
column('name', String),
column('category', String),
column('reverse_name', String),
column('intimacy_range', String),
column('icon', String),
column('description', Text),
)
relationship_types_data = [
# 家庭关系
{"name": "父亲", "category": "family", "reverse_name": "子女", "intimacy_range": "high", "icon": "👨", "description": "父子/父女关系"},
{"name": "母亲", "category": "family", "reverse_name": "子女", "intimacy_range": "high", "icon": "👩", "description": "母子/母女关系"},
{"name": "兄弟", "category": "family", "reverse_name": "兄弟", "intimacy_range": "high", "icon": "👬", "description": "兄弟关系"},
{"name": "姐妹", "category": "family", "reverse_name": "姐妹", "intimacy_range": "high", "icon": "👭", "description": "姐妹关系"},
{"name": "子女", "category": "family", "reverse_name": "父母", "intimacy_range": "high", "icon": "👶", "description": "子女关系"},
{"name": "配偶", "category": "family", "reverse_name": "配偶", "intimacy_range": "high", "icon": "💑", "description": "夫妻关系"},
{"name": "恋人", "category": "family", "reverse_name": "恋人", "intimacy_range": "high", "icon": "💕", "description": "恋爱关系"},
# 社交关系
{"name": "师父", "category": "social", "reverse_name": "徒弟", "intimacy_range": "high", "icon": "🎓", "description": "师徒关系(师父视角)"},
{"name": "徒弟", "category": "social", "reverse_name": "师父", "intimacy_range": "high", "icon": "📚", "description": "师徒关系(徒弟视角)"},
{"name": "朋友", "category": "social", "reverse_name": "朋友", "intimacy_range": "medium", "icon": "🤝", "description": "朋友关系"},
{"name": "同学", "category": "social", "reverse_name": "同学", "intimacy_range": "medium", "icon": "🎒", "description": "同学关系"},
{"name": "邻居", "category": "social", "reverse_name": "邻居", "intimacy_range": "low", "icon": "🏘️", "description": "邻居关系"},
{"name": "知己", "category": "social", "reverse_name": "知己", "intimacy_range": "high", "icon": "💙", "description": "知心好友"},
# 职业关系
{"name": "上司", "category": "professional", "reverse_name": "下属", "intimacy_range": "low", "icon": "👔", "description": "上下级关系(上司视角)"},
{"name": "下属", "category": "professional", "reverse_name": "上司", "intimacy_range": "low", "icon": "💼", "description": "上下级关系(下属视角)"},
{"name": "同事", "category": "professional", "reverse_name": "同事", "intimacy_range": "medium", "icon": "🤵", "description": "同事关系"},
{"name": "合作伙伴", "category": "professional", "reverse_name": "合作伙伴", "intimacy_range": "medium", "icon": "🤜🤛", "description": "合作关系"},
# 敌对关系
{"name": "敌人", "category": "hostile", "reverse_name": "敌人", "intimacy_range": "low", "icon": "⚔️", "description": "敌对关系"},
{"name": "仇人", "category": "hostile", "reverse_name": "仇人", "intimacy_range": "low", "icon": "💢", "description": "仇恨关系"},
{"name": "竞争对手", "category": "hostile", "reverse_name": "竞争对手", "intimacy_range": "low", "icon": "🎯", "description": "竞争关系"},
{"name": "宿敌", "category": "hostile", "reverse_name": "宿敌", "intimacy_range": "low", "icon": "", "description": "宿命之敌"},
]
op.bulk_insert(relationship_types_table, relationship_types_data)
print(f"✅ 已插入 {len(relationship_types_data)} 条关系类型数据")
# ==================== 2. 插入全局写作风格预设 ====================
# 注意:这里需要从 WritingStyleManager 获取预设配置
# 为了避免导入应用代码,我们直接硬编码预设风格
writing_styles_table = table(
'writing_styles',
column('user_id', String),
column('name', String),
column('style_type', String),
column('preset_id', String),
column('description', Text),
column('prompt_content', Text),
column('order_index', Integer),
)
writing_styles_data = [
{
"user_id": None, # NULL 表示全局预设
"name": "自然流畅",
"style_type": "preset",
"preset_id": "natural",
"description": "自然流畅的叙事风格,适合现代都市、现实题材",
"prompt_content": """写作风格要求:
1. 语言简洁明快,贴近现代口语
2. 多用短句,节奏流畅
3. 注重情感细节的自然流露
4. 避免过度修饰和复杂句式""",
"order_index": 1
},
{
"user_id": None,
"name": "古典优雅",
"style_type": "preset",
"preset_id": "classical",
"description": "古典文雅的写作风格,适合古装、仙侠题材",
"prompt_content": """写作风格要求:
1. 使用文言、半文言或典雅的白话
2. 适当运用古典诗词意象
3. 注重意境营造和韵味
4. 对话和描写保持古典美感""",
"order_index": 2
},
{
"user_id": None,
"name": "现代简约",
"style_type": "preset",
"preset_id": "modern",
"description": "现代简约风格,适合轻小说、网文快节奏叙事",
"prompt_content": """写作风格要求:
1. 语言直白简练,信息密度高
2. 多用对话推进情节
3. 避免冗长描写,突出关键动作
4. 节奏明快,适合快速阅读""",
"order_index": 3
},
{
"user_id": None,
"name": "文艺细腻",
"style_type": "preset",
"preset_id": "literary",
"description": "文艺细腻风格,注重心理描写和氛围营造",
"prompt_content": """写作风格要求:
1. 注重心理活动和情感细节
2. 善用环境描写烘托氛围
3. 语言优美,富有文学性
4. 适当使用比喻、象征等修辞手法""",
"order_index": 4
},
{
"user_id": None,
"name": "紧张悬疑",
"style_type": "preset",
"preset_id": "suspense",
"description": "紧张悬疑风格,适合推理、惊悚题材",
"prompt_content": """写作风格要求:
1. 营造紧张压迫的氛围
2. 多用短句加快节奏
3. 善于设置悬念和伏笔
4. 注重细节描写,为推理埋下线索""",
"order_index": 5
},
{
"user_id": None,
"name": "幽默诙谐",
"style_type": "preset",
"preset_id": "humorous",
"description": "幽默诙谐风格,适合轻松搞笑题材",
"prompt_content": """写作风格要求:
1. 语言活泼风趣,善用俏皮话
2. 注重对话的喜剧效果
3. 适当夸张和反转制造笑点
4. 保持轻松愉快的基调""",
"order_index": 6
},
]
op.bulk_insert(writing_styles_table, writing_styles_data)
print(f"✅ 已插入 {len(writing_styles_data)} 条全局写作风格预设")
def downgrade() -> None:
"""删除预置数据"""
# 删除写作风格预设(只删除全局预设)
op.execute("DELETE FROM writing_styles WHERE user_id IS NULL")
print("✅ 已删除全局写作风格预设")
# 删除关系类型
op.execute("DELETE FROM relationship_types")
print("✅ 已删除关系类型数据")
@@ -0,0 +1,30 @@
"""添加system_prompt字段到settings表
Revision ID: a7e4408e1d5b
Revises: e411428f00c0
Create Date: 2025-12-27 15:41:22.310160
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = 'a7e4408e1d5b'
down_revision: Union[str, None] = 'e411428f00c0'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column('settings', sa.Column('system_prompt', sa.Text(), nullable=True, comment='系统级别提示词,每次AI调用都会使用'))
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column('settings', 'system_prompt')
# ### end Alembic commands ###
+2
View File
@@ -0,0 +1,2 @@
# 此文件确保 versions 目录被 Git 追踪
# 迁移版本文件将存放在此目录
+102
View File
@@ -0,0 +1,102 @@
"""Alembic 环境配置文件 - SQLite"""
import asyncio
import os
import sys
from logging.config import fileConfig
from sqlalchemy import pool
from sqlalchemy.engine import Connection
from sqlalchemy.ext.asyncio import async_engine_from_config
from alembic import context
# 添加项目根目录到 Python 路径
project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
if project_root not in sys.path:
sys.path.insert(0, project_root)
# 导入应用配置
from app.config import settings
# 导入 Base 和所有模型
from app.database import Base
from app.models import (
Project, Outline, Character, Chapter, GenerationHistory,
Settings, WritingStyle, ProjectDefaultStyle,
RelationshipType, CharacterRelationship, Organization, OrganizationMember,
StoryMemory, PlotAnalysis, AnalysisTask, BatchGenerationTask,
RegenerationTask, Career, CharacterCareer, User, MCPPlugin, PromptTemplate
)
# Alembic Config 对象
config = context.config
# 设置数据库连接字符串(从环境变量读取)
config.set_main_option("sqlalchemy.url", settings.database_url)
# 配置日志
if config.config_file_name is not None:
fileConfig(config.config_file_name)
# 设置 target_metadata 为应用的 Base.metadata
target_metadata = Base.metadata
def run_migrations_offline() -> None:
"""'离线'模式下运行迁移"""
url = config.get_main_option("sqlalchemy.url")
context.configure(
url=url,
target_metadata=target_metadata,
literal_binds=True,
dialect_opts={"paramstyle": "named"},
compare_type=True,
compare_server_default=True,
render_as_batch=True, # SQLite 必须启用批处理模式
)
with context.begin_transaction():
context.run_migrations()
def do_run_migrations(connection: Connection) -> None:
"""执行迁移的核心函数 - SQLite 专用"""
context.configure(
connection=connection,
target_metadata=target_metadata,
compare_type=True,
compare_server_default=True,
render_as_batch=True, # SQLite 必须启用批处理模式
)
with context.begin_transaction():
context.run_migrations()
async def run_async_migrations() -> None:
"""'在线'模式下运行异步迁移"""
configuration = config.get_section(config.config_ini_section, {})
configuration["sqlalchemy.url"] = settings.database_url
connectable = async_engine_from_config(
configuration,
prefix="sqlalchemy.",
poolclass=pool.NullPool,
)
async with connectable.connect() as connection:
await connection.run_sync(do_run_migrations)
await connectable.dispose()
def run_migrations_online() -> None:
"""'在线'模式下运行迁移"""
asyncio.run(run_async_migrations())
# 根据上下文选择运行模式
if context.is_offline_mode():
run_migrations_offline()
else:
run_migrations_online()
+26
View File
@@ -0,0 +1,26 @@
"""${message}
Revision ID: ${up_revision}
Revises: ${down_revision | comma,n}
Create Date: ${create_date}
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
${imports if imports else ""}
# revision identifiers, used by Alembic.
revision: str = ${repr(up_revision)}
down_revision: Union[str, None] = ${repr(down_revision)}
branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
def upgrade() -> None:
${upgrades if upgrades else "pass"}
def downgrade() -> None:
${downgrades if downgrades else "pass"}
@@ -0,0 +1,616 @@
"""初始化SQLite数据库
Revision ID: fbeb1038c728
Revises:
Create Date: 2025-12-26 13:22:53.151546
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = 'fbeb1038c728'
down_revision: Union[str, None] = None
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('batch_generation_tasks',
sa.Column('id', sa.String(length=36), nullable=False),
sa.Column('project_id', sa.String(length=36), nullable=False, comment='项目ID'),
sa.Column('user_id', sa.String(length=100), nullable=False, comment='用户ID'),
sa.Column('start_chapter_number', sa.Integer(), nullable=False, comment='起始章节序号'),
sa.Column('chapter_count', sa.Integer(), nullable=False, comment='生成章节数量'),
sa.Column('chapter_ids', sa.JSON(), nullable=False, comment='待生成的章节ID列表'),
sa.Column('style_id', sa.Integer(), nullable=True, comment='使用的写作风格ID'),
sa.Column('target_word_count', sa.Integer(), nullable=True, comment='目标字数'),
sa.Column('enable_analysis', sa.Boolean(), nullable=True, comment='是否启用同步分析'),
sa.Column('status', sa.String(length=20), nullable=True, comment='任务状态: pending/running/completed/failed/cancelled'),
sa.Column('total_chapters', sa.Integer(), nullable=True, comment='总章节数'),
sa.Column('completed_chapters', sa.Integer(), nullable=True, comment='已完成章节数'),
sa.Column('failed_chapters', sa.JSON(), nullable=True, comment='失败的章节信息列表'),
sa.Column('current_chapter_id', sa.String(length=36), nullable=True, comment='当前正在生成的章节ID'),
sa.Column('current_chapter_number', sa.Integer(), nullable=True, comment='当前正在生成的章节序号'),
sa.Column('current_retry_count', sa.Integer(), nullable=True, comment='当前章节重试次数'),
sa.Column('max_retries', sa.Integer(), nullable=True, comment='最大重试次数'),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=True, comment='创建时间'),
sa.Column('started_at', sa.DateTime(), nullable=True, comment='开始时间'),
sa.Column('completed_at', sa.DateTime(), nullable=True, comment='完成时间'),
sa.Column('error_message', sa.String(length=500), nullable=True, comment='错误信息'),
sa.PrimaryKeyConstraint('id')
)
op.create_table('mcp_plugins',
sa.Column('id', sa.String(length=36), nullable=False),
sa.Column('user_id', sa.String(length=50), nullable=False, comment='用户ID'),
sa.Column('plugin_name', sa.String(length=100), nullable=False, comment='插件名称(唯一标识)'),
sa.Column('display_name', sa.String(length=200), nullable=False, comment='显示名称'),
sa.Column('description', sa.Text(), nullable=True, comment='插件描述'),
sa.Column('plugin_type', sa.String(length=50), nullable=True, comment='插件类型:http/stdio'),
sa.Column('server_url', sa.String(length=500), nullable=True, comment='服务器URLHTTP类型)'),
sa.Column('command', sa.String(length=500), nullable=True, comment='启动命令(stdio类型)'),
sa.Column('args', sa.JSON(), nullable=True, comment='命令参数(stdio类型)'),
sa.Column('env', sa.JSON(), nullable=True, comment='环境变量'),
sa.Column('headers', sa.JSON(), nullable=True, comment='HTTP请求头'),
sa.Column('config', sa.JSON(), nullable=True, comment='插件特定配置(JSON'),
sa.Column('tools', sa.JSON(), nullable=True, comment='提供的工具列表'),
sa.Column('enabled', sa.Boolean(), nullable=True, comment='是否启用'),
sa.Column('status', sa.String(length=50), nullable=True, comment='状态:active/inactive/error'),
sa.Column('last_error', sa.Text(), nullable=True, comment='最后错误信息'),
sa.Column('last_test_at', sa.DateTime(), nullable=True, comment='最后测试时间'),
sa.Column('category', sa.String(length=100), nullable=True, comment='分类'),
sa.Column('sort_order', sa.Integer(), nullable=True, comment='排序顺序'),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=True, comment='创建时间'),
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=True, comment='更新时间'),
sa.PrimaryKeyConstraint('id')
)
with op.batch_alter_table('mcp_plugins', schema=None) as batch_op:
batch_op.create_index('idx_user_enabled', ['user_id', 'enabled'], unique=False)
batch_op.create_index('idx_user_plugin', ['user_id', 'plugin_name'], unique=True)
batch_op.create_index(batch_op.f('ix_mcp_plugins_user_id'), ['user_id'], unique=False)
op.create_table('projects',
sa.Column('id', sa.String(length=36), nullable=False),
sa.Column('user_id', sa.String(length=100), nullable=False, comment='用户ID'),
sa.Column('title', sa.String(length=200), nullable=False, comment='项目标题'),
sa.Column('description', sa.Text(), nullable=True, comment='项目简介'),
sa.Column('theme', sa.Text(), nullable=True, comment='主题'),
sa.Column('genre', sa.String(length=50), nullable=True, comment='小说类型'),
sa.Column('target_words', sa.Integer(), nullable=True, comment='目标字数'),
sa.Column('current_words', sa.Integer(), nullable=True, comment='当前字数'),
sa.Column('status', sa.String(length=20), nullable=True, comment='创作状态'),
sa.Column('wizard_status', sa.String(length=20), nullable=True, comment='向导完成状态: incomplete/completed'),
sa.Column('wizard_step', sa.Integer(), nullable=True, comment='向导当前步骤: 0-4'),
sa.Column('outline_mode', sa.String(length=20), nullable=False, comment='大纲章节模式: one-to-one(传统模式) 或 one-to-many(细化模式)'),
sa.Column('world_time_period', sa.Text(), nullable=True, comment='时间背景'),
sa.Column('world_location', sa.Text(), nullable=True, comment='地理位置'),
sa.Column('world_atmosphere', sa.Text(), nullable=True, comment='氛围基调'),
sa.Column('world_rules', sa.Text(), nullable=True, comment='世界规则'),
sa.Column('chapter_count', sa.Integer(), nullable=True, comment='章节数量'),
sa.Column('narrative_perspective', sa.String(length=50), nullable=True, comment='叙事视角:first_person/third_person/omniscient'),
sa.Column('character_count', sa.Integer(), nullable=True, comment='角色数量'),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=True, comment='创建时间'),
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=True, comment='更新时间'),
sa.CheckConstraint("outline_mode IN ('one-to-one', 'one-to-many')", name='check_outline_mode'),
sa.PrimaryKeyConstraint('id')
)
with op.batch_alter_table('projects', schema=None) as batch_op:
batch_op.create_index(batch_op.f('ix_projects_user_id'), ['user_id'], unique=False)
op.create_table('prompt_templates',
sa.Column('id', sa.String(length=36), nullable=False),
sa.Column('user_id', sa.String(length=50), nullable=False, comment='用户ID'),
sa.Column('template_key', sa.String(length=100), nullable=False, comment='模板键名'),
sa.Column('template_name', sa.String(length=200), nullable=False, comment='模板显示名称'),
sa.Column('template_content', sa.Text(), nullable=False, comment='模板内容'),
sa.Column('description', sa.Text(), nullable=True, comment='模板描述'),
sa.Column('category', sa.String(length=50), nullable=True, comment='模板分类'),
sa.Column('parameters', sa.Text(), nullable=True, comment='模板参数定义(JSON)'),
sa.Column('is_active', sa.Boolean(), nullable=True, comment='是否启用'),
sa.Column('is_system_default', sa.Boolean(), nullable=True, comment='是否为系统默认模板'),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=True, comment='创建时间'),
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=True, comment='更新时间'),
sa.PrimaryKeyConstraint('id')
)
with op.batch_alter_table('prompt_templates', schema=None) as batch_op:
batch_op.create_index('idx_user_template', ['user_id', 'template_key'], unique=True)
batch_op.create_index(batch_op.f('ix_prompt_templates_user_id'), ['user_id'], unique=False)
op.create_table('relationship_types',
sa.Column('id', sa.Integer(), autoincrement=True, nullable=False),
sa.Column('name', sa.String(length=50), nullable=False, comment='关系名称'),
sa.Column('category', sa.String(length=20), nullable=False, comment='分类:family/social/hostile/professional'),
sa.Column('reverse_name', sa.String(length=50), nullable=True, comment='反向关系名称'),
sa.Column('intimacy_range', sa.String(length=20), nullable=True, comment='亲密度范围:high/medium/low'),
sa.Column('icon', sa.String(length=50), nullable=True, comment='图标标识'),
sa.Column('description', sa.Text(), nullable=True, comment='关系描述'),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=True, comment='创建时间'),
sa.PrimaryKeyConstraint('id')
)
with op.batch_alter_table('relationship_types', schema=None) as batch_op:
batch_op.create_index(batch_op.f('ix_relationship_types_id'), ['id'], unique=False)
op.create_table('settings',
sa.Column('id', sa.String(length=36), nullable=False),
sa.Column('user_id', sa.String(length=50), nullable=False, comment='用户ID'),
sa.Column('api_provider', sa.String(length=50), nullable=True, comment='API提供商'),
sa.Column('api_key', sa.String(length=500), nullable=True, comment='API密钥'),
sa.Column('api_base_url', sa.String(length=500), nullable=True, comment='自定义API地址'),
sa.Column('llm_model', sa.String(length=100), nullable=True, comment='模型名称'),
sa.Column('temperature', sa.Float(), nullable=True, comment='温度参数'),
sa.Column('max_tokens', sa.Integer(), nullable=True, comment='最大token数'),
sa.Column('preferences', sa.Text(), nullable=True, comment='其他偏好设置(JSON)'),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=True, comment='创建时间'),
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=True, comment='更新时间'),
sa.PrimaryKeyConstraint('id')
)
with op.batch_alter_table('settings', schema=None) as batch_op:
batch_op.create_index('idx_user_id', ['user_id'], unique=False)
batch_op.create_index(batch_op.f('ix_settings_user_id'), ['user_id'], unique=True)
op.create_table('user_passwords',
sa.Column('user_id', sa.String(length=100), nullable=False, comment='用户ID'),
sa.Column('username', sa.String(length=100), nullable=False, comment='用户名'),
sa.Column('password_hash', sa.String(length=64), nullable=False, comment='密码哈希(SHA256'),
sa.Column('has_custom_password', sa.Boolean(), nullable=True, comment='是否为自定义密码'),
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=True, comment='创建时间'),
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=True, comment='更新时间'),
sa.PrimaryKeyConstraint('user_id')
)
with op.batch_alter_table('user_passwords', schema=None) as batch_op:
batch_op.create_index(batch_op.f('ix_user_passwords_user_id'), ['user_id'], unique=False)
op.create_table('users',
sa.Column('user_id', sa.String(length=100), nullable=False, comment='用户ID,格式:linuxdo_{id} 或 local_{id}'),
sa.Column('username', sa.String(length=100), nullable=False, comment='用户名'),
sa.Column('display_name', sa.String(length=200), nullable=False, comment='显示名称'),
sa.Column('avatar_url', sa.String(length=500), nullable=True, comment='头像URL'),
sa.Column('trust_level', sa.Integer(), nullable=True, comment='信任等级(仅用于显示)'),
sa.Column('is_admin', sa.Boolean(), nullable=True, comment='是否为管理员'),
sa.Column('linuxdo_id', sa.String(length=100), nullable=False, comment='LinuxDO用户ID或本地用户ID'),
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=True, comment='创建时间'),
sa.Column('last_login', sa.DateTime(timezone=True), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=True, comment='最后登录时间'),
sa.PrimaryKeyConstraint('user_id')
)
with op.batch_alter_table('users', schema=None) as batch_op:
batch_op.create_index(batch_op.f('ix_users_linuxdo_id'), ['linuxdo_id'], unique=True)
batch_op.create_index(batch_op.f('ix_users_user_id'), ['user_id'], unique=False)
batch_op.create_index(batch_op.f('ix_users_username'), ['username'], unique=False)
op.create_table('careers',
sa.Column('id', sa.String(length=36), nullable=False),
sa.Column('project_id', sa.String(length=36), nullable=False),
sa.Column('name', sa.String(length=100), nullable=False, comment='职业名称'),
sa.Column('type', sa.String(length=20), nullable=False, comment='职业类型: main(主职业)/sub(副职业)'),
sa.Column('description', sa.Text(), nullable=True, comment='职业描述'),
sa.Column('category', sa.String(length=50), nullable=True, comment='职业分类(如:战斗系、生产系、辅助系)'),
sa.Column('stages', sa.Text(), nullable=False, comment="职业阶段列表(JSON): [{level:1, name:'', description:''}, ...]"),
sa.Column('max_stage', sa.Integer(), nullable=False, comment='最大阶段数'),
sa.Column('requirements', sa.Text(), nullable=True, comment='职业要求/限制'),
sa.Column('special_abilities', sa.Text(), nullable=True, comment='特殊能力描述'),
sa.Column('worldview_rules', sa.Text(), nullable=True, comment='世界观规则关联'),
sa.Column('attribute_bonuses', sa.Text(), nullable=True, comment="属性加成(JSON): {strength: '+10%', intelligence: '+5%'}"),
sa.Column('source', sa.String(length=20), nullable=True, comment='来源: ai/manual'),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=True, comment='创建时间'),
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=True, comment='更新时间'),
sa.ForeignKeyConstraint(['project_id'], ['projects.id'], ondelete='CASCADE'),
sa.PrimaryKeyConstraint('id')
)
with op.batch_alter_table('careers', schema=None) as batch_op:
batch_op.create_index('idx_project_id', ['project_id'], unique=False)
batch_op.create_index('idx_type', ['type'], unique=False)
op.create_table('outlines',
sa.Column('id', sa.String(length=36), nullable=False),
sa.Column('project_id', sa.String(length=36), nullable=False),
sa.Column('title', sa.String(length=200), nullable=False, comment='大纲标题'),
sa.Column('content', sa.Text(), nullable=True, comment='大纲内容'),
sa.Column('structure', sa.Text(), nullable=True, comment='结构化大纲数据(JSON)'),
sa.Column('order_index', sa.Integer(), nullable=True, comment='排序序号'),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=True, comment='创建时间'),
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=True, comment='更新时间'),
sa.ForeignKeyConstraint(['project_id'], ['projects.id'], ondelete='CASCADE'),
sa.PrimaryKeyConstraint('id')
)
op.create_table('writing_styles',
sa.Column('id', sa.Integer(), autoincrement=True, nullable=False),
sa.Column('user_id', sa.String(length=255), nullable=True, comment='所属用户ID(NULL表示全局预设风格)'),
sa.Column('name', sa.String(length=100), nullable=False, comment='风格名称'),
sa.Column('style_type', sa.String(length=50), nullable=False, comment='风格类型:preset/custom'),
sa.Column('preset_id', sa.String(length=50), nullable=True, comment='预设风格IDnatural/classical/modern等'),
sa.Column('description', sa.Text(), nullable=True, comment='风格描述'),
sa.Column('prompt_content', sa.Text(), nullable=False, comment='风格提示词内容'),
sa.Column('order_index', sa.Integer(), nullable=True, comment='排序序号'),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=True, comment='创建时间'),
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=True, comment='更新时间'),
sa.ForeignKeyConstraint(['user_id'], ['users.user_id'], ondelete='CASCADE'),
sa.PrimaryKeyConstraint('id')
)
op.create_table('chapters',
sa.Column('id', sa.String(length=36), nullable=False),
sa.Column('project_id', sa.String(length=36), nullable=False),
sa.Column('chapter_number', sa.Integer(), nullable=False, comment='章节序号'),
sa.Column('title', sa.String(length=200), nullable=False, comment='章节标题'),
sa.Column('content', sa.Text(), nullable=True, comment='章节内容'),
sa.Column('summary', sa.Text(), nullable=True, comment='章节摘要'),
sa.Column('word_count', sa.Integer(), nullable=True, comment='字数统计'),
sa.Column('status', sa.String(length=20), nullable=True, comment='章节状态'),
sa.Column('outline_id', sa.String(length=36), nullable=True, comment='关联的大纲ID'),
sa.Column('sub_index', sa.Integer(), nullable=True, comment='大纲下的子章节序号'),
sa.Column('expansion_plan', sa.Text(), nullable=True, comment='展开规划详情(JSON): 包含key_events, character_focus, emotional_tone等'),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=True, comment='创建时间'),
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=True, comment='更新时间'),
sa.ForeignKeyConstraint(['outline_id'], ['outlines.id'], ondelete='SET NULL'),
sa.ForeignKeyConstraint(['project_id'], ['projects.id'], ondelete='CASCADE'),
sa.PrimaryKeyConstraint('id')
)
op.create_table('characters',
sa.Column('id', sa.String(length=36), nullable=False),
sa.Column('project_id', sa.String(length=36), nullable=False),
sa.Column('name', sa.String(length=100), nullable=False, comment='角色/组织名称'),
sa.Column('age', sa.String(length=50), nullable=True, comment='年龄'),
sa.Column('gender', sa.String(length=50), nullable=True, comment='性别'),
sa.Column('is_organization', sa.Boolean(), nullable=True, comment='是否为组织'),
sa.Column('role_type', sa.String(length=50), nullable=True, comment='角色类型'),
sa.Column('personality', sa.Text(), nullable=True, comment='性格特点/组织特性'),
sa.Column('background', sa.Text(), nullable=True, comment='背景故事'),
sa.Column('appearance', sa.Text(), nullable=True, comment='外貌描述'),
sa.Column('relationships', sa.Text(), nullable=True, comment='人物关系(JSON)'),
sa.Column('organization_type', sa.String(length=100), nullable=True, comment='组织类型'),
sa.Column('organization_purpose', sa.String(length=500), nullable=True, comment='组织目的'),
sa.Column('organization_members', sa.Text(), nullable=True, comment='组织成员(JSON)'),
sa.Column('main_career_id', sa.String(length=36), nullable=True, comment='主职业ID'),
sa.Column('main_career_stage', sa.Integer(), nullable=True, comment='主职业当前阶段'),
sa.Column('sub_careers', sa.Text(), nullable=True, comment='副职业列表(JSON): [{"career_id": "xxx", "stage": 3}, ...]'),
sa.Column('avatar_url', sa.String(length=500), nullable=True, comment='头像URL'),
sa.Column('traits', sa.Text(), nullable=True, comment='特征标签(JSON)'),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=True, comment='创建时间'),
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=True, comment='更新时间'),
sa.ForeignKeyConstraint(['main_career_id'], ['careers.id'], ondelete='SET NULL'),
sa.ForeignKeyConstraint(['project_id'], ['projects.id'], ondelete='CASCADE'),
sa.PrimaryKeyConstraint('id')
)
op.create_table('project_default_styles',
sa.Column('id', sa.Integer(), autoincrement=True, nullable=False),
sa.Column('project_id', sa.String(length=36), nullable=False, comment='项目ID'),
sa.Column('style_id', sa.Integer(), nullable=False, comment='风格ID'),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=True, comment='创建时间'),
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=True, comment='更新时间'),
sa.ForeignKeyConstraint(['project_id'], ['projects.id'], ondelete='CASCADE'),
sa.ForeignKeyConstraint(['style_id'], ['writing_styles.id'], ondelete='CASCADE'),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('project_id', name='uix_project_default_style')
)
op.create_table('analysis_tasks',
sa.Column('id', sa.String(length=36), nullable=False, comment='任务ID'),
sa.Column('chapter_id', sa.String(length=36), nullable=False, comment='章节ID'),
sa.Column('user_id', sa.String(length=50), nullable=False, comment='用户ID'),
sa.Column('project_id', sa.String(length=36), nullable=False, comment='项目ID'),
sa.Column('status', sa.String(length=20), nullable=False, comment='任务状态: pending/running/completed/failed'),
sa.Column('progress', sa.Integer(), nullable=True, comment='进度 0-100'),
sa.Column('error_message', sa.Text(), nullable=True, comment='错误信息'),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=True, comment='创建时间'),
sa.Column('started_at', sa.DateTime(), nullable=True, comment='开始执行时间'),
sa.Column('completed_at', sa.DateTime(), nullable=True, comment='完成时间'),
sa.ForeignKeyConstraint(['chapter_id'], ['chapters.id'], ondelete='CASCADE'),
sa.PrimaryKeyConstraint('id')
)
with op.batch_alter_table('analysis_tasks', schema=None) as batch_op:
batch_op.create_index('idx_chapter_id_created', ['chapter_id', 'created_at'], unique=False)
batch_op.create_index('idx_status', ['status'], unique=False)
op.create_table('character_careers',
sa.Column('id', sa.String(length=36), nullable=False),
sa.Column('character_id', sa.String(length=36), nullable=False),
sa.Column('career_id', sa.String(length=36), nullable=False),
sa.Column('career_type', sa.String(length=20), nullable=False, comment='main(主职业)/sub(副职业)'),
sa.Column('current_stage', sa.Integer(), nullable=False, comment='当前阶段(对应职业中的数值)'),
sa.Column('stage_progress', sa.Integer(), nullable=True, comment='阶段内进度(0-100'),
sa.Column('started_at', sa.String(length=100), nullable=True, comment='开始修炼时间(小说时间线)'),
sa.Column('reached_current_stage_at', sa.String(length=100), nullable=True, comment='到达当前阶段时间'),
sa.Column('notes', sa.Text(), nullable=True, comment='备注(如:修炼心得、特殊事件)'),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=True, comment='创建时间'),
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=True, comment='更新时间'),
sa.ForeignKeyConstraint(['career_id'], ['careers.id'], ondelete='CASCADE'),
sa.ForeignKeyConstraint(['character_id'], ['characters.id'], ondelete='CASCADE'),
sa.PrimaryKeyConstraint('id')
)
with op.batch_alter_table('character_careers', schema=None) as batch_op:
batch_op.create_index('idx_career_type', ['career_type'], unique=False)
batch_op.create_index('idx_character_career', ['character_id', 'career_id'], unique=True)
batch_op.create_index('idx_character_id', ['character_id'], unique=False)
op.create_table('character_relationships',
sa.Column('id', sa.String(length=36), nullable=False, comment='关系ID'),
sa.Column('project_id', sa.String(length=36), nullable=False, comment='项目ID'),
sa.Column('character_from_id', sa.String(length=36), nullable=False, comment='角色A的ID'),
sa.Column('character_to_id', sa.String(length=36), nullable=False, comment='角色B的ID'),
sa.Column('relationship_type_id', sa.Integer(), nullable=True, comment='关系类型ID'),
sa.Column('relationship_name', sa.String(length=100), nullable=True, comment='自定义关系名称'),
sa.Column('intimacy_level', sa.Integer(), nullable=True, comment='亲密度:-100到100'),
sa.Column('status', sa.String(length=20), nullable=True, comment='状态:active/broken/past/complicated'),
sa.Column('description', sa.Text(), nullable=True, comment='关系详细描述'),
sa.Column('started_at', sa.String(length=100), nullable=True, comment='关系开始时间(故事时间)'),
sa.Column('ended_at', sa.String(length=100), nullable=True, comment='关系结束时间(故事时间)'),
sa.Column('source', sa.String(length=20), nullable=True, comment='来源:ai/manual/imported'),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=True, comment='创建时间'),
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=True, comment='更新时间'),
sa.ForeignKeyConstraint(['character_from_id'], ['characters.id'], ondelete='CASCADE'),
sa.ForeignKeyConstraint(['character_to_id'], ['characters.id'], ondelete='CASCADE'),
sa.ForeignKeyConstraint(['project_id'], ['projects.id'], ondelete='CASCADE'),
sa.ForeignKeyConstraint(['relationship_type_id'], ['relationship_types.id'], ),
sa.PrimaryKeyConstraint('id')
)
with op.batch_alter_table('character_relationships', schema=None) as batch_op:
batch_op.create_index(batch_op.f('ix_character_relationships_character_from_id'), ['character_from_id'], unique=False)
batch_op.create_index(batch_op.f('ix_character_relationships_character_to_id'), ['character_to_id'], unique=False)
batch_op.create_index(batch_op.f('ix_character_relationships_project_id'), ['project_id'], unique=False)
batch_op.create_index(batch_op.f('ix_character_relationships_relationship_type_id'), ['relationship_type_id'], unique=False)
op.create_table('generation_history',
sa.Column('id', sa.String(length=36), nullable=False),
sa.Column('project_id', sa.String(length=36), nullable=False),
sa.Column('chapter_id', sa.String(length=36), nullable=True),
sa.Column('prompt', sa.Text(), nullable=True, comment='使用的提示词'),
sa.Column('generated_content', sa.Text(), nullable=True, comment='生成的内容'),
sa.Column('model', sa.String(length=50), nullable=True, comment='使用的模型'),
sa.Column('tokens_used', sa.Integer(), nullable=True, comment='消耗的token数'),
sa.Column('generation_time', sa.Float(), nullable=True, comment='生成耗时(秒)'),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=True, comment='创建时间'),
sa.ForeignKeyConstraint(['chapter_id'], ['chapters.id'], ondelete='SET NULL'),
sa.ForeignKeyConstraint(['project_id'], ['projects.id'], ondelete='CASCADE'),
sa.PrimaryKeyConstraint('id')
)
op.create_table('organizations',
sa.Column('id', sa.String(length=36), nullable=False, comment='组织ID'),
sa.Column('character_id', sa.String(length=36), nullable=False, comment='关联的角色ID'),
sa.Column('project_id', sa.String(length=36), nullable=False, comment='项目ID'),
sa.Column('parent_org_id', sa.String(length=36), nullable=True, comment='父组织ID'),
sa.Column('level', sa.Integer(), nullable=True, comment='组织层级'),
sa.Column('power_level', sa.Integer(), nullable=True, comment='势力等级:0-100'),
sa.Column('member_count', sa.Integer(), nullable=True, comment='成员数量'),
sa.Column('location', sa.Text(), nullable=True, comment='所在地'),
sa.Column('motto', sa.String(length=200), nullable=True, comment='宗旨/口号'),
sa.Column('color', sa.String(length=100), nullable=True, comment='代表颜色'),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=True, comment='创建时间'),
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=True, comment='更新时间'),
sa.ForeignKeyConstraint(['character_id'], ['characters.id'], ondelete='CASCADE'),
sa.ForeignKeyConstraint(['parent_org_id'], ['organizations.id'], ondelete='SET NULL'),
sa.ForeignKeyConstraint(['project_id'], ['projects.id'], ondelete='CASCADE'),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('character_id')
)
with op.batch_alter_table('organizations', schema=None) as batch_op:
batch_op.create_index(batch_op.f('ix_organizations_project_id'), ['project_id'], unique=False)
op.create_table('plot_analysis',
sa.Column('id', sa.String(length=36), nullable=False),
sa.Column('project_id', sa.String(length=36), nullable=False),
sa.Column('chapter_id', sa.String(length=36), nullable=False),
sa.Column('plot_stage', sa.String(length=50), nullable=True, comment='剧情阶段: 开端/发展/高潮/结局/过渡'),
sa.Column('conflict_level', sa.Integer(), nullable=True, comment='冲突强度 1-10'),
sa.Column('conflict_types', sa.JSON(), nullable=True, comment="冲突类型列表: ['人与人', '人与己', '人与环境']"),
sa.Column('emotional_tone', sa.String(length=100), nullable=True, comment='主导情感: 紧张/温馨/悲伤/激昂/平静'),
sa.Column('emotional_intensity', sa.Float(), nullable=True, comment='情感强度 0.0-1.0'),
sa.Column('emotional_curve', sa.JSON(), nullable=True, comment='情感曲线: {start: 0.3, middle: 0.7, end: 0.5}'),
sa.Column('hooks', sa.JSON(), nullable=True, comment='钩子列表 - 吸引读者的元素: [\n {\n "type": "悬念|情感|冲突|认知",\n "content": "具体内容",\n "strength": 8,\n "position": "开头|中段|结尾"\n }\n ]'),
sa.Column('hooks_count', sa.Integer(), nullable=True, comment='钩子数量'),
sa.Column('hooks_avg_strength', sa.Float(), nullable=True, comment='钩子平均强度'),
sa.Column('foreshadows', sa.JSON(), nullable=True, comment='伏笔列表: [\n {\n "content": "伏笔内容",\n "type": "planted|resolved",\n "strength": 7,\n "subtlety": 8,\n "reference_chapter": 3\n }\n ]'),
sa.Column('foreshadows_planted', sa.Integer(), nullable=True, comment='本章埋下的伏笔数量'),
sa.Column('foreshadows_resolved', sa.Integer(), nullable=True, comment='本章回收的伏笔数量'),
sa.Column('plot_points', sa.JSON(), nullable=True, comment='情节点列表: [\n {\n "content": "情节点描述",\n "importance": 0.9,\n "type": "revelation|conflict|resolution|transition",\n "impact": "对故事的影响描述"\n }\n ]'),
sa.Column('plot_points_count', sa.Integer(), nullable=True, comment='情节点数量'),
sa.Column('character_states', sa.JSON(), nullable=True, comment='角色状态变化: [\n {\n "character_id": "xxx",\n "character_name": "张三",\n "state_before": "犹豫不决",\n "state_after": "坚定信念",\n "psychological_change": "内心描述",\n "key_event": "触发事件",\n "relationship_changes": {"李四": "关系变化"}\n }\n ]'),
sa.Column('scenes', sa.JSON(), nullable=True, comment="场景列表: [{location: '地点', atmosphere: '氛围', duration: '时长'}]"),
sa.Column('pacing', sa.String(length=50), nullable=True, comment='节奏: slow|moderate|fast|varied'),
sa.Column('overall_quality_score', sa.Float(), nullable=True, comment='整体质量评分 0.0-10.0'),
sa.Column('pacing_score', sa.Float(), nullable=True, comment='节奏评分 0.0-10.0'),
sa.Column('engagement_score', sa.Float(), nullable=True, comment='吸引力评分 0.0-10.0'),
sa.Column('coherence_score', sa.Float(), nullable=True, comment='连贯性评分 0.0-10.0'),
sa.Column('analysis_report', sa.Text(), nullable=True, comment='完整的文字分析报告'),
sa.Column('suggestions', sa.JSON(), nullable=True, comment="改进建议列表: ['建议1', '建议2']"),
sa.Column('word_count', sa.Integer(), nullable=True, comment='章节字数'),
sa.Column('dialogue_ratio', sa.Float(), nullable=True, comment='对话占比 0.0-1.0'),
sa.Column('description_ratio', sa.Float(), nullable=True, comment='描写占比 0.0-1.0'),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=True, comment='分析时间'),
sa.ForeignKeyConstraint(['chapter_id'], ['chapters.id'], ondelete='CASCADE'),
sa.ForeignKeyConstraint(['project_id'], ['projects.id'], ondelete='CASCADE'),
sa.PrimaryKeyConstraint('id')
)
with op.batch_alter_table('plot_analysis', schema=None) as batch_op:
batch_op.create_index(batch_op.f('ix_plot_analysis_chapter_id'), ['chapter_id'], unique=True)
batch_op.create_index(batch_op.f('ix_plot_analysis_project_id'), ['project_id'], unique=False)
op.create_table('regeneration_tasks',
sa.Column('id', sa.String(length=36), nullable=False),
sa.Column('chapter_id', sa.String(length=36), nullable=False),
sa.Column('analysis_id', sa.String(length=36), nullable=True, comment='关联的分析结果ID'),
sa.Column('user_id', sa.String(length=50), nullable=False),
sa.Column('project_id', sa.String(length=36), nullable=False),
sa.Column('modification_instructions', sa.Text(), nullable=False, comment='综合修改指令'),
sa.Column('original_suggestions', sa.JSON(), nullable=True, comment='来自分析的原始建议列表'),
sa.Column('selected_suggestion_indices', sa.JSON(), nullable=True, comment='用户选择的建议索引'),
sa.Column('custom_instructions', sa.Text(), nullable=True, comment='用户自定义修改意见'),
sa.Column('style_id', sa.Integer(), nullable=True, comment='写作风格ID'),
sa.Column('target_word_count', sa.Integer(), nullable=True, comment='目标字数'),
sa.Column('focus_areas', sa.JSON(), nullable=True, comment='重点优化方向'),
sa.Column('preserve_elements', sa.JSON(), nullable=True, comment='需要保留的元素配置'),
sa.Column('status', sa.String(length=20), nullable=True, comment='pending/running/completed/failed'),
sa.Column('progress', sa.Integer(), nullable=True, comment='进度 0-100'),
sa.Column('error_message', sa.Text(), nullable=True),
sa.Column('original_content', sa.Text(), nullable=True, comment='原始章节内容快照'),
sa.Column('original_word_count', sa.Integer(), nullable=True, comment='原始字数'),
sa.Column('regenerated_content', sa.Text(), nullable=True, comment='重新生成的内容'),
sa.Column('regenerated_word_count', sa.Integer(), nullable=True, comment='新内容字数'),
sa.Column('version_number', sa.Integer(), nullable=True, comment='版本号'),
sa.Column('version_note', sa.String(length=500), nullable=True, comment='版本说明'),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=True),
sa.Column('started_at', sa.DateTime(), nullable=True),
sa.Column('completed_at', sa.DateTime(), nullable=True),
sa.ForeignKeyConstraint(['chapter_id'], ['chapters.id'], ondelete='CASCADE'),
sa.PrimaryKeyConstraint('id')
)
with op.batch_alter_table('regeneration_tasks', schema=None) as batch_op:
batch_op.create_index(batch_op.f('ix_regeneration_tasks_chapter_id'), ['chapter_id'], unique=False)
batch_op.create_index(batch_op.f('ix_regeneration_tasks_project_id'), ['project_id'], unique=False)
batch_op.create_index(batch_op.f('ix_regeneration_tasks_user_id'), ['user_id'], unique=False)
op.create_table('story_memories',
sa.Column('id', sa.String(length=100), nullable=False),
sa.Column('project_id', sa.String(length=36), nullable=False),
sa.Column('chapter_id', sa.String(length=36), nullable=True),
sa.Column('memory_type', sa.String(length=50), nullable=False, comment='\n 记忆类型:\n - plot_point: 情节点\n - character_event: 角色事件\n - world_detail: 世界观细节\n - hook: 钩子(悬念/冲突)\n - foreshadow: 伏笔\n - dialogue: 重要对话\n - scene: 场景描写\n '),
sa.Column('title', sa.String(length=200), nullable=True, comment='记忆标题/简述'),
sa.Column('content', sa.Text(), nullable=False, comment='记忆内容摘要(100-500字)'),
sa.Column('full_context', sa.Text(), nullable=True, comment='完整上下文(可选,用于详细记录)'),
sa.Column('related_characters', sa.JSON(), nullable=True, comment="涉及角色ID列表: ['char_id_1', 'char_id_2']"),
sa.Column('related_locations', sa.JSON(), nullable=True, comment="涉及地点列表: ['地点1', '地点2']"),
sa.Column('tags', sa.JSON(), nullable=True, comment="标签列表: ['悬念', '转折', '伏笔', '高潮']"),
sa.Column('importance_score', sa.Float(), nullable=True, comment='重要性评分 0.0-1.0'),
sa.Column('story_timeline', sa.Integer(), nullable=False, comment='故事时间线位置(章节序号)'),
sa.Column('chapter_position', sa.Integer(), nullable=True, comment='章节内位置(字符位置)'),
sa.Column('text_length', sa.Integer(), nullable=True, comment='文本长度(字符数)'),
sa.Column('is_foreshadow', sa.Integer(), nullable=True, comment='伏笔状态: 0=普通记忆, 1=已埋下伏笔, 2=伏笔已回收'),
sa.Column('foreshadow_resolved_at', sa.String(length=100), nullable=True, comment='伏笔回收的章节ID'),
sa.Column('foreshadow_strength', sa.Float(), nullable=True, comment='伏笔强度 0.0-1.0'),
sa.Column('vector_id', sa.String(length=100), nullable=True, comment='向量数据库中的唯一ID'),
sa.Column('embedding_model', sa.String(length=100), nullable=True, comment='使用的embedding模型'),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=True, comment='创建时间'),
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=True, comment='更新时间'),
sa.ForeignKeyConstraint(['chapter_id'], ['chapters.id'], ondelete='CASCADE'),
sa.ForeignKeyConstraint(['foreshadow_resolved_at'], ['chapters.id'], ondelete='SET NULL'),
sa.ForeignKeyConstraint(['project_id'], ['projects.id'], ondelete='CASCADE'),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('vector_id')
)
with op.batch_alter_table('story_memories', schema=None) as batch_op:
batch_op.create_index(batch_op.f('ix_story_memories_chapter_id'), ['chapter_id'], unique=False)
batch_op.create_index(batch_op.f('ix_story_memories_memory_type'), ['memory_type'], unique=False)
batch_op.create_index(batch_op.f('ix_story_memories_project_id'), ['project_id'], unique=False)
batch_op.create_index(batch_op.f('ix_story_memories_story_timeline'), ['story_timeline'], unique=False)
op.create_table('organization_members',
sa.Column('id', sa.String(length=36), nullable=False, comment='成员关系ID'),
sa.Column('organization_id', sa.String(length=36), nullable=False, comment='组织ID'),
sa.Column('character_id', sa.String(length=36), nullable=False, comment='角色ID'),
sa.Column('position', sa.String(length=100), nullable=False, comment='职位名称'),
sa.Column('rank', sa.Integer(), nullable=True, comment='职位等级'),
sa.Column('status', sa.String(length=20), nullable=True, comment='状态:active/retired/expelled/deceased'),
sa.Column('joined_at', sa.String(length=100), nullable=True, comment='加入时间(故事时间)'),
sa.Column('left_at', sa.String(length=100), nullable=True, comment='离开时间(故事时间)'),
sa.Column('loyalty', sa.Integer(), nullable=True, comment='忠诚度:0-100'),
sa.Column('contribution', sa.Integer(), nullable=True, comment='贡献度:0-100'),
sa.Column('source', sa.String(length=20), nullable=True, comment='来源:ai/manual'),
sa.Column('notes', sa.Text(), nullable=True, comment='备注'),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=True, comment='创建时间'),
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=True, comment='更新时间'),
sa.ForeignKeyConstraint(['character_id'], ['characters.id'], ondelete='CASCADE'),
sa.ForeignKeyConstraint(['organization_id'], ['organizations.id'], ondelete='CASCADE'),
sa.PrimaryKeyConstraint('id')
)
with op.batch_alter_table('organization_members', schema=None) as batch_op:
batch_op.create_index(batch_op.f('ix_organization_members_character_id'), ['character_id'], unique=False)
batch_op.create_index(batch_op.f('ix_organization_members_organization_id'), ['organization_id'], unique=False)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('organization_members', schema=None) as batch_op:
batch_op.drop_index(batch_op.f('ix_organization_members_organization_id'))
batch_op.drop_index(batch_op.f('ix_organization_members_character_id'))
op.drop_table('organization_members')
with op.batch_alter_table('story_memories', schema=None) as batch_op:
batch_op.drop_index(batch_op.f('ix_story_memories_story_timeline'))
batch_op.drop_index(batch_op.f('ix_story_memories_project_id'))
batch_op.drop_index(batch_op.f('ix_story_memories_memory_type'))
batch_op.drop_index(batch_op.f('ix_story_memories_chapter_id'))
op.drop_table('story_memories')
with op.batch_alter_table('regeneration_tasks', schema=None) as batch_op:
batch_op.drop_index(batch_op.f('ix_regeneration_tasks_user_id'))
batch_op.drop_index(batch_op.f('ix_regeneration_tasks_project_id'))
batch_op.drop_index(batch_op.f('ix_regeneration_tasks_chapter_id'))
op.drop_table('regeneration_tasks')
with op.batch_alter_table('plot_analysis', schema=None) as batch_op:
batch_op.drop_index(batch_op.f('ix_plot_analysis_project_id'))
batch_op.drop_index(batch_op.f('ix_plot_analysis_chapter_id'))
op.drop_table('plot_analysis')
with op.batch_alter_table('organizations', schema=None) as batch_op:
batch_op.drop_index(batch_op.f('ix_organizations_project_id'))
op.drop_table('organizations')
op.drop_table('generation_history')
with op.batch_alter_table('character_relationships', schema=None) as batch_op:
batch_op.drop_index(batch_op.f('ix_character_relationships_relationship_type_id'))
batch_op.drop_index(batch_op.f('ix_character_relationships_project_id'))
batch_op.drop_index(batch_op.f('ix_character_relationships_character_to_id'))
batch_op.drop_index(batch_op.f('ix_character_relationships_character_from_id'))
op.drop_table('character_relationships')
with op.batch_alter_table('character_careers', schema=None) as batch_op:
batch_op.drop_index('idx_character_id')
batch_op.drop_index('idx_character_career')
batch_op.drop_index('idx_career_type')
op.drop_table('character_careers')
with op.batch_alter_table('analysis_tasks', schema=None) as batch_op:
batch_op.drop_index('idx_status')
batch_op.drop_index('idx_chapter_id_created')
op.drop_table('analysis_tasks')
op.drop_table('project_default_styles')
op.drop_table('characters')
op.drop_table('chapters')
op.drop_table('writing_styles')
op.drop_table('outlines')
with op.batch_alter_table('careers', schema=None) as batch_op:
batch_op.drop_index('idx_type')
batch_op.drop_index('idx_project_id')
op.drop_table('careers')
with op.batch_alter_table('users', schema=None) as batch_op:
batch_op.drop_index(batch_op.f('ix_users_username'))
batch_op.drop_index(batch_op.f('ix_users_user_id'))
batch_op.drop_index(batch_op.f('ix_users_linuxdo_id'))
op.drop_table('users')
with op.batch_alter_table('user_passwords', schema=None) as batch_op:
batch_op.drop_index(batch_op.f('ix_user_passwords_user_id'))
op.drop_table('user_passwords')
with op.batch_alter_table('settings', schema=None) as batch_op:
batch_op.drop_index(batch_op.f('ix_settings_user_id'))
batch_op.drop_index('idx_user_id')
op.drop_table('settings')
with op.batch_alter_table('relationship_types', schema=None) as batch_op:
batch_op.drop_index(batch_op.f('ix_relationship_types_id'))
op.drop_table('relationship_types')
with op.batch_alter_table('prompt_templates', schema=None) as batch_op:
batch_op.drop_index(batch_op.f('ix_prompt_templates_user_id'))
batch_op.drop_index('idx_user_template')
op.drop_table('prompt_templates')
with op.batch_alter_table('projects', schema=None) as batch_op:
batch_op.drop_index(batch_op.f('ix_projects_user_id'))
op.drop_table('projects')
with op.batch_alter_table('mcp_plugins', schema=None) as batch_op:
batch_op.drop_index(batch_op.f('ix_mcp_plugins_user_id'))
batch_op.drop_index('idx_user_plugin')
batch_op.drop_index('idx_user_enabled')
op.drop_table('mcp_plugins')
op.drop_table('batch_generation_tasks')
# ### end Alembic commands ###
@@ -0,0 +1,177 @@
"""初始化SQLite预置数据
Revision ID: a1b2c3d4e5f6
Revises: fbeb1038c728
Create Date: 2025-12-27 08:56:00.000000
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
from sqlalchemy import table, column, String, Integer, Text
# revision identifiers, used by Alembic.
revision: str = 'a1b2c3d4e5f6'
down_revision: Union[str, None] = 'fbeb1038c728'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
"""插入预置数据"""
# ==================== 1. 插入关系类型数据 ====================
relationship_types_table = table(
'relationship_types',
column('name', String),
column('category', String),
column('reverse_name', String),
column('intimacy_range', String),
column('icon', String),
column('description', Text),
)
relationship_types_data = [
# 家庭关系
{"name": "父亲", "category": "family", "reverse_name": "子女", "intimacy_range": "high", "icon": "👨", "description": "父子/父女关系"},
{"name": "母亲", "category": "family", "reverse_name": "子女", "intimacy_range": "high", "icon": "👩", "description": "母子/母女关系"},
{"name": "兄弟", "category": "family", "reverse_name": "兄弟", "intimacy_range": "high", "icon": "👬", "description": "兄弟关系"},
{"name": "姐妹", "category": "family", "reverse_name": "姐妹", "intimacy_range": "high", "icon": "👭", "description": "姐妹关系"},
{"name": "子女", "category": "family", "reverse_name": "父母", "intimacy_range": "high", "icon": "👶", "description": "子女关系"},
{"name": "配偶", "category": "family", "reverse_name": "配偶", "intimacy_range": "high", "icon": "💑", "description": "夫妻关系"},
{"name": "恋人", "category": "family", "reverse_name": "恋人", "intimacy_range": "high", "icon": "💕", "description": "恋爱关系"},
# 社交关系
{"name": "师父", "category": "social", "reverse_name": "徒弟", "intimacy_range": "high", "icon": "🎓", "description": "师徒关系(师父视角)"},
{"name": "徒弟", "category": "social", "reverse_name": "师父", "intimacy_range": "high", "icon": "📚", "description": "师徒关系(徒弟视角)"},
{"name": "朋友", "category": "social", "reverse_name": "朋友", "intimacy_range": "medium", "icon": "🤝", "description": "朋友关系"},
{"name": "同学", "category": "social", "reverse_name": "同学", "intimacy_range": "medium", "icon": "🎒", "description": "同学关系"},
{"name": "邻居", "category": "social", "reverse_name": "邻居", "intimacy_range": "low", "icon": "🏘️", "description": "邻居关系"},
{"name": "知己", "category": "social", "reverse_name": "知己", "intimacy_range": "high", "icon": "💙", "description": "知心好友"},
# 职业关系
{"name": "上司", "category": "professional", "reverse_name": "下属", "intimacy_range": "low", "icon": "👔", "description": "上下级关系(上司视角)"},
{"name": "下属", "category": "professional", "reverse_name": "上司", "intimacy_range": "low", "icon": "💼", "description": "上下级关系(下属视角)"},
{"name": "同事", "category": "professional", "reverse_name": "同事", "intimacy_range": "medium", "icon": "🤵", "description": "同事关系"},
{"name": "合作伙伴", "category": "professional", "reverse_name": "合作伙伴", "intimacy_range": "medium", "icon": "🤜🤛", "description": "合作关系"},
# 敌对关系
{"name": "敌人", "category": "hostile", "reverse_name": "敌人", "intimacy_range": "low", "icon": "⚔️", "description": "敌对关系"},
{"name": "仇人", "category": "hostile", "reverse_name": "仇人", "intimacy_range": "low", "icon": "💢", "description": "仇恨关系"},
{"name": "竞争对手", "category": "hostile", "reverse_name": "竞争对手", "intimacy_range": "low", "icon": "🎯", "description": "竞争关系"},
{"name": "宿敌", "category": "hostile", "reverse_name": "宿敌", "intimacy_range": "low", "icon": "", "description": "宿命之敌"},
]
op.bulk_insert(relationship_types_table, relationship_types_data)
print(f"✅ SQLite: 已插入 {len(relationship_types_data)} 条关系类型数据")
# ==================== 2. 插入全局写作风格预设 ====================
writing_styles_table = table(
'writing_styles',
column('user_id', String),
column('name', String),
column('style_type', String),
column('preset_id', String),
column('description', Text),
column('prompt_content', Text),
column('order_index', Integer),
)
writing_styles_data = [
{
"user_id": None, # NULL 表示全局预设
"name": "自然流畅",
"style_type": "preset",
"preset_id": "natural",
"description": "自然流畅的叙事风格,适合现代都市、现实题材",
"prompt_content": """写作风格要求:
1. 语言简洁明快,贴近现代口语
2. 多用短句,节奏流畅
3. 注重情感细节的自然流露
4. 避免过度修饰和复杂句式""",
"order_index": 1
},
{
"user_id": None,
"name": "古典优雅",
"style_type": "preset",
"preset_id": "classical",
"description": "古典文雅的写作风格,适合古装、仙侠题材",
"prompt_content": """写作风格要求:
1. 使用文言、半文言或典雅的白话
2. 适当运用古典诗词意象
3. 注重意境营造和韵味
4. 对话和描写保持古典美感""",
"order_index": 2
},
{
"user_id": None,
"name": "现代简约",
"style_type": "preset",
"preset_id": "modern",
"description": "现代简约风格,适合轻小说、网文快节奏叙事",
"prompt_content": """写作风格要求:
1. 语言直白简练,信息密度高
2. 多用对话推进情节
3. 避免冗长描写,突出关键动作
4. 节奏明快,适合快速阅读""",
"order_index": 3
},
{
"user_id": None,
"name": "文艺细腻",
"style_type": "preset",
"preset_id": "literary",
"description": "文艺细腻风格,注重心理描写和氛围营造",
"prompt_content": """写作风格要求:
1. 注重心理活动和情感细节
2. 善用环境描写烘托氛围
3. 语言优美,富有文学性
4. 适当使用比喻、象征等修辞手法""",
"order_index": 4
},
{
"user_id": None,
"name": "紧张悬疑",
"style_type": "preset",
"preset_id": "suspense",
"description": "紧张悬疑风格,适合推理、惊悚题材",
"prompt_content": """写作风格要求:
1. 营造紧张压迫的氛围
2. 多用短句加快节奏
3. 善于设置悬念和伏笔
4. 注重细节描写,为推理埋下线索""",
"order_index": 5
},
{
"user_id": None,
"name": "幽默诙谐",
"style_type": "preset",
"preset_id": "humorous",
"description": "幽默诙谐风格,适合轻松搞笑题材",
"prompt_content": """写作风格要求:
1. 语言活泼风趣,善用俏皮话
2. 注重对话的喜剧效果
3. 适当夸张和反转制造笑点
4. 保持轻松愉快的基调""",
"order_index": 6
},
]
op.bulk_insert(writing_styles_table, writing_styles_data)
print(f"✅ SQLite: 已插入 {len(writing_styles_data)} 条全局写作风格预设")
def downgrade() -> None:
"""删除预置数据"""
# 删除写作风格预设(只删除全局预设)
op.execute("DELETE FROM writing_styles WHERE user_id IS NULL")
print("✅ SQLite: 已删除全局写作风格预设")
# 删除关系类型
op.execute("DELETE FROM relationship_types")
print("✅ SQLite: 已删除关系类型数据")
@@ -0,0 +1,34 @@
"""添加system_prompt字段到settings表
Revision ID: 7899f8d4d839
Revises: a1b2c3d4e5f6
Create Date: 2025-12-27 17:00:35.440551
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = '7899f8d4d839'
down_revision: Union[str, None] = 'a1b2c3d4e5f6'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('settings', schema=None) as batch_op:
batch_op.add_column(sa.Column('system_prompt', sa.Text(), nullable=True, comment='系统级别提示词,每次AI调用都会使用'))
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('settings', schema=None) as batch_op:
batch_op.drop_column('system_prompt')
# ### end Alembic commands ###
+388
View File
@@ -0,0 +1,388 @@
"""
管理员API - 用户管理功能
"""
from fastapi import APIRouter, HTTPException, Request, Depends
from pydantic import BaseModel, Field
from typing import Optional, List
from datetime import datetime
import hashlib
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.database import get_db
from app.models.user import User
from app.user_manager import user_manager
from app.user_password import password_manager
from app.logger import get_logger
logger = get_logger(__name__)
router = APIRouter(prefix="/admin", tags=["管理员"])
# ==================== 请求/响应模型 ====================
class CreateUserRequest(BaseModel):
"""创建用户请求"""
username: str = Field(..., min_length=3, max_length=20, description="用户名")
display_name: str = Field(..., min_length=2, max_length=50, description="显示名称")
password: Optional[str] = Field(None, min_length=6, description="初始密码,留空则自动生成")
avatar_url: Optional[str] = Field(None, description="头像URL")
trust_level: int = Field(0, ge=0, le=9, description="信任等级")
is_admin: bool = Field(False, description="是否为管理员")
class UpdateUserRequest(BaseModel):
"""更新用户请求"""
display_name: Optional[str] = Field(None, min_length=2, max_length=50)
avatar_url: Optional[str] = None
trust_level: Optional[int] = Field(None, ge=-1, le=9)
is_admin: Optional[bool] = Field(None, description="是否为管理员")
class ToggleStatusRequest(BaseModel):
"""切换用户状态请求"""
is_active: bool = Field(..., description="true=启用, false=禁用")
class ResetPasswordRequest(BaseModel):
"""重置密码请求"""
new_password: Optional[str] = Field(None, min_length=6, description="新密码,留空则重置为默认密码")
class UserResponse(BaseModel):
"""用户响应"""
user_id: str
username: str
display_name: str
avatar_url: Optional[str]
trust_level: int
is_admin: bool
is_active: bool
linuxdo_id: str
created_at: str
last_login: Optional[str]
class CreateUserResponse(BaseModel):
"""创建用户响应"""
success: bool
message: str
user: dict
default_password: Optional[str] = None
# ==================== 权限检查依赖 ====================
async def check_admin(request: Request) -> User:
"""检查管理员权限"""
user = getattr(request.state, "user", None)
if not user:
raise HTTPException(status_code=401, detail="未登录")
if not user.is_admin:
raise HTTPException(status_code=403, detail="需要管理员权限")
return user
# ==================== API 端点 ====================
@router.get("/users")
async def get_users(
admin: User = Depends(check_admin),
db: AsyncSession = Depends(get_db)
):
"""获取用户列表(仅管理员)"""
try:
all_users = await user_manager.get_all_users()
users_data = []
for user in all_users:
# user_manager 返回的是 Pydantic User 对象,直接转为 dict
user_dict = user.model_dump()
user_dict["is_active"] = user.trust_level != -1
users_data.append(user_dict)
logger.info(f"管理员 {admin.user_id} 获取用户列表,共 {len(users_data)} 个用户")
return {
"total": len(users_data),
"users": users_data
}
except Exception as e:
logger.error(f"获取用户列表失败: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail=f"获取用户列表失败: {str(e)}")
@router.post("/users")
async def create_user(
data: CreateUserRequest,
admin: User = Depends(check_admin),
db: AsyncSession = Depends(get_db)
):
"""添加用户(仅管理员)"""
try:
# 检查用户名是否已存在
all_users = await user_manager.get_all_users()
for user in all_users:
if user.username == data.username:
raise HTTPException(status_code=409, detail="用户名已存在")
# 生成用户ID
user_id = f"admin_created_{hashlib.md5(data.username.encode()).hexdigest()[:16]}"
# 创建用户
new_user = await user_manager.create_or_update_from_linuxdo(
linuxdo_id=user_id,
username=data.username,
display_name=data.display_name,
avatar_url=data.avatar_url,
trust_level=data.trust_level
)
# 设置管理员标志
if data.is_admin:
# 直接更新数据库中的is_admin字段
async with await user_manager._get_session() as session:
result = await session.execute(
select(User).where(User.user_id == user_id)
)
db_user = result.scalar_one_or_none()
if db_user:
db_user.is_admin = True
await session.commit()
new_user.is_admin = True
# 设置密码
actual_password = await password_manager.set_password(
user_id=new_user.user_id,
username=data.username,
password=data.password
)
# Settings 将在首次访问设置页面时自动创建(延迟初始化)
logger.info(f"管理员 {admin.user_id} 创建了新用户 {new_user.user_id} ({data.username})")
return CreateUserResponse(
success=True,
message="用户创建成功",
user=new_user.model_dump(),
default_password=actual_password if not data.password else None
)
except HTTPException:
raise
except Exception as e:
logger.error(f"创建用户失败: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail=f"创建用户失败: {str(e)}")
@router.put("/users/{user_id}")
async def update_user(
user_id: str,
data: UpdateUserRequest,
admin: User = Depends(check_admin),
db: AsyncSession = Depends(get_db)
):
"""编辑用户信息(仅管理员)"""
try:
# 获取目标用户
target_user = await user_manager.get_user(user_id)
if not target_user:
raise HTTPException(status_code=404, detail="用户不存在")
# 更新用户信息
async with await user_manager._get_session() as session:
result = await session.execute(
select(User).where(User.user_id == user_id)
)
db_user = result.scalar_one_or_none()
if not db_user:
raise HTTPException(status_code=404, detail="用户不存在")
# 更新字段
if data.display_name is not None:
db_user.display_name = data.display_name
if data.avatar_url is not None:
db_user.avatar_url = data.avatar_url
if data.trust_level is not None:
db_user.trust_level = data.trust_level
if data.is_admin is not None:
# 检查是否是最后一个管理员
if db_user.is_admin and not data.is_admin:
all_users = await user_manager.get_all_users()
admin_count = sum(1 for u in all_users if u.is_admin)
if admin_count <= 1:
raise HTTPException(status_code=400, detail="不能取消最后一个管理员的权限")
db_user.is_admin = data.is_admin
await session.commit()
await session.refresh(db_user)
logger.info(f"管理员 {admin.user_id} 更新了用户 {user_id} 的信息")
updated_user = await user_manager.get_user(user_id)
user_dict = updated_user.model_dump()
user_dict["is_active"] = updated_user.trust_level != -1
return {
"success": True,
"message": "用户信息更新成功",
"user": user_dict
}
except HTTPException:
raise
except Exception as e:
logger.error(f"更新用户失败: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail=f"更新用户失败: {str(e)}")
@router.post("/users/{user_id}/toggle-status")
async def toggle_user_status(
user_id: str,
data: ToggleStatusRequest,
admin: User = Depends(check_admin),
db: AsyncSession = Depends(get_db)
):
"""切换用户状态(启用/禁用)(仅管理员)"""
try:
# 不允许禁用自己
if user_id == admin.user_id:
raise HTTPException(status_code=400, detail="不能禁用自己的账号")
# 获取目标用户
target_user = await user_manager.get_user(user_id)
if not target_user:
raise HTTPException(status_code=404, detail="用户不存在")
# 更新状态
async with await user_manager._get_session() as session:
result = await session.execute(
select(User).where(User.user_id == user_id)
)
db_user = result.scalar_one_or_none()
if not db_user:
raise HTTPException(status_code=404, detail="用户不存在")
if data.is_active:
# 启用用户:恢复trust_level为0(或之前的值)
db_user.trust_level = 0
else:
# 禁用用户:设置trust_level为-1
db_user.trust_level = -1
await session.commit()
status_text = "启用" if data.is_active else "禁用"
logger.info(f"管理员 {admin.user_id} {status_text}了用户 {user_id}")
return {
"success": True,
"message": f"用户已{status_text}",
"is_active": data.is_active
}
except HTTPException:
raise
except Exception as e:
logger.error(f"切换用户状态失败: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail=f"切换用户状态失败: {str(e)}")
@router.post("/users/{user_id}/reset-password")
async def reset_password(
user_id: str,
data: ResetPasswordRequest,
admin: User = Depends(check_admin),
db: AsyncSession = Depends(get_db)
):
"""重置用户密码(仅管理员)"""
try:
# 获取目标用户
target_user = await user_manager.get_user(user_id)
if not target_user:
raise HTTPException(status_code=404, detail="用户不存在")
# 重置密码
actual_password = await password_manager.set_password(
user_id=user_id,
username=target_user.username,
password=data.new_password
)
logger.info(f"管理员 {admin.user_id} 重置了用户 {user_id} 的密码")
return {
"success": True,
"message": "密码重置成功",
"new_password": actual_password
}
except HTTPException:
raise
except Exception as e:
logger.error(f"重置密码失败: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail=f"重置密码失败: {str(e)}")
@router.delete("/users/{user_id}")
async def delete_user(
user_id: str,
admin: User = Depends(check_admin),
db: AsyncSession = Depends(get_db)
):
"""删除用户(仅管理员,慎用)"""
try:
# 不允许删除自己
if user_id == admin.user_id:
raise HTTPException(status_code=400, detail="不能删除自己的账号")
# 获取目标用户
target_user = await user_manager.get_user(user_id)
if not target_user:
raise HTTPException(status_code=404, detail="用户不存在")
# 检查是否是最后一个管理员
if target_user.is_admin:
all_users = await user_manager.get_all_users()
admin_count = sum(1 for u in all_users if u.is_admin)
if admin_count <= 1:
raise HTTPException(status_code=400, detail="不能删除最后一个管理员账号")
# 删除用户(包括密码记录)
async with await user_manager._get_session() as session:
# 删除用户记录
result = await session.execute(
select(User).where(User.user_id == user_id)
)
db_user = result.scalar_one_or_none()
if db_user:
await session.delete(db_user)
# 删除密码记录
from app.models.user import UserPassword
result = await session.execute(
select(UserPassword).where(UserPassword.user_id == user_id)
)
pwd_record = result.scalar_one_or_none()
if pwd_record:
await session.delete(pwd_record)
await session.commit()
logger.warning(f"管理员 {admin.user_id} 删除了用户 {user_id}")
return {
"success": True,
"message": "用户已删除"
}
except HTTPException:
raise
except Exception as e:
logger.error(f"删除用户失败: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail=f"删除用户失败: {str(e)}")
+250 -32
View File
@@ -9,7 +9,7 @@ import hashlib
from datetime import datetime, timedelta, timezone
from app.services.oauth_service import LinuxDOOAuthService
from app.user_manager import user_manager
from app.database import init_db
from app.user_password import password_manager
from app.logger import get_logger
from app.config import settings
@@ -49,6 +49,25 @@ class LocalLoginResponse(BaseModel):
user: Optional[dict] = None
class SetPasswordRequest(BaseModel):
"""设置密码请求"""
password: str
class SetPasswordResponse(BaseModel):
"""设置密码响应"""
success: bool
message: str
class PasswordStatusResponse(BaseModel):
"""密码状态响应"""
has_password: bool
has_custom_password: bool
username: Optional[str] = None
default_password: Optional[str] = None
@router.get("/config")
async def get_auth_config():
"""获取认证配置信息"""
@@ -60,37 +79,79 @@ async def get_auth_config():
@router.post("/local/login", response_model=LocalLoginResponse)
async def local_login(request: LocalLoginRequest, response: Response):
"""本地账户登录"""
"""本地账户登录(支持.env配置的管理员账号和Linux DO授权后绑定的账号)"""
# 检查是否启用本地登录
if not settings.LOCAL_AUTH_ENABLED:
raise HTTPException(status_code=403, detail="本地账户登录未启用")
# 检查是否配置了本地账户
if not settings.LOCAL_AUTH_USERNAME or not settings.LOCAL_AUTH_PASSWORD:
raise HTTPException(status_code=500, detail="本地账户未配置")
logger.info(f"[本地登录] 尝试登录用户名: {request.username}")
# 验证用户名和密码
if request.username != settings.LOCAL_AUTH_USERNAME or request.password != settings.LOCAL_AUTH_PASSWORD:
raise HTTPException(status_code=401, detail="用户名或密码错误")
# 首先尝试查找 Linux DO 授权后绑定的账号
all_users = await user_manager.get_all_users()
target_user = None
# 生成本地用户ID(使用用户名的hash)
user_id = f"local_{hashlib.md5(request.username.encode()).hexdigest()[:16]}"
for user in all_users:
# 同时检查 users 表的 username 和 user_passwords 表的 username
password_username = await password_manager.get_username(user.user_id)
if user.username == request.username or password_username == request.username:
target_user = user
logger.info(f"[本地登录] 找到 Linux DO 授权用户: {user.user_id}")
break
# 创建或更新本地用户
user = await user_manager.create_or_update_from_linuxdo(
linuxdo_id=user_id,
username=request.username,
display_name=settings.LOCAL_AUTH_DISPLAY_NAME,
avatar_url=None,
trust_level=9 # 本地用户给予高信任级别
)
# 如果找到了 Linux DO 授权的用户
if target_user:
# 检查是否有密码
if not await password_manager.has_password(target_user.user_id):
logger.warning(f"[本地登录] 用户 {target_user.user_id} 没有设置密码")
raise HTTPException(status_code=401, detail="用户名或密码错误")
# 验证密码
if not await password_manager.verify_password(target_user.user_id, request.password):
logger.warning(f"[本地登录] 用户 {target_user.user_id} 密码验证失败")
raise HTTPException(status_code=401, detail="用户名或密码错误")
logger.info(f"[本地登录] Linux DO 授权用户 {target_user.user_id} 登录成功")
user = target_user
else:
# 没有找到 Linux DO 用户,尝试 .env 配置的管理员账号
logger.info(f"[本地登录] 未找到 Linux DO 用户,检查 .env 管理员账号")
# 检查是否配置了本地账户
if not settings.LOCAL_AUTH_USERNAME or not settings.LOCAL_AUTH_PASSWORD:
raise HTTPException(status_code=401, detail="用户名或密码错误")
# 生成本地用户ID(使用用户名的hash)
user_id = f"local_{hashlib.md5(request.username.encode()).hexdigest()[:16]}"
# 检查用户是否存在
user = await user_manager.get_user(user_id)
# 如果用户不存在,使用.env中的默认密码验证
if not user:
# 验证用户名和密码(使用.env配置)
if request.username != settings.LOCAL_AUTH_USERNAME or request.password != settings.LOCAL_AUTH_PASSWORD:
raise HTTPException(status_code=401, detail="用户名或密码错误")
# 创建本地用户
user = await user_manager.create_or_update_from_linuxdo(
linuxdo_id=user_id,
username=request.username,
display_name=settings.LOCAL_AUTH_DISPLAY_NAME,
avatar_url=None,
trust_level=9 # 本地用户给予高信任级别
)
# 为新用户设置默认密码到数据库
await password_manager.set_password(user.user_id, request.username, request.password)
logger.info(f"[本地登录] 管理员用户 {user.user_id} 初始密码已设置到数据库")
else:
# 用户已存在,使用数据库中的密码验证
if not await password_manager.verify_password(user.user_id, request.password):
raise HTTPException(status_code=401, detail="用户名或密码错误")
logger.info(f"[本地登录] 管理员用户 {user.user_id} 登录成功")
# 初始化用户数据库
try:
await init_db(user.user_id)
logger.info(f"本地用户 {user.user_id} 数据库初始化成功")
except Exception as e:
logger.error(f"本地用户 {user.user_id} 数据库初始化失败: {e}")
# Settings 将在首次访问设置页面时自动创建(延迟初始化)
# 设置 Cookie2小时有效)
max_age = settings.SESSION_EXPIRE_MINUTES * 60
@@ -189,13 +250,12 @@ async def _handle_callback(
trust_level=trust_level
)
# 3.5. 初始化用户数据库(如果是新用户
try:
await init_db(user.user_id)
logger.info(f"用户 {user.user_id} 数据库初始化成功")
except Exception as e:
logger.error(f"用户 {user.user_id} 数据库初始化失败: {e}")
# 继续执行,不影响登录流程(可能是已存在的用户)
# 3.1. 检查是否是首次登录(没有密码记录
is_first_login = not await password_manager.has_password(user.user_id)
if is_first_login:
logger.info(f"用户 {user.user_id} ({username}) 首次登录,需要初始化密码")
# Settings 将在首次访问设置页面时自动创建(延迟初始化)
# 4. 设置 Cookie 并重定向到前端回调页面
# 使用配置的前端URL,支持不同的部署环境
@@ -229,6 +289,17 @@ async def _handle_callback(
samesite="lax"
)
# 如果是首次登录,设置标记 Cookie(5分钟有效,仅用于前端显示初始密码提示)
if is_first_login:
redirect_response.set_cookie(
key="first_login",
value="true",
max_age=300, # 5分钟有效
httponly=False, # 前端需要读取
samesite="lax"
)
logger.info(f"✅ [OAuth登录] 用户 {user.user_id} 首次登录,已设置 first_login 标记")
return redirect_response
@@ -337,4 +408,151 @@ async def get_current_user(request: Request):
if not hasattr(request.state, "user") or not request.state.user:
raise HTTPException(status_code=401, detail="未登录")
return request.state.user.dict()
return request.state.user.dict()
@router.get("/password/status", response_model=PasswordStatusResponse)
async def get_password_status(request: Request):
"""获取当前用户的密码状态"""
if not hasattr(request.state, "user") or not request.state.user:
raise HTTPException(status_code=401, detail="未登录")
user = request.state.user
has_password = await password_manager.has_password(user.user_id)
has_custom = await password_manager.has_custom_password(user.user_id)
username = await password_manager.get_username(user.user_id)
# 如果使用默认密码,返回默认密码供用户查看
default_password = None
if has_password and not has_custom:
default_password = f"{user.username}@666"
return PasswordStatusResponse(
has_password=has_password,
has_custom_password=has_custom,
username=username or user.username,
default_password=default_password
)
@router.post("/password/set", response_model=SetPasswordResponse)
async def set_user_password(request: Request, password_req: SetPasswordRequest):
"""设置当前用户的密码"""
if not hasattr(request.state, "user") or not request.state.user:
raise HTTPException(status_code=401, detail="未登录")
user = request.state.user
# 验证密码强度(至少6个字符)
if len(password_req.password) < 6:
raise HTTPException(status_code=400, detail="密码长度至少为6个字符")
# 设置密码
await password_manager.set_password(user.user_id, user.username, password_req.password)
logger.info(f"用户 {user.user_id} ({user.username}) 设置了自定义密码")
return SetPasswordResponse(
success=True,
message="密码设置成功"
)
@router.post("/password/initialize", response_model=SetPasswordResponse)
async def initialize_user_password(request: Request, password_req: SetPasswordRequest):
"""
初始化首次登录用户的密码
用于首次通过 Linux DO 授权登录的用户,可以选择设置自定义密码或使用默认密码
"""
if not hasattr(request.state, "user") or not request.state.user:
raise HTTPException(status_code=401, detail="未登录")
user = request.state.user
# 检查是否已经有密码(防止重复初始化)
if await password_manager.has_password(user.user_id):
raise HTTPException(status_code=400, detail="密码已经初始化,请使用密码修改功能")
# 验证密码强度(至少6个字符)
if len(password_req.password) < 6:
raise HTTPException(status_code=400, detail="密码长度至少为6个字符")
# 设置密码
await password_manager.set_password(user.user_id, user.username, password_req.password)
logger.info(f"用户 {user.user_id} ({user.username}) 初始化密码成功")
return SetPasswordResponse(
success=True,
message="密码初始化成功"
)
@router.post("/bind/login", response_model=LocalLoginResponse)
async def bind_account_login(request: LocalLoginRequest, response: Response):
"""使用绑定的账号密码登录(LinuxDO授权后绑定的账号)"""
# 查找用户
all_users = await user_manager.get_all_users()
target_user = None
logger.info(f"[绑定账号登录] 尝试登录用户名: {request.username}")
logger.info(f"[绑定账号登录] 当前共有 {len(all_users)} 个用户")
for user in all_users:
# 同时检查 users 表的 username 和 user_passwords 表的 username
password_username = await password_manager.get_username(user.user_id)
logger.info(f"[绑定账号登录] 检查用户 {user.user_id}: users.username={user.username}, passwords.username={password_username}")
if user.username == request.username or password_username == request.username:
target_user = user
logger.info(f"[绑定账号登录] 找到匹配用户: {user.user_id}")
break
if not target_user:
logger.warning(f"[绑定账号登录] 用户名 {request.username} 未找到")
raise HTTPException(status_code=401, detail="用户名或密码错误")
# 检查是否有密码记录
has_pwd = await password_manager.has_password(target_user.user_id)
if not has_pwd:
logger.warning(f"[绑定账号登录] 用户 {target_user.user_id} 没有设置密码")
raise HTTPException(status_code=401, detail="用户名或密码错误")
# 验证密码
is_valid = await password_manager.verify_password(target_user.user_id, request.password)
logger.info(f"[绑定账号登录] 用户 {target_user.user_id} 密码验证结果: {is_valid}")
if not is_valid:
raise HTTPException(status_code=401, detail="用户名或密码错误")
# Settings 将在首次访问设置页面时自动创建(延迟初始化)
# 设置 Cookie2小时有效)
max_age = settings.SESSION_EXPIRE_MINUTES * 60
response.set_cookie(
key="user_id",
value=target_user.user_id,
max_age=max_age,
httponly=True,
samesite="lax"
)
# 设置过期时间戳 Cookie(用于前端判断)
china_now = get_china_now()
expire_time = china_now + timedelta(minutes=settings.SESSION_EXPIRE_MINUTES)
expire_at = int(expire_time.timestamp())
logger.info(f"✅ [绑定账号登录] 用户 {target_user.user_id} ({request.username}) 登录成功,会话有效期 {settings.SESSION_EXPIRE_MINUTES} 分钟")
response.set_cookie(
key="session_expire_at",
value=str(expire_at),
max_age=max_age,
httponly=False, # 前端需要读取
samesite="lax"
)
return LocalLoginResponse(
success=True,
message="登录成功",
user=target_user.dict()
)
+938
View File
@@ -0,0 +1,938 @@
"""职业管理API"""
from fastapi import APIRouter, Depends, HTTPException, Request
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, func, and_
import json
from typing import AsyncGenerator
from app.database import get_db
from app.utils.sse_response import SSEResponse, create_sse_response
from app.models.career import Career, CharacterCareer
from app.models.character import Character
from app.models.project import Project
from app.schemas.career import (
CareerCreate,
CareerUpdate,
CareerResponse,
CareerListResponse,
CareerGenerateRequest,
CharacterCareerResponse,
CharacterCareerDetail,
SetMainCareerRequest,
AddSubCareerRequest,
UpdateCareerStageRequest,
CareerStage
)
from app.services.ai_service import AIService
from app.logger import get_logger
from app.api.settings import get_user_ai_service
router = APIRouter(prefix="/careers", tags=["职业管理"])
logger = get_logger(__name__)
async def verify_project_access(project_id: str, user_id: str, db: AsyncSession) -> Project:
"""验证用户是否有权访问指定项目"""
if not user_id:
raise HTTPException(status_code=401, detail="未登录")
result = await db.execute(
select(Project).where(
Project.id == project_id,
Project.user_id == user_id
)
)
project = result.scalar_one_or_none()
if not project:
logger.warning(f"项目访问被拒绝: project_id={project_id}, user_id={user_id}")
raise HTTPException(status_code=404, detail="项目不存在或无权访问")
return project
@router.get("", response_model=CareerListResponse, summary="获取职业列表")
async def get_careers(
project_id: str,
request: Request,
db: AsyncSession = Depends(get_db)
):
"""获取指定项目的所有职业"""
user_id = getattr(request.state, 'user_id', None)
await verify_project_access(project_id, user_id, db)
# 获取总数
count_result = await db.execute(
select(func.count(Career.id)).where(Career.project_id == project_id)
)
total = count_result.scalar_one()
# 获取职业列表
result = await db.execute(
select(Career)
.where(Career.project_id == project_id)
.order_by(Career.type, Career.created_at.desc())
)
careers = result.scalars().all()
# 分类返回
main_careers = []
sub_careers = []
for career in careers:
# 解析JSON字段
stages = json.loads(career.stages) if career.stages else []
attribute_bonuses = json.loads(career.attribute_bonuses) if career.attribute_bonuses else None
career_dict = {
"id": career.id,
"project_id": career.project_id,
"name": career.name,
"type": career.type,
"description": career.description,
"category": career.category,
"stages": stages,
"max_stage": career.max_stage,
"requirements": career.requirements,
"special_abilities": career.special_abilities,
"worldview_rules": career.worldview_rules,
"attribute_bonuses": attribute_bonuses,
"source": career.source,
"created_at": career.created_at,
"updated_at": career.updated_at
}
if career.type == "main":
main_careers.append(career_dict)
else:
sub_careers.append(career_dict)
return CareerListResponse(
total=total,
main_careers=main_careers,
sub_careers=sub_careers
)
@router.post("", response_model=CareerResponse, summary="创建职业")
async def create_career(
career_data: CareerCreate,
request: Request,
db: AsyncSession = Depends(get_db)
):
"""手动创建职业"""
user_id = getattr(request.state, 'user_id', None)
await verify_project_access(career_data.project_id, user_id, db)
try:
# 转换stages为JSON字符串
stages_json = json.dumps([stage.model_dump() for stage in career_data.stages], ensure_ascii=False)
attribute_bonuses_json = json.dumps(career_data.attribute_bonuses, ensure_ascii=False) if career_data.attribute_bonuses else None
# 创建职业
career = Career(
project_id=career_data.project_id,
name=career_data.name,
type=career_data.type,
description=career_data.description,
category=career_data.category,
stages=stages_json,
max_stage=career_data.max_stage,
requirements=career_data.requirements,
special_abilities=career_data.special_abilities,
worldview_rules=career_data.worldview_rules,
attribute_bonuses=attribute_bonuses_json,
source=career_data.source
)
db.add(career)
await db.commit()
await db.refresh(career)
logger.info(f"✅ 创建职业成功:{career.name} (ID: {career.id}, 类型: {career.type})")
return CareerResponse(
id=career.id,
project_id=career.project_id,
name=career.name,
type=career.type,
description=career.description,
category=career.category,
stages=career_data.stages,
max_stage=career.max_stage,
requirements=career.requirements,
special_abilities=career.special_abilities,
worldview_rules=career.worldview_rules,
attribute_bonuses=career_data.attribute_bonuses,
source=career.source,
created_at=career.created_at,
updated_at=career.updated_at
)
except Exception as e:
logger.error(f"创建职业失败: {str(e)}")
raise HTTPException(status_code=500, detail=f"创建职业失败: {str(e)}")
@router.get("/generate-system", summary="AI生成新职业(增量式,流式)")
async def generate_career_system(
project_id: str,
main_career_count: int = 3,
sub_career_count: int = 6,
enable_mcp: bool = False,
http_request: Request = None,
db: AsyncSession = Depends(get_db),
user_ai_service: AIService = Depends(get_user_ai_service)
):
"""
使用AI生成新职业(增量式,基于已有职业补充,支持SSE流式进度显示)
通过Server-Sent Events返回实时进度信息
"""
async def generate() -> AsyncGenerator[str, None]:
try:
# 验证用户权限和项目是否存在
user_id = getattr(http_request.state, 'user_id', None)
project = await verify_project_access(project_id, user_id, db)
yield await SSEResponse.send_progress("开始生成新职业...", 0)
# 获取已有职业列表
yield await SSEResponse.send_progress("分析已有职业...", 5)
existing_careers_result = await db.execute(
select(Career).where(Career.project_id == project_id)
)
existing_careers = existing_careers_result.scalars().all()
# 构建已有职业摘要
existing_main_careers = []
existing_sub_careers = []
for career in existing_careers:
career_summary = f"- {career.name}{career.category or '未分类'}{career.max_stage}阶)"
if career.description:
career_summary += f": {career.description[:50]}"
if career.type == "main":
existing_main_careers.append(career_summary)
else:
existing_sub_careers.append(career_summary)
existing_careers_text = ""
if existing_main_careers:
existing_careers_text += f"\n已有主职业({len(existing_main_careers)}个):\n" + "\n".join(existing_main_careers)
if existing_sub_careers:
existing_careers_text += f"\n\n已有副职业({len(existing_sub_careers)}个):\n" + "\n".join(existing_sub_careers)
if not existing_careers_text:
existing_careers_text = "\n当前还没有任何职业,这是第一次创建职业体系。"
# 构建项目上下文
yield await SSEResponse.send_progress("分析项目世界观...", 15)
project_context = f"""
项目信息:
- 书名:{project.title}
- 类型:{project.genre or '未设定'}
- 主题:{project.theme or '未设定'}
- 时间背景:{project.world_time_period or '未设定'}
- 地理位置:{project.world_location or '未设定'}
- 氛围基调:{project.world_atmosphere or '未设定'}
- 世界规则:{project.world_rules or '未设定'}
"""
user_requirements = f"""
已有职业情况:{existing_careers_text}
生成要求(增量式):
- 本次新增主职业:{main_career_count}
- 本次新增副职业:{sub_career_count}
- ⚠️ 重要:请生成与已有职业**不重复**的新职业,形成互补体系
- 新职业应填补已有职业体系的空缺,丰富职业多样性
- 主职业必须严格符合世界观规则,体现核心能力体系
- 副职业可以更加自由灵活,包含生产、辅助、特殊类型
"""
yield await SSEResponse.send_progress("构建AI提示词...", 20)
# 构建提示词
prompt = f"""{project_context}
{user_requirements}
请为这个小说项目生成新的补充职业(增量式)。要求:
1. **仔细分析已有职业**,避免生成重复或相似的职业
2. **填补职业体系的空缺**,让职业体系更加完善和多样化
3. 如果已有职业较少,可以生成核心基础职业
4. 如果已有职业较多,可以生成特色化、专精化的职业
返回JSON格式,结构如下:
{{
"main_careers": [
{{
"name": "职业名称",
"description": "职业描述",
"category": "职业分类(如:战斗系、法术系等)",
"stages": [
{{"level": 1, "name": "阶段名称", "description": "阶段描述"}},
{{"level": 2, "name": "阶段名称", "description": "阶段描述"}},
...
],
"max_stage": 10,
"requirements": "职业要求",
"special_abilities": "特殊能力",
"worldview_rules": "世界观规则关联",
"attribute_bonuses": {{"strength": "+10%", "intelligence": "+5%"}}
}}
],
"sub_careers": [
{{
"name": "副职业名称",
"description": "职业描述",
"category": "生产系/辅助系/特殊系",
"stages": [...],
"max_stage": 5,
"requirements": "职业要求",
"special_abilities": "特殊能力"
}}
]
}}
注意事项:
1. **避免重复**:生成的职业名称和定位不能与已有职业重复
2. **互补性**:新职业应与已有职业形成互补,丰富职业体系
3. 主职业的阶段设定要详细,体现明确的成长路径
4. 阶段名称要符合世界观特色
5. 副职业可以相对简化,但要有独特性
6. 所有职业都要符合项目的整体世界观设定
7. 只返回纯JSON,不要添加任何解释文字
"""
yield await SSEResponse.send_progress("调用AI生成新职业...", 10)
logger.info(f"🎯 开始为项目 {project_id} 生成新职业(增量式,已有{len(existing_careers)}个职业)")
try:
# 使用流式生成替代非流式
ai_response = ""
chunk_count = 0
last_progress = 10
async for chunk in user_ai_service.generate_text_stream(prompt=prompt):
chunk_count += 1
ai_response += chunk
# 发送内容块
yield await SSEResponse.send_chunk(chunk)
# 平滑更新进度(10-90%,AI生成占60%
# 每10个chunk增加约1%的进度,最多到90%
if chunk_count % 10 == 0:
# 计算进度:10% + (chunk_count / 10) * 1%,但不超过90%
current_progress = min(10 + (chunk_count // 10), 90)
if current_progress > last_progress:
last_progress = current_progress
yield await SSEResponse.send_progress(
f"AI生成职业体系中... (已生成 {len(ai_response)} 字符)",
current_progress
)
# 心跳
if chunk_count % 20 == 0:
yield await SSEResponse.send_heartbeat()
except Exception as ai_error:
logger.error(f"❌ AI服务调用异常:{str(ai_error)}")
yield await SSEResponse.send_error(f"AI服务调用失败:{str(ai_error)}")
return
if not ai_response or not ai_response.strip():
yield await SSEResponse.send_error("AI服务返回空响应")
return
yield await SSEResponse.send_progress("解析AI响应...", 91)
# 清洗并解析JSON
try:
cleaned_response = user_ai_service._clean_json_response(ai_response)
career_data = json.loads(cleaned_response)
logger.info(f"✅ 职业体系JSON解析成功")
except json.JSONDecodeError as e:
logger.error(f"❌ 职业体系JSON解析失败: {e}")
logger.error(f" 原始响应预览: {ai_response[:200]}")
yield await SSEResponse.send_error(f"AI返回的内容无法解析为JSON{str(e)}")
return
yield await SSEResponse.send_progress("保存主职业到数据库...", 93)
# 保存主职业
main_careers_created = []
for idx, career_info in enumerate(career_data.get("main_careers", [])):
try:
stages_json = json.dumps(career_info.get("stages", []), ensure_ascii=False)
attribute_bonuses = career_info.get("attribute_bonuses")
attribute_bonuses_json = json.dumps(attribute_bonuses, ensure_ascii=False) if attribute_bonuses else None
career = Career(
project_id=project_id,
name=career_info.get("name", f"未命名主职业{idx+1}"),
type="main",
description=career_info.get("description"),
category=career_info.get("category"),
stages=stages_json,
max_stage=career_info.get("max_stage", 10),
requirements=career_info.get("requirements"),
special_abilities=career_info.get("special_abilities"),
worldview_rules=career_info.get("worldview_rules"),
attribute_bonuses=attribute_bonuses_json,
source="ai"
)
db.add(career)
await db.flush()
main_careers_created.append(career.name)
logger.info(f" ✅ 创建主职业:{career.name}")
except Exception as e:
logger.error(f" ❌ 创建主职业失败:{str(e)}")
continue
yield await SSEResponse.send_progress("保存副职业到数据库...", 96)
# 保存副职业
sub_careers_created = []
for idx, career_info in enumerate(career_data.get("sub_careers", [])):
try:
stages_json = json.dumps(career_info.get("stages", []), ensure_ascii=False)
attribute_bonuses = career_info.get("attribute_bonuses")
attribute_bonuses_json = json.dumps(attribute_bonuses, ensure_ascii=False) if attribute_bonuses else None
career = Career(
project_id=project_id,
name=career_info.get("name", f"未命名副职业{idx+1}"),
type="sub",
description=career_info.get("description"),
category=career_info.get("category"),
stages=stages_json,
max_stage=career_info.get("max_stage", 5),
requirements=career_info.get("requirements"),
special_abilities=career_info.get("special_abilities"),
worldview_rules=career_info.get("worldview_rules"),
attribute_bonuses=attribute_bonuses_json,
source="ai"
)
db.add(career)
await db.flush()
sub_careers_created.append(career.name)
logger.info(f" ✅ 创建副职业:{career.name}")
except Exception as e:
logger.error(f" ❌ 创建副职业失败:{str(e)}")
continue
await db.commit()
total_main = len(existing_main_careers) + len(main_careers_created)
total_sub = len(existing_sub_careers) + len(sub_careers_created)
logger.info(f"🎉 新职业生成完成:新增主职业{len(main_careers_created)}个,新增副职业{len(sub_careers_created)}")
logger.info(f" 职业体系总数:主职业{total_main}个,副职业{total_sub}")
yield await SSEResponse.send_progress(f"新职业生成完成!(主职业{total_main}个,副职业{total_sub}个)", 100, "success")
# 发送结果数据
yield await SSEResponse.send_result({
"main_careers_count": len(main_careers_created),
"sub_careers_count": len(sub_careers_created),
"main_careers": main_careers_created,
"sub_careers": sub_careers_created
})
yield await SSEResponse.send_done()
except HTTPException as he:
logger.error(f"HTTP异常: {he.detail}")
yield await SSEResponse.send_error(he.detail, he.status_code)
except Exception as e:
logger.error(f"生成职业体系失败: {str(e)}")
yield await SSEResponse.send_error(f"生成新职业失败: {str(e)}")
return create_sse_response(generate())
@router.put("/{career_id}", response_model=CareerResponse, summary="更新职业")
async def update_career(
career_id: str,
career_update: CareerUpdate,
request: Request,
db: AsyncSession = Depends(get_db)
):
"""更新职业信息"""
result = await db.execute(
select(Career).where(Career.id == career_id)
)
career = result.scalar_one_or_none()
if not career:
raise HTTPException(status_code=404, detail="职业不存在")
# 验证用户权限
user_id = getattr(request.state, 'user_id', None)
await verify_project_access(career.project_id, user_id, db)
# 更新字段
update_data = career_update.model_dump(exclude_unset=True)
for field, value in update_data.items():
if field == "stages" and value is not None:
# 转换为JSON字符串
# model_dump() 已经将嵌套模型转换为字典,所以 value 中的元素已经是 dict
stages_list = [
stage if isinstance(stage, dict) else stage.model_dump()
for stage in value
]
setattr(career, field, json.dumps(stages_list, ensure_ascii=False))
elif field == "attribute_bonuses" and value is not None:
# 转换为JSON字符串
setattr(career, field, json.dumps(value, ensure_ascii=False))
else:
setattr(career, field, value)
await db.commit()
await db.refresh(career)
logger.info(f"✅ 更新职业成功:{career.name} (ID: {career_id})")
# 解析JSON返回
stages = json.loads(career.stages) if career.stages else []
attribute_bonuses = json.loads(career.attribute_bonuses) if career.attribute_bonuses else None
return CareerResponse(
id=career.id,
project_id=career.project_id,
name=career.name,
type=career.type,
description=career.description,
category=career.category,
stages=stages,
max_stage=career.max_stage,
requirements=career.requirements,
special_abilities=career.special_abilities,
worldview_rules=career.worldview_rules,
attribute_bonuses=attribute_bonuses,
source=career.source,
created_at=career.created_at,
updated_at=career.updated_at
)
@router.delete("/{career_id}", summary="删除职业")
async def delete_career(
career_id: str,
request: Request,
db: AsyncSession = Depends(get_db)
):
"""删除职业"""
result = await db.execute(
select(Career).where(Career.id == career_id)
)
career = result.scalar_one_or_none()
if not career:
raise HTTPException(status_code=404, detail="职业不存在")
# 验证用户权限
user_id = getattr(request.state, 'user_id', None)
await verify_project_access(career.project_id, user_id, db)
# 检查是否有角色使用该职业
char_career_result = await db.execute(
select(func.count(CharacterCareer.id)).where(CharacterCareer.career_id == career_id)
)
usage_count = char_career_result.scalar_one()
if usage_count > 0:
raise HTTPException(
status_code=400,
detail=f"该职业被{usage_count}个角色使用,无法删除。请先移除角色的职业关联。"
)
await db.delete(career)
await db.commit()
logger.info(f"✅ 删除职业成功:{career.name} (ID: {career_id})")
return {"message": "职业删除成功"}
@router.get("/{career_id}", response_model=CareerResponse, summary="获取职业详情")
async def get_career(
career_id: str,
request: Request,
db: AsyncSession = Depends(get_db)
):
"""根据ID获取职业详情"""
result = await db.execute(
select(Career).where(Career.id == career_id)
)
career = result.scalar_one_or_none()
if not career:
raise HTTPException(status_code=404, detail="职业不存在")
# 验证用户权限
user_id = getattr(request.state, 'user_id', None)
await verify_project_access(career.project_id, user_id, db)
# 解析JSON字段
stages = json.loads(career.stages) if career.stages else []
attribute_bonuses = json.loads(career.attribute_bonuses) if career.attribute_bonuses else None
return CareerResponse(
id=career.id,
project_id=career.project_id,
name=career.name,
type=career.type,
description=career.description,
category=career.category,
stages=stages,
max_stage=career.max_stage,
requirements=career.requirements,
special_abilities=career.special_abilities,
worldview_rules=career.worldview_rules,
attribute_bonuses=attribute_bonuses,
source=career.source,
created_at=career.created_at,
updated_at=career.updated_at
)
# ===== 角色职业关联API =====
@router.get("/character/{character_id}/careers", response_model=CharacterCareerResponse, summary="获取角色的职业信息")
async def get_character_careers(
character_id: str,
request: Request,
db: AsyncSession = Depends(get_db)
):
"""获取角色的所有职业信息(主职业和副职业)"""
# 验证角色存在
char_result = await db.execute(
select(Character).where(Character.id == character_id)
)
character = char_result.scalar_one_or_none()
if not character:
raise HTTPException(status_code=404, detail="角色不存在")
# 验证用户权限
user_id = getattr(request.state, 'user_id', None)
await verify_project_access(character.project_id, user_id, db)
# 获取角色的所有职业关联
result = await db.execute(
select(CharacterCareer, Career)
.join(Career, CharacterCareer.career_id == Career.id)
.where(CharacterCareer.character_id == character_id)
.order_by(CharacterCareer.career_type.desc()) # main排在前
)
career_relations = result.all()
main_career = None
sub_careers = []
for char_career, career in career_relations:
# 解析职业的阶段信息
stages = json.loads(career.stages) if career.stages else []
# 找到当前阶段信息
stage_name = "未知阶段"
stage_description = None
for stage in stages:
if stage.get("level") == char_career.current_stage:
stage_name = stage.get("name", f"{char_career.current_stage}阶段")
stage_description = stage.get("description")
break
career_detail = CharacterCareerDetail(
id=char_career.id,
character_id=char_career.character_id,
career_id=char_career.career_id,
career_name=career.name,
career_type=char_career.career_type,
current_stage=char_career.current_stage,
stage_name=stage_name,
stage_description=stage_description,
stage_progress=char_career.stage_progress,
max_stage=career.max_stage,
started_at=char_career.started_at,
reached_current_stage_at=char_career.reached_current_stage_at,
notes=char_career.notes,
created_at=char_career.created_at,
updated_at=char_career.updated_at
)
if char_career.career_type == "main":
main_career = career_detail
else:
sub_careers.append(career_detail)
return CharacterCareerResponse(
main_career=main_career,
sub_careers=sub_careers
)
@router.post("/character/{character_id}/careers/main", summary="设置角色主职业")
async def set_main_career(
character_id: str,
career_request: SetMainCareerRequest,
request: Request,
db: AsyncSession = Depends(get_db)
):
"""设置或更换角色的主职业"""
# 验证角色存在
char_result = await db.execute(
select(Character).where(Character.id == character_id)
)
character = char_result.scalar_one_or_none()
if not character:
raise HTTPException(status_code=404, detail="角色不存在")
# 验证用户权限
user_id = getattr(request.state, 'user_id', None)
await verify_project_access(character.project_id, user_id, db)
# 验证职业存在且为主职业类型
career_result = await db.execute(
select(Career).where(
Career.id == career_request.career_id,
Career.project_id == character.project_id
)
)
career = career_result.scalar_one_or_none()
if not career:
raise HTTPException(status_code=404, detail="职业不存在")
if career.type != "main":
raise HTTPException(status_code=400, detail="该职业不是主职业类型,无法设置为主职业")
# 验证阶段有效性
if career_request.current_stage > career.max_stage:
raise HTTPException(
status_code=400,
detail=f"阶段超出范围,该职业最大阶段为{career.max_stage}"
)
# 检查是否已有主职业
existing_main = await db.execute(
select(CharacterCareer).where(
CharacterCareer.character_id == character_id,
CharacterCareer.career_type == "main"
)
)
current_main = existing_main.scalar_one_or_none()
if current_main:
# 删除旧的主职业
await db.delete(current_main)
logger.info(f" 移除旧主职业关联: {current_main.career_id}")
# 创建新的主职业关联
char_career = CharacterCareer(
character_id=character_id,
career_id=career_request.career_id,
career_type="main",
current_stage=career_request.current_stage,
stage_progress=0,
started_at=career_request.started_at,
reached_current_stage_at=career_request.started_at
)
db.add(char_career)
await db.commit()
logger.info(f"✅ 设置主职业成功:角色{character.name} -> {career.name}(第{career_request.current_stage}阶段)")
return {"message": "主职业设置成功", "career_name": career.name}
@router.post("/character/{character_id}/careers/sub", summary="添加角色副职业")
async def add_sub_career(
character_id: str,
career_request: AddSubCareerRequest,
request: Request,
db: AsyncSession = Depends(get_db)
):
"""为角色添加副职业"""
# 验证角色存在
char_result = await db.execute(
select(Character).where(Character.id == character_id)
)
character = char_result.scalar_one_or_none()
if not character:
raise HTTPException(status_code=404, detail="角色不存在")
# 验证用户权限
user_id = getattr(request.state, 'user_id', None)
await verify_project_access(character.project_id, user_id, db)
# 验证职业存在且为副职业类型
career_result = await db.execute(
select(Career).where(
Career.id == career_request.career_id,
Career.project_id == character.project_id
)
)
career = career_result.scalar_one_or_none()
if not career:
raise HTTPException(status_code=404, detail="职业不存在")
if career.type != "sub":
raise HTTPException(status_code=400, detail="该职业不是副职业类型,无法添加为副职业")
# 验证阶段有效性
if career_request.current_stage > career.max_stage:
raise HTTPException(
status_code=400,
detail=f"阶段超出范围,该职业最大阶段为{career.max_stage}"
)
# 检查是否已存在
existing_check = await db.execute(
select(CharacterCareer).where(
CharacterCareer.character_id == character_id,
CharacterCareer.career_id == career_request.career_id
)
)
if existing_check.scalar_one_or_none():
raise HTTPException(status_code=400, detail="该角色已拥有此副职业")
# 检查副职业数量限制(可选,这里设置为最多5个)
sub_count_result = await db.execute(
select(func.count(CharacterCareer.id)).where(
CharacterCareer.character_id == character_id,
CharacterCareer.career_type == "sub"
)
)
sub_count = sub_count_result.scalar_one()
if sub_count >= 5:
raise HTTPException(status_code=400, detail="副职业数量已达上限(最多5个)")
# 创建副职业关联
char_career = CharacterCareer(
character_id=character_id,
career_id=career_request.career_id,
career_type="sub",
current_stage=career_request.current_stage,
stage_progress=0,
started_at=career_request.started_at,
reached_current_stage_at=career_request.started_at
)
db.add(char_career)
await db.commit()
logger.info(f"✅ 添加副职业成功:角色{character.name} -> {career.name}(第{career_request.current_stage}阶段)")
return {"message": "副职业添加成功", "career_name": career.name}
@router.put("/character/{character_id}/careers/{career_id}/stage", summary="更新职业阶段")
async def update_career_stage(
character_id: str,
career_id: str,
stage_request: UpdateCareerStageRequest,
request: Request,
db: AsyncSession = Depends(get_db)
):
"""更新角色在某个职业的阶段"""
# 验证角色职业关联存在
result = await db.execute(
select(CharacterCareer, Career, Character)
.join(Career, CharacterCareer.career_id == Career.id)
.join(Character, CharacterCareer.character_id == Character.id)
.where(
CharacterCareer.character_id == character_id,
CharacterCareer.career_id == career_id
)
)
relation_data = result.one_or_none()
if not relation_data:
raise HTTPException(status_code=404, detail="角色职业关联不存在")
char_career, career, character = relation_data
# 验证用户权限
user_id = getattr(request.state, 'user_id', None)
await verify_project_access(character.project_id, user_id, db)
# 验证新阶段有效性
if stage_request.current_stage > career.max_stage:
raise HTTPException(
status_code=400,
detail=f"阶段超出范围,该职业最大阶段为{career.max_stage}"
)
# 验证阶段递增规则(不能倒退,除非降级)
if stage_request.current_stage < char_career.current_stage:
logger.warning(f"⚠️ 角色{character.name}的职业{career.name}阶段降低:{char_career.current_stage} -> {stage_request.current_stage}")
# 更新阶段信息
char_career.current_stage = stage_request.current_stage
char_career.stage_progress = stage_request.stage_progress
if stage_request.reached_current_stage_at:
char_career.reached_current_stage_at = stage_request.reached_current_stage_at
if stage_request.notes is not None:
char_career.notes = stage_request.notes
await db.commit()
logger.info(f"✅ 更新职业阶段成功:{character.name}{career.name} -> 第{stage_request.current_stage}阶段")
return {
"message": "职业阶段更新成功",
"career_name": career.name,
"new_stage": stage_request.current_stage
}
@router.delete("/character/{character_id}/careers/{career_id}", summary="删除角色副职业")
async def remove_sub_career(
character_id: str,
career_id: str,
request: Request,
db: AsyncSession = Depends(get_db)
):
"""删除角色的副职业"""
# 验证角色职业关联存在
result = await db.execute(
select(CharacterCareer, Character)
.join(Character, CharacterCareer.character_id == Character.id)
.where(
CharacterCareer.character_id == character_id,
CharacterCareer.career_id == career_id
)
)
relation_data = result.one_or_none()
if not relation_data:
raise HTTPException(status_code=404, detail="角色职业关联不存在")
char_career, character = relation_data
# 验证用户权限
user_id = getattr(request.state, 'user_id', None)
await verify_project_access(character.project_id, user_id, db)
# 不允许删除主职业
if char_career.career_type == "main":
raise HTTPException(status_code=400, detail="无法删除主职业,只能更换")
await db.delete(char_career)
await db.commit()
logger.info(f"✅ 删除副职业成功:角色{character.name}移除职业{career_id}")
return {"message": "副职业删除成功"}
+233
View File
@@ -0,0 +1,233 @@
"""
更新日志API
提供GitHub提交历史的缓存和代理服务
"""
from fastapi import APIRouter, HTTPException, Query
from typing import List, Optional
import httpx
from datetime import datetime, timedelta
from pydantic import BaseModel
import logging
logger = logging.getLogger(__name__)
router = APIRouter()
# GitHub API配置
GITHUB_API_BASE = "https://api.github.com"
REPO_OWNER = "xiamuceer-j"
REPO_NAME = "MuMuAINovel"
# 缓存配置
_cache = {
"data": None,
"timestamp": None,
"ttl": timedelta(hours=1) # 缓存1小时
}
class GitHubAuthor(BaseModel):
"""GitHub作者信息"""
name: str
email: str
date: str
class GitHubCommitInfo(BaseModel):
"""GitHub提交信息"""
author: GitHubAuthor
message: str
class GitHubUser(BaseModel):
"""GitHub用户信息"""
login: str
avatar_url: str
class GitHubCommit(BaseModel):
"""GitHub提交数据"""
sha: str
commit: GitHubCommitInfo
html_url: str
author: Optional[GitHubUser] = None
class ChangelogResponse(BaseModel):
"""更新日志响应"""
commits: List[GitHubCommit]
cached: bool
cache_time: Optional[str] = None
def is_cache_valid() -> bool:
"""检查缓存是否有效"""
if _cache["data"] is None or _cache["timestamp"] is None:
return False
now = datetime.now()
cache_age = now - _cache["timestamp"]
return cache_age < _cache["ttl"]
async def fetch_github_commits(page: int = 1, per_page: int = 30) -> List[dict]:
"""从GitHub API获取提交历史"""
url = f"{GITHUB_API_BASE}/repos/{REPO_OWNER}/{REPO_NAME}/commits"
params = {
"author": REPO_OWNER,
"page": page,
"per_page": per_page
}
headers = {
"Accept": "application/vnd.github.v3+json",
"User-Agent": "MuMuAINovel-App"
}
try:
async with httpx.AsyncClient(timeout=30.0) as client:
response = await client.get(url, params=params, headers=headers)
response.raise_for_status()
return response.json()
except httpx.HTTPError as e:
logger.error(f"GitHub API请求失败: {str(e)}")
raise HTTPException(
status_code=502,
detail=f"获取GitHub提交历史失败: {str(e)}"
)
@router.get("/changelog", response_model=ChangelogResponse)
async def get_changelog(
page: int = Query(1, ge=1, description="页码"),
per_page: int = Query(30, ge=1, le=100, description="每页数量")
):
"""
获取更新日志
从GitHub获取项目的提交历史,支持缓存以减少API调用
- **page**: 页码,从1开始
- **per_page**: 每页返回的提交数量,最大100
"""
try:
# 只缓存第一页
if page == 1 and is_cache_valid():
logger.info("使用缓存的更新日志")
return ChangelogResponse(
commits=_cache["data"],
cached=True,
cache_time=_cache["timestamp"].isoformat()
)
# 从GitHub获取数据
logger.info(f"从GitHub获取更新日志 (page={page}, per_page={per_page})")
commits_data = await fetch_github_commits(page, per_page)
# 解析数据
commits = []
for commit_data in commits_data:
try:
commit = GitHubCommit(
sha=commit_data["sha"],
commit=GitHubCommitInfo(
author=GitHubAuthor(
name=commit_data["commit"]["author"]["name"],
email=commit_data["commit"]["author"]["email"],
date=commit_data["commit"]["author"]["date"]
),
message=commit_data["commit"]["message"]
),
html_url=commit_data["html_url"],
author=GitHubUser(
login=commit_data["author"]["login"],
avatar_url=commit_data["author"]["avatar_url"]
) if commit_data.get("author") else None
)
commits.append(commit)
except (KeyError, TypeError) as e:
logger.warning(f"解析提交数据失败: {str(e)}")
continue
# 缓存第一页数据
if page == 1:
_cache["data"] = commits
_cache["timestamp"] = datetime.now()
logger.info("已缓存更新日志")
return ChangelogResponse(
commits=commits,
cached=False,
cache_time=None
)
except HTTPException:
raise
except Exception as e:
logger.error(f"获取更新日志时发生错误: {str(e)}")
raise HTTPException(
status_code=500,
detail=f"获取更新日志失败: {str(e)}"
)
@router.post("/changelog/refresh")
async def refresh_changelog():
"""
刷新更新日志缓存
强制从GitHub重新获取最新的提交历史
"""
try:
logger.info("刷新更新日志缓存")
# 清除缓存
_cache["data"] = None
_cache["timestamp"] = None
# 重新获取
commits_data = await fetch_github_commits(1, 30)
# 解析数据
commits = []
for commit_data in commits_data:
try:
commit = GitHubCommit(
sha=commit_data["sha"],
commit=GitHubCommitInfo(
author=GitHubAuthor(
name=commit_data["commit"]["author"]["name"],
email=commit_data["commit"]["author"]["email"],
date=commit_data["commit"]["author"]["date"]
),
message=commit_data["commit"]["message"]
),
html_url=commit_data["html_url"],
author=GitHubUser(
login=commit_data["author"]["login"],
avatar_url=commit_data["author"]["avatar_url"]
) if commit_data.get("author") else None
)
commits.append(commit)
except (KeyError, TypeError) as e:
logger.warning(f"解析提交数据失败: {str(e)}")
continue
# 更新缓存
_cache["data"] = commits
_cache["timestamp"] = datetime.now()
return {
"success": True,
"message": "缓存已刷新",
"commit_count": len(commits),
"cache_time": _cache["timestamp"].isoformat()
}
except Exception as e:
logger.error(f"刷新缓存时发生错误: {str(e)}")
raise HTTPException(
status_code=500,
detail=f"刷新缓存失败: {str(e)}"
)
+2150 -168
View File
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
+490
View File
@@ -0,0 +1,490 @@
"""灵感模式API - 通过对话引导创建项目"""
from fastapi import APIRouter, Depends, Request
from sqlalchemy.ext.asyncio import AsyncSession
from typing import Dict, Any
import json
from app.database import get_db
from app.services.ai_service import AIService
from app.api.settings import get_user_ai_service
from app.services.prompt_service import PromptService
from app.logger import get_logger
router = APIRouter(prefix="/inspiration", tags=["灵感模式"])
logger = get_logger(__name__)
# 不同阶段的temperature设置(递减以保持一致性)
TEMPERATURE_SETTINGS = {
"title": 0.8, # 书名阶段可以更有创意
"description": 0.65, # 简介需要贴合书名和原始想法
"theme": 0.55, # 主题需要更加贴合
"genre": 0.45 # 类型应该很明确
}
def validate_options_response(result: Dict[str, Any], step: str, max_retries: int = 3) -> tuple[bool, str]:
"""
校验AI返回的选项格式是否正确
Returns:
(is_valid, error_message)
"""
# 检查必需字段
if "options" not in result:
return False, "缺少options字段"
options = result.get("options", [])
# 检查options是否为数组
if not isinstance(options, list):
return False, "options必须是数组"
# 检查数组长度
if len(options) < 3:
return False, f"选项数量不足,至少需要3个,当前只有{len(options)}"
if len(options) > 10:
return False, f"选项数量过多,最多10个,当前有{len(options)}"
# 检查每个选项是否为字符串且不为空
for i, option in enumerate(options):
if not isinstance(option, str):
return False, f"{i+1}个选项不是字符串类型"
if not option.strip():
return False, f"{i+1}个选项为空"
if len(option) > 500:
return False, f"{i+1}个选项过长(超过500字符)"
# 根据不同步骤进行特定校验
if step == "genre":
# 类型标签应该比较短
for i, option in enumerate(options):
if len(option) > 10:
return False, f"类型标签【{option}】过长,应该在2-10字之间"
return True, ""
@router.post("/generate-options")
async def generate_options(
data: Dict[str, Any],
http_request: Request,
db: AsyncSession = Depends(get_db),
ai_service: AIService = Depends(get_user_ai_service)
) -> Dict[str, Any]:
"""
根据当前收集的信息生成下一步的选项建议(带自动重试)
Request:
{
"step": "title", // title/description/theme/genre
"context": {
"title": "...",
"description": "...",
"theme": "..."
}
}
Response:
{
"prompt": "引导语",
"options": ["选项1", "选项2", ...]
}
"""
max_retries = 3
for attempt in range(max_retries):
try:
step = data.get("step", "title")
context = data.get("context", {})
logger.info(f"灵感模式:生成{step}阶段的选项(第{attempt + 1}次尝试)")
# 获取用户ID
user_id = getattr(http_request.state, 'user_id', None)
# 获取对应的提示词模板(根据step确定模板key)
# 新结构:每个步骤有独立的 SYSTEM 和 USER 模板
template_key_map = {
"title": ("INSPIRATION_TITLE_SYSTEM", "INSPIRATION_TITLE_USER"),
"description": ("INSPIRATION_DESCRIPTION_SYSTEM", "INSPIRATION_DESCRIPTION_USER"),
"theme": ("INSPIRATION_THEME_SYSTEM", "INSPIRATION_THEME_USER"),
"genre": ("INSPIRATION_GENRE_SYSTEM", "INSPIRATION_GENRE_USER")
}
template_keys = template_key_map.get(step)
if not template_keys:
return {
"error": f"不支持的步骤: {step}",
"prompt": "",
"options": []
}
system_key, user_key = template_keys
# 获取自定义提示词模板(分别获取 system 和 user
system_template = await PromptService.get_template(system_key, user_id, db)
user_template = await PromptService.get_template(user_key, user_id, db)
# 准备格式化参数
format_params = {
"initial_idea": context.get("initial_idea", context.get("description", "")),
"title": context.get("title", ""),
"description": context.get("description", ""),
"theme": context.get("theme", "")
}
# 格式化提示词
system_prompt = system_template.format(**format_params)
user_prompt = user_template.format(**format_params)
# 如果是重试,在提示词中强调格式要求
if attempt > 0:
system_prompt += f"\n\n⚠️ 这是第{attempt + 1}次生成,请务必严格按照JSON格式返回,确保options数组包含6个有效选项!"
# 调用AI生成选项
# 关键改进:使用递减的temperature以保持后续阶段与前文的一致性
temperature = TEMPERATURE_SETTINGS.get(step, 0.7)
logger.info(f"调用AI生成{step}选项... (temperature={temperature})")
# 流式生成并累积文本
accumulated_text = ""
async for chunk in ai_service.generate_text_stream(
prompt=user_prompt,
system_prompt=system_prompt,
temperature=temperature
):
accumulated_text += chunk
response = {"content": accumulated_text}
content = accumulated_text
logger.info(f"AI返回内容长度: {len(content)}")
# 解析JSON(使用统一的JSON清洗方法)
try:
# 使用统一的JSON清洗方法
cleaned_content = ai_service._clean_json_response(content)
result = json.loads(cleaned_content)
# 校验返回格式
is_valid, error_msg = validate_options_response(result, step)
if not is_valid:
logger.warning(f"⚠️ 第{attempt + 1}次生成格式校验失败: {error_msg}")
if attempt < max_retries - 1:
logger.info("准备重试...")
continue # 重试
else:
# 最后一次尝试也失败了
return {
"prompt": f"请为【{step}】提供内容:",
"options": ["让AI重新生成", "我自己输入"],
"error": f"AI生成格式错误({error_msg}),已自动重试{max_retries}次,请手动重试或自己输入"
}
logger.info(f"✅ 第{attempt + 1}次成功生成{len(result.get('options', []))}个有效选项")
return result
except json.JSONDecodeError as e:
logger.error(f"{attempt + 1}次JSON解析失败: {e}")
if attempt < max_retries - 1:
logger.info("JSON解析失败,准备重试...")
continue # 重试
else:
# 最后一次尝试也失败了
return {
"prompt": f"请为【{step}】提供内容:",
"options": ["让AI重新生成", "我自己输入"],
"error": f"AI返回格式错误,已自动重试{max_retries}次,请手动重试或自己输入"
}
except Exception as e:
logger.error(f"{attempt + 1}次生成失败: {e}", exc_info=True)
if attempt < max_retries - 1:
logger.info("发生异常,准备重试...")
continue
else:
return {
"error": str(e),
"prompt": "生成失败,请重试",
"options": ["重新生成", "我自己输入"]
}
# 理论上不会到这里
return {
"error": "生成失败",
"prompt": "请重试",
"options": []
}
@router.post("/refine-options")
async def refine_options(
data: Dict[str, Any],
http_request: Request,
db: AsyncSession = Depends(get_db),
ai_service: AIService = Depends(get_user_ai_service)
) -> Dict[str, Any]:
"""
基于用户反馈重新生成选项(支持多轮对话)
Request:
{
"step": "title", // 当前步骤
"context": {
"initial_idea": "...",
"title": "...",
"description": "...",
"theme": "..."
},
"feedback": "我想要更悲剧一些的主题", // 用户反馈
"previous_options": ["选项1", "选项2", ...] // 之前的选项(可选)
}
Response:
{
"prompt": "引导语",
"options": ["新选项1", "新选项2", ...]
}
"""
max_retries = 3
for attempt in range(max_retries):
try:
step = data.get("step", "title")
context = data.get("context", {})
feedback = data.get("feedback", "")
previous_options = data.get("previous_options", [])
logger.info(f"灵感模式:根据反馈重新生成{step}阶段的选项(第{attempt + 1}次尝试)")
logger.info(f"用户反馈: {feedback}")
# 获取用户ID
user_id = getattr(http_request.state, 'user_id', None)
# 获取对应的提示词模板
template_key_map = {
"title": ("INSPIRATION_TITLE_SYSTEM", "INSPIRATION_TITLE_USER"),
"description": ("INSPIRATION_DESCRIPTION_SYSTEM", "INSPIRATION_DESCRIPTION_USER"),
"theme": ("INSPIRATION_THEME_SYSTEM", "INSPIRATION_THEME_USER"),
"genre": ("INSPIRATION_GENRE_SYSTEM", "INSPIRATION_GENRE_USER")
}
template_keys = template_key_map.get(step)
if not template_keys:
return {
"error": f"不支持的步骤: {step}",
"prompt": "",
"options": []
}
system_key, user_key = template_keys
# 获取自定义提示词模板
system_template = await PromptService.get_template(system_key, user_id, db)
user_template = await PromptService.get_template(user_key, user_id, db)
# 准备格式化参数
format_params = {
"initial_idea": context.get("initial_idea", context.get("description", "")),
"title": context.get("title", ""),
"description": context.get("description", ""),
"theme": context.get("theme", "")
}
# 格式化提示词
system_prompt = system_template.format(**format_params)
user_prompt = user_template.format(**format_params)
# 添加反馈信息到提示词
feedback_instruction = f"""
⚠️ 用户对之前的选项不太满意,提供了以下反馈:
{feedback}
之前生成的选项:
{chr(10).join([f"- {opt}" for opt in previous_options]) if previous_options else "(无)"}
请根据用户的反馈调整生成策略,提供更符合用户期望的新选项。
注意:
1. 仔细理解用户的反馈意图
2. 生成的新选项要明显体现用户要求的调整方向
3. 保持与已有上下文的一致性
4. 确保返回6个有效选项
"""
system_prompt += feedback_instruction
# 如果是重试,强调格式要求
if attempt > 0:
system_prompt += f"\n\n⚠️ 这是第{attempt + 1}次生成,请务必严格按照JSON格式返回!"
# 调用AI生成选项
temperature = TEMPERATURE_SETTINGS.get(step, 0.7)
# 反馈生成时使用稍高的temperature以获得更多样化的结果
temperature = min(temperature + 0.1, 0.9)
logger.info(f"调用AI根据反馈生成{step}选项... (temperature={temperature})")
# 流式生成并累积文本
accumulated_text = ""
async for chunk in ai_service.generate_text_stream(
prompt=user_prompt,
system_prompt=system_prompt,
temperature=temperature
):
accumulated_text += chunk
content = accumulated_text
logger.info(f"AI返回内容长度: {len(content)}")
# 解析JSON
try:
cleaned_content = ai_service._clean_json_response(content)
result = json.loads(cleaned_content)
# 校验返回格式
is_valid, error_msg = validate_options_response(result, step)
if not is_valid:
logger.warning(f"⚠️ 第{attempt + 1}次生成格式校验失败: {error_msg}")
if attempt < max_retries - 1:
logger.info("准备重试...")
continue
else:
return {
"prompt": f"请为【{step}】提供内容:",
"options": ["让AI重新生成", "我自己输入"],
"error": f"AI生成格式错误({error_msg}),已自动重试{max_retries}"
}
logger.info(f"✅ 第{attempt + 1}次根据反馈成功生成{len(result.get('options', []))}个有效选项")
return result
except json.JSONDecodeError as e:
logger.error(f"{attempt + 1}次JSON解析失败: {e}")
if attempt < max_retries - 1:
logger.info("JSON解析失败,准备重试...")
continue
else:
return {
"prompt": f"请为【{step}】提供内容:",
"options": ["让AI重新生成", "我自己输入"],
"error": f"AI返回格式错误,已自动重试{max_retries}"
}
except Exception as e:
logger.error(f"{attempt + 1}次根据反馈生成失败: {e}", exc_info=True)
if attempt < max_retries - 1:
logger.info("发生异常,准备重试...")
continue
else:
return {
"error": str(e),
"prompt": "生成失败,请重试",
"options": ["重新生成", "我自己输入"]
}
return {
"error": "生成失败",
"prompt": "请重试",
"options": []
}
@router.post("/quick-generate")
async def quick_generate(
data: Dict[str, Any],
http_request: Request,
db: AsyncSession = Depends(get_db),
ai_service: AIService = Depends(get_user_ai_service)
) -> Dict[str, Any]:
"""
智能补全:根据用户已提供的部分信息,AI自动补全缺失字段
Request:
{
"title": "书名(可选)",
"description": "简介(可选)",
"theme": "主题(可选)",
"genre": ["类型1", "类型2"](可选)
}
Response:
{
"title": "补全的书名",
"description": "补全的简介",
"theme": "补全的主题",
"genre": ["补全的类型"]
}
"""
try:
logger.info("灵感模式:智能补全")
# 获取用户ID
user_id = getattr(http_request.state, 'user_id', None)
# 构建补全提示词
existing_info = []
if data.get("title"):
existing_info.append(f"- 书名:{data['title']}")
if data.get("description"):
existing_info.append(f"- 简介:{data['description']}")
if data.get("theme"):
existing_info.append(f"- 主题:{data['theme']}")
if data.get("genre"):
existing_info.append(f"- 类型:{', '.join(data['genre'])}")
existing_text = "\n".join(existing_info) if existing_info else "暂无信息"
# 获取自定义提示词模板
system_template = await PromptService.get_template("INSPIRATION_QUICK_COMPLETE", user_id, db)
# 格式化提示词
prompts = {
"system": PromptService.format_prompt(system_template, existing=existing_text),
"user": "请补全小说信息"
}
# 调用AI - 流式生成并累积文本
accumulated_text = ""
async for chunk in ai_service.generate_text_stream(
prompt=prompts["user"],
system_prompt=prompts["system"],
temperature=0.7
):
accumulated_text += chunk
response = {"content": accumulated_text}
content = accumulated_text
# 解析JSON(使用统一的JSON清洗方法)
try:
# 使用统一的JSON清洗方法
cleaned_content = ai_service._clean_json_response(content)
result = json.loads(cleaned_content)
# 合并用户已提供的信息(用户输入优先)
final_result = {
"title": data.get("title") or result.get("title", ""),
"description": data.get("description") or result.get("description", ""),
"theme": data.get("theme") or result.get("theme", ""),
"genre": data.get("genre") or result.get("genre", [])
}
logger.info(f"✅ 智能补全成功")
return final_result
except json.JSONDecodeError as e:
logger.error(f"JSON解析失败: {e}")
raise Exception("AI返回格式错误,请重试")
except Exception as e:
logger.error(f"智能补全失败: {e}", exc_info=True)
return {
"error": str(e)
}
+801
View File
@@ -0,0 +1,801 @@
"""MCP插件管理API
重构后使用统一的MCPClientFacade门面来管理所有MCP操作。
"""
from fastapi import APIRouter, HTTPException, Depends, Query, Request
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
from typing import List, Optional
from datetime import datetime
from app.database import get_db
from app.models.mcp_plugin import MCPPlugin
from app.schemas.mcp_plugin import (
MCPPluginCreate,
MCPPluginSimpleCreate,
MCPPluginUpdate,
MCPPluginResponse,
MCPToolCall,
MCPTestResult
)
import json
from app.user_manager import User
from app.mcp import mcp_client, MCPPluginConfig, PluginStatus
from app.services.mcp_test_service import mcp_test_service
from app.logger import get_logger
logger = get_logger(__name__)
router = APIRouter(prefix="/mcp/plugins", tags=["MCP插件管理"])
def require_login(request: Request) -> User:
"""依赖:要求用户已登录"""
if not hasattr(request.state, "user") or not request.state.user:
raise HTTPException(status_code=401, detail="需要登录")
return request.state.user
async def _register_plugin_to_facade(plugin: MCPPlugin, user_id: str) -> bool:
"""
将插件注册到统一门面
Args:
plugin: 插件对象
user_id: 用户ID
Returns:
是否注册成功
"""
if plugin.plugin_type in ["http", "streamable_http", "sse"] and plugin.server_url:
return await mcp_client.register(MCPPluginConfig(
user_id=user_id,
plugin_name=plugin.plugin_name,
url=plugin.server_url,
plugin_type=plugin.plugin_type,
headers=plugin.headers,
timeout=plugin.config.get('timeout', 60.0) if plugin.config else 60.0
))
else:
logger.warning(f"暂不支持的插件类型: {plugin.plugin_type}")
return False
@router.get("", response_model=List[MCPPluginResponse])
async def list_plugins(
enabled_only: bool = Query(False, description="只返回启用的插件"),
category: Optional[str] = Query(None, description="按分类筛选"),
user: User = Depends(require_login),
db: AsyncSession = Depends(get_db)
):
"""
获取用户的所有MCP插件
"""
query = select(MCPPlugin).where(MCPPlugin.user_id == user.user_id)
if enabled_only:
query = query.where(MCPPlugin.enabled == True)
if category:
query = query.where(MCPPlugin.category == category)
query = query.order_by(MCPPlugin.sort_order, MCPPlugin.created_at)
result = await db.execute(query)
plugins = result.scalars().all()
logger.info(f"用户 {user.user_id} 查询插件列表,共 {len(plugins)}")
return plugins
@router.post("", response_model=MCPPluginResponse)
async def create_plugin(
data: MCPPluginCreate,
user: User = Depends(require_login),
db: AsyncSession = Depends(get_db)
):
"""
创建新的MCP插件
"""
# 检查插件名是否已存在
result = await db.execute(
select(MCPPlugin).where(
MCPPlugin.user_id == user.user_id,
MCPPlugin.plugin_name == data.plugin_name
)
)
existing = result.scalar_one_or_none()
if existing:
raise HTTPException(status_code=400, detail=f"插件名已存在: {data.plugin_name}")
# 创建插件数据
plugin_data = data.model_dump()
# 如果没有提供display_name,使用plugin_name作为默认值
if not plugin_data.get("display_name"):
plugin_data["display_name"] = plugin_data["plugin_name"]
# 创建插件
plugin = MCPPlugin(
user_id=user.user_id,
**plugin_data
)
db.add(plugin)
await db.commit()
await db.refresh(plugin)
# 如果启用,注册到统一门面
if plugin.enabled:
success = await _register_plugin_to_facade(plugin, user.user_id)
if success:
plugin.status = "active"
else:
plugin.status = "error"
plugin.last_error = "加载失败"
await db.commit()
await db.refresh(plugin)
logger.info(f"用户 {user.user_id} 创建插件: {plugin.plugin_name}")
return plugin
@router.post("/simple", response_model=MCPPluginResponse)
async def create_plugin_simple(
data: MCPPluginSimpleCreate,
user: User = Depends(require_login),
db: AsyncSession = Depends(get_db)
):
"""
通过标准MCP配置JSON创建或更新插件(简化版)
接受格式:
{
"config_json": '{"mcpServers": {"exa": {"type": "http", "url": "...", "headers": {}}}}',
"category": "search"
}
自动从mcpServers中提取插件名称(取第一个键)
如果插件已存在,则更新;否则创建新插件
"""
try:
# 解析配置JSON
config = json.loads(data.config_json)
# 验证格式
if "mcpServers" not in config:
raise HTTPException(status_code=400, detail="配置JSON必须包含mcpServers字段")
servers = config["mcpServers"]
if not servers or len(servers) == 0:
raise HTTPException(status_code=400, detail="mcpServers不能为空")
# 自动提取第一个插件名称
plugin_name = list(servers.keys())[0]
server_config = servers[plugin_name]
logger.info(f"从配置中提取插件名称: {plugin_name}")
# 提取配置
server_type = server_config.get("type", "http")
if server_type not in ["http", "stdio", "streamable_http", "sse"]:
raise HTTPException(status_code=400, detail=f"不支持的服务器类型: {server_type}")
# 检查插件名是否已存在
result = await db.execute(
select(MCPPlugin).where(
MCPPlugin.user_id == user.user_id,
MCPPlugin.plugin_name == plugin_name
)
)
existing = result.scalar_one_or_none()
# 构建插件数据
plugin_data = {
"plugin_name": plugin_name,
"display_name": plugin_name,
"plugin_type": server_type,
"enabled": data.enabled,
"category": data.category,
"sort_order": 0
}
if server_type in ["http", "streamable_http", "sse"]:
plugin_data["server_url"] = server_config.get("url")
plugin_data["headers"] = server_config.get("headers", {})
if not plugin_data["server_url"]:
raise HTTPException(status_code=400, detail=f"{server_type}类型插件必须提供url字段")
elif server_type == "stdio":
plugin_data["command"] = server_config.get("command")
plugin_data["args"] = server_config.get("args", [])
plugin_data["env"] = server_config.get("env", {})
if not plugin_data["command"]:
raise HTTPException(status_code=400, detail="Stdio类型插件必须提供command字段")
if existing:
# 更新现有插件
logger.info(f"插件 {plugin_name} 已存在,执行更新操作")
# 保存旧状态
old_enabled = existing.enabled
old_plugin_name = existing.plugin_name
# 更新字段
for key, value in plugin_data.items():
setattr(existing, key, value)
plugin = existing
await db.commit()
await db.refresh(plugin)
# 数据库完成后进行MCP操作
if old_enabled:
try:
await mcp_client.unregister(user.user_id, old_plugin_name)
except Exception as e:
logger.warning(f"注销旧插件出错: {e}")
if plugin.enabled:
try:
success = await _register_plugin_to_facade(plugin, user.user_id)
plugin.status = "active" if success else "error"
plugin.last_error = None if success else "加载失败"
await db.commit()
except Exception as e:
logger.error(f"注册插件失败: {e}")
plugin.status = "error"
plugin.last_error = str(e)
await db.commit()
logger.info(f"用户 {user.user_id} 更新插件: {plugin_name}")
else:
# 创建新插件
plugin = MCPPlugin(
user_id=user.user_id,
**plugin_data
)
db.add(plugin)
await db.commit()
await db.refresh(plugin)
# 数据库完成后进行MCP操作
if plugin.enabled:
try:
success = await _register_plugin_to_facade(plugin, user.user_id)
plugin.status = "active" if success else "error"
plugin.last_error = None if success else "加载失败"
await db.commit()
except Exception as e:
logger.error(f"注册插件失败: {e}")
plugin.status = "error"
plugin.last_error = str(e)
await db.commit()
logger.info(f"用户 {user.user_id} 通过简化配置创建插件: {plugin_name}")
return plugin
except json.JSONDecodeError as e:
raise HTTPException(status_code=400, detail=f"配置JSON格式错误: {str(e)}")
except HTTPException:
raise
except Exception as e:
logger.error(f"创建插件失败: {str(e)}")
raise HTTPException(status_code=500, detail=f"创建插件失败: {str(e)}")
@router.get("/{plugin_id}", response_model=MCPPluginResponse)
async def get_plugin(
plugin_id: str,
user: User = Depends(require_login),
db: AsyncSession = Depends(get_db)
):
"""
获取插件详情
"""
result = await db.execute(
select(MCPPlugin).where(
MCPPlugin.id == plugin_id,
MCPPlugin.user_id == user.user_id
)
)
plugin = result.scalar_one_or_none()
if not plugin:
raise HTTPException(status_code=404, detail="插件不存在")
return plugin
@router.put("/{plugin_id}", response_model=MCPPluginResponse)
async def update_plugin(
plugin_id: str,
data: MCPPluginUpdate,
user: User = Depends(require_login),
db: AsyncSession = Depends(get_db)
):
"""
更新插件配置
"""
result = await db.execute(
select(MCPPlugin).where(
MCPPlugin.id == plugin_id,
MCPPlugin.user_id == user.user_id
)
)
plugin = result.scalar_one_or_none()
if not plugin:
raise HTTPException(status_code=404, detail="插件不存在")
# 更新字段
update_data = data.model_dump(exclude_unset=True)
for key, value in update_data.items():
setattr(plugin, key, value)
await db.commit()
await db.refresh(plugin)
# 如果插件已启用,重新注册
if plugin.enabled:
await mcp_client.unregister(user.user_id, plugin.plugin_name)
await _register_plugin_to_facade(plugin, user.user_id)
logger.info(f"用户 {user.user_id} 更新插件: {plugin.plugin_name}")
return plugin
@router.delete("/{plugin_id}")
async def delete_plugin(
plugin_id: str,
user: User = Depends(require_login),
db: AsyncSession = Depends(get_db)
):
"""
删除插件
"""
result = await db.execute(
select(MCPPlugin).where(
MCPPlugin.id == plugin_id,
MCPPlugin.user_id == user.user_id
)
)
plugin = result.scalar_one_or_none()
if not plugin:
raise HTTPException(status_code=404, detail="插件不存在")
# 从统一门面注销
await mcp_client.unregister(user.user_id, plugin.plugin_name)
# 删除数据库记录
await db.delete(plugin)
await db.commit()
logger.info(f"用户 {user.user_id} 删除插件: {plugin.plugin_name}")
return {"message": "插件已删除", "plugin_name": plugin.plugin_name}
@router.post("/{plugin_id}/toggle", response_model=MCPPluginResponse)
async def toggle_plugin(
plugin_id: str,
enabled: bool = Query(..., description="启用或禁用"),
user: User = Depends(require_login),
db: AsyncSession = Depends(get_db)
):
"""
启用或禁用插件
"""
result = await db.execute(
select(MCPPlugin).where(
MCPPlugin.id == plugin_id,
MCPPlugin.user_id == user.user_id
)
)
plugin = result.scalar_one_or_none()
if not plugin:
raise HTTPException(status_code=404, detail="插件不存在")
# 保存插件信息用于后续MCP操作
plugin_name = plugin.plugin_name
plugin_type = plugin.plugin_type
server_url = plugin.server_url
headers = plugin.headers
config = plugin.config
# 先更新数据库状态
plugin.enabled = enabled
if not enabled:
plugin.status = "inactive"
await db.commit()
await db.refresh(plugin)
# 数据库操作完成后,再进行MCP操作
if enabled:
# 启用:注册到统一门面
try:
if plugin_type in ["http", "streamable_http", "sse"] and server_url:
success = await mcp_client.register(MCPPluginConfig(
user_id=user.user_id,
plugin_name=plugin_name,
url=server_url,
plugin_type=plugin_type,
headers=headers,
timeout=config.get('timeout', 60.0) if config else 60.0
))
else:
success = False
# 更新状态
plugin.status = "active" if success else "error"
plugin.last_error = None if success else "加载失败"
await db.commit()
await db.refresh(plugin)
except Exception as e:
logger.error(f"注册插件失败: {plugin_name}, 错误: {e}")
plugin.status = "error"
plugin.last_error = str(e)
await db.commit()
await db.refresh(plugin)
else:
# 禁用:从统一门面注销(不影响数据库状态)
try:
await mcp_client.unregister(user.user_id, plugin_name)
except Exception as e:
logger.warning(f"注销插件时出错(可忽略): {plugin_name}, 错误: {e}")
action = "启用" if enabled else "禁用"
logger.info(f"用户 {user.user_id} {action}插件: {plugin_name}")
return plugin
@router.post("/{plugin_id}/test", response_model=MCPTestResult)
async def test_plugin(
plugin_id: str,
user: User = Depends(require_login),
db: AsyncSession = Depends(get_db)
):
"""
测试插件连接并调用工具验证功能
使用MCPTestService进行测试
"""
result = await db.execute(
select(MCPPlugin).where(
MCPPlugin.id == plugin_id,
MCPPlugin.user_id == user.user_id
)
)
plugin = result.scalar_one_or_none()
if not plugin:
raise HTTPException(status_code=404, detail="插件不存在")
if not plugin.enabled:
return MCPTestResult(
success=False,
message="插件未启用",
error="请先启用插件",
suggestions=["点击开关按钮启用插件"]
)
# 使用测试服务
try:
test_result = await mcp_test_service.test_plugin_with_ai(plugin, user, db)
# 更新插件状态
if test_result.success:
plugin.status = "active"
plugin.last_error = None
else:
plugin.status = "error"
plugin.last_error = test_result.error
plugin.last_test_at = datetime.now()
await db.commit()
return test_result
except Exception as e:
logger.error(f"测试插件失败: {plugin.plugin_name}, 错误: {e}")
plugin.status = "error"
plugin.last_error = str(e)
plugin.last_test_at = datetime.now()
await db.commit()
raise HTTPException(status_code=500, detail=f"测试失败: {str(e)}")
async def _ensure_plugin_registered(
plugin: MCPPlugin,
user_id: str
) -> bool:
"""
确保插件已注册到统一门面
Args:
plugin: 插件对象
user_id: 用户ID
Returns:
是否成功
Raises:
HTTPException: 注册失败
"""
try:
# 使用ensure_registered方法,它会检查是否已注册
if plugin.plugin_type in ["http", "streamable_http", "sse"] and plugin.server_url:
return await mcp_client.ensure_registered(
user_id=user_id,
plugin_name=plugin.plugin_name,
url=plugin.server_url,
plugin_type=plugin.plugin_type,
headers=plugin.headers
)
return False
except ValueError as e:
logger.info(f"插件 {plugin.plugin_name} 未注册,自动注册中...")
success = await _register_plugin_to_facade(plugin, user_id)
if not success:
raise HTTPException(
status_code=500,
detail=f"插件注册失败: {plugin.plugin_name}"
)
return True
@router.get("/{plugin_id}/status")
async def get_plugin_status(
plugin_id: str,
user: User = Depends(require_login),
db: AsyncSession = Depends(get_db)
):
"""获取插件的实时状态(包括内存中的会话状态)"""
result = await db.execute(
select(MCPPlugin).where(
MCPPlugin.id == plugin_id,
MCPPlugin.user_id == user.user_id
)
)
plugin = result.scalar_one_or_none()
if not plugin:
raise HTTPException(status_code=404, detail="插件不存在")
session_stats = mcp_client.get_session_stats()
session_key = f"{user.user_id}:{plugin.plugin_name}"
session_info = next((s for s in session_stats.get("sessions", []) if s["key"] == session_key), None)
return {
"plugin_id": plugin_id,
"plugin_name": plugin.plugin_name,
"db_status": plugin.status,
"session_status": session_info["status"] if session_info else None,
"is_registered": session_info is not None,
"error_rate": session_info["error_rate"] if session_info else 0,
"in_sync": (plugin.status == session_info["status"]) if session_info else (plugin.status == "inactive"),
"timestamp": datetime.now().isoformat()
}
@router.get("/metrics")
async def get_metrics(
tool_name: Optional[str] = Query(None, description="工具名称(可选,获取特定工具的指标)"),
user: User = Depends(require_login)
):
"""
获取MCP工具调用指标
Query参数:
- tool_name: 可选,指定工具名称获取特定工具的指标
Returns:
工具调用指标字典,包含:
- total_calls: 总调用次数
- success_calls: 成功调用次数
- failed_calls: 失败调用次数
- success_rate: 成功率
- avg_duration_ms: 平均耗时(毫秒)
- last_call_time: 最后调用时间
"""
# 使用统一门面获取指标
metrics = mcp_client.get_metrics(tool_name)
return {
"metrics": metrics,
"tool_name": tool_name,
"timestamp": datetime.now().isoformat()
}
@router.get("/cache/stats")
async def get_cache_stats(
user: User = Depends(require_login)
):
"""
获取工具缓存统计信息
Returns:
缓存统计信息,包含:
- total_entries: 缓存条目总数
- total_hits: 缓存总命中次数
- cache_ttl_minutes: 缓存TTL(分钟)
- entries: 各缓存条目详情
"""
# 使用统一门面获取缓存统计
stats = mcp_client.get_cache_stats()
return {
"cache_stats": stats,
"timestamp": datetime.now().isoformat()
}
@router.get("/sessions/stats")
async def get_session_stats(
user: User = Depends(require_login)
):
"""
获取MCP会话统计信息
Returns:
会话统计信息,包含:
- total_sessions: 会话总数
- sessions: 各会话详情
"""
# 使用统一门面获取会话统计
stats = mcp_client.get_session_stats()
return {
"session_stats": stats,
"timestamp": datetime.now().isoformat()
}
@router.post("/cache/clear")
async def clear_cache(
user_id: Optional[str] = Query(None, description="用户ID(可选)"),
plugin_name: Optional[str] = Query(None, description="插件名称(可选)"),
user: User = Depends(require_login)
):
"""
清理工具缓存
Query参数:
- user_id: 可选,清理特定用户的缓存
- plugin_name: 可选,清理特定插件的缓存
说明:
- 不提供任何参数:清理所有缓存
- 只提供user_id:清理该用户的所有缓存
- 提供user_id和plugin_name:清理特定插件的缓存
"""
# 非管理员只能清理自己的缓存
if user_id and user_id != user.user_id:
raise HTTPException(status_code=403, detail="无权清理其他用户的缓存")
# 如果没有指定user_id,使用当前用户
target_user_id = user_id or user.user_id
# 使用统一门面清理缓存
mcp_client.clear_cache(target_user_id, plugin_name)
message = "已清理"
if plugin_name:
message += f"插件 {plugin_name} 的缓存"
elif target_user_id:
message += f"用户 {target_user_id} 的所有缓存"
else:
message += "所有缓存"
logger.info(f"用户 {user.user_id} {message}")
return {
"success": True,
"message": message,
"timestamp": datetime.now().isoformat()
}
@router.get("/{plugin_id}/tools")
async def get_plugin_tools(
plugin_id: str,
user: User = Depends(require_login),
db: AsyncSession = Depends(get_db)
):
"""
获取插件提供的工具列表
"""
result = await db.execute(
select(MCPPlugin).where(
MCPPlugin.id == plugin_id,
MCPPlugin.user_id == user.user_id
)
)
plugin = result.scalar_one_or_none()
if not plugin:
raise HTTPException(status_code=404, detail="插件不存在")
if not plugin.enabled:
raise HTTPException(status_code=400, detail="插件未启用")
try:
# 确保插件已注册
await _ensure_plugin_registered(plugin, user.user_id)
# 使用统一门面获取工具列表
tools = await mcp_client.get_tools(user.user_id, plugin.plugin_name)
# 更新数据库中的工具缓存
plugin.tools = tools
await db.commit()
return {
"plugin_name": plugin.plugin_name,
"tools": tools,
"count": len(tools)
}
except HTTPException:
raise
except Exception as e:
logger.error(f"获取工具列表失败: {plugin.plugin_name}, 错误: {e}")
raise HTTPException(status_code=500, detail=f"获取工具列表失败: {str(e)}")
@router.post("/call")
async def call_mcp_tool(
data: MCPToolCall,
user: User = Depends(require_login),
db: AsyncSession = Depends(get_db)
):
"""
调用MCP工具
"""
# 获取插件
result = await db.execute(
select(MCPPlugin).where(
MCPPlugin.id == data.plugin_id,
MCPPlugin.user_id == user.user_id
)
)
plugin = result.scalar_one_or_none()
if not plugin:
raise HTTPException(status_code=404, detail="插件不存在")
if not plugin.enabled:
raise HTTPException(status_code=400, detail="插件未启用")
try:
# 确保插件已注册
await _ensure_plugin_registered(plugin, user.user_id)
# 使用统一门面调用工具
tool_result = await mcp_client.call_tool(
user_id=user.user_id,
plugin_name=plugin.plugin_name,
tool_name=data.tool_name,
arguments=data.arguments
)
return {
"success": True,
"plugin_name": plugin.plugin_name,
"tool_name": data.tool_name,
"result": tool_result
}
except HTTPException:
raise
except Exception as e:
logger.error(f"调用工具失败: {plugin.plugin_name}.{data.tool_name}, 错误: {e}")
raise HTTPException(status_code=500, detail=f"工具调用失败: {str(e)}")
+55 -8
View File
@@ -6,6 +6,7 @@ from typing import List, Optional
from app.database import get_db
from app.models.memory import StoryMemory, PlotAnalysis
from app.models.chapter import Chapter
from app.models.project import Project
from app.services.memory_service import memory_service
from app.services.plot_analyzer import get_plot_analyzer
from app.services.ai_service import create_user_ai_service
@@ -17,6 +18,26 @@ logger = get_logger(__name__)
router = APIRouter(prefix="/api/memories", tags=["memories"])
async def verify_project_access(project_id: str, user_id: str, db: AsyncSession) -> Project:
"""验证用户是否有权访问指定项目"""
if not user_id:
raise HTTPException(status_code=401, detail="未登录")
result = await db.execute(
select(Project).where(
Project.id == project_id,
Project.user_id == user_id
)
)
project = result.scalar_one_or_none()
if not project:
logger.warning(f"项目访问被拒绝: project_id={project_id}, user_id={user_id}")
raise HTTPException(status_code=404, detail="项目不存在或无权访问")
return project
@router.post("/projects/{project_id}/analyze-chapter/{chapter_id}")
async def analyze_chapter(
project_id: str,
@@ -30,7 +51,10 @@ async def analyze_chapter(
对指定章节进行剧情分析,提取钩子、伏笔、情节点等,并存入记忆系统
"""
try:
user_id = request.state.user_id
user_id = getattr(request.state, 'user_id', None)
# 验证用户权限
await verify_project_access(project_id, user_id, db)
# 获取章节内容
result = await db.execute(
@@ -192,7 +216,10 @@ async def get_project_memories(
):
"""获取项目的记忆列表"""
try:
user_id = request.state.user_id
user_id = getattr(request.state, 'user_id', None)
# 验证用户权限
await verify_project_access(project_id, user_id, db)
# 构建查询
query = select(StoryMemory).where(StoryMemory.project_id == project_id)
@@ -222,10 +249,16 @@ async def get_project_memories(
async def get_chapter_analysis(
project_id: str,
chapter_id: str,
request: Request,
db: AsyncSession = Depends(get_db)
):
"""获取章节的剧情分析"""
try:
user_id = getattr(request.state, 'user_id', None)
# 验证用户权限
await verify_project_access(project_id, user_id, db)
result = await db.execute(
select(PlotAnalysis).where(
and_(
@@ -258,11 +291,15 @@ async def search_memories(
query: str,
memory_types: Optional[List[str]] = None,
limit: int = 10,
min_importance: float = 0.0
min_importance: float = 0.0,
db: AsyncSession = Depends(get_db)
):
"""语义搜索项目记忆"""
try:
user_id = request.state.user_id
user_id = getattr(request.state, 'user_id', None)
# 验证用户权限
await verify_project_access(project_id, user_id, db)
memories = await memory_service.search_memories(
user_id=user_id,
@@ -294,7 +331,10 @@ async def get_unresolved_foreshadows(
):
"""获取未完结的伏笔"""
try:
user_id = request.state.user_id
user_id = getattr(request.state, 'user_id', None)
# 验证用户权限
await verify_project_access(project_id, user_id, db)
# 从向量库搜索
foreshadows = await memory_service.find_unresolved_foreshadows(
@@ -317,11 +357,15 @@ async def get_unresolved_foreshadows(
@router.get("/projects/{project_id}/stats")
async def get_memory_stats(
project_id: str,
request: Request
request: Request,
db: AsyncSession = Depends(get_db)
):
"""获取记忆统计信息"""
try:
user_id = request.state.user_id
user_id = getattr(request.state, 'user_id', None)
# 验证用户权限
await verify_project_access(project_id, user_id, db)
stats = await memory_service.get_memory_stats(
user_id=user_id,
@@ -347,7 +391,10 @@ async def delete_chapter_memories(
):
"""删除章节的所有记忆"""
try:
user_id = request.state.user_id
user_id = getattr(request.state, 'user_id', None)
# 验证用户权限
await verify_project_access(project_id, user_id, db)
# 从数据库删除
result = await db.execute(
+306 -6
View File
@@ -1,12 +1,17 @@
"""组织管理API"""
from fastapi import APIRouter, Depends, HTTPException
from fastapi import APIRouter, Depends, HTTPException, Request
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, and_
from typing import List
from typing import List, Optional, AsyncGenerator
from pydantic import BaseModel, Field
import json
from app.database import get_db
from app.utils.sse_response import SSEResponse, create_sse_response
from app.models.relationship import Organization, OrganizationMember
from app.models.character import Character
from app.models.project import Project
from app.models.generation_history import GenerationHistory
from app.schemas.relationship import (
OrganizationCreate,
OrganizationUpdate,
@@ -17,17 +22,56 @@ from app.schemas.relationship import (
OrganizationMemberResponse,
OrganizationMemberDetailResponse
)
from app.schemas.character import CharacterResponse
from app.services.ai_service import AIService
from app.services.prompt_service import prompt_service, PromptService
from app.logger import get_logger
from app.api.settings import get_user_ai_service
router = APIRouter(prefix="/organizations", tags=["组织管理"])
logger = get_logger(__name__)
async def verify_project_access(project_id: str, user_id: str, db: AsyncSession) -> Project:
"""验证用户是否有权访问指定项目"""
if not user_id:
raise HTTPException(status_code=401, detail="未登录")
result = await db.execute(
select(Project).where(
Project.id == project_id,
Project.user_id == user_id
)
)
project = result.scalar_one_or_none()
if not project:
logger.warning(f"项目访问被拒绝: project_id={project_id}, user_id={user_id}")
raise HTTPException(status_code=404, detail="项目不存在或无权访问")
return project
class OrganizationGenerateRequest(BaseModel):
"""AI生成组织的请求模型"""
project_id: str = Field(..., description="项目ID")
name: Optional[str] = Field(None, description="组织名称")
organization_type: Optional[str] = Field(None, description="组织类型")
background: Optional[str] = Field(None, description="组织背景")
requirements: Optional[str] = Field(None, description="特殊要求")
enable_mcp: bool = Field(True, description="是否启用MCP工具增强(搜索组织架构参考)")
@router.get("/project/{project_id}", response_model=List[OrganizationDetailResponse], summary="获取项目的所有组织")
async def get_project_organizations(
project_id: str,
request: Request,
db: AsyncSession = Depends(get_db)
):
# 验证用户权限
user_id = getattr(request.state, 'user_id', None)
await verify_project_access(project_id, user_id, db)
"""
获取项目中的所有组织及其详情
@@ -67,6 +111,7 @@ async def get_project_organizations(
@router.get("/{org_id}", response_model=OrganizationResponse, summary="获取组织详情")
async def get_organization(
org_id: str,
request: Request,
db: AsyncSession = Depends(get_db)
):
"""获取组织的详细信息"""
@@ -78,12 +123,17 @@ async def get_organization(
if not org:
raise HTTPException(status_code=404, detail="组织不存在")
# 验证用户权限
user_id = getattr(request.state, 'user_id', None)
await verify_project_access(org.project_id, user_id, db)
return org
@router.post("/", response_model=OrganizationResponse, summary="创建组织")
@router.post("", response_model=OrganizationResponse, summary="创建组织")
async def create_organization(
organization: OrganizationCreate,
request: Request,
db: AsyncSession = Depends(get_db)
):
"""
@@ -92,6 +142,10 @@ async def create_organization(
- 需要关联到一个已存在的角色记录(is_organization=True
- 可以设置父组织、势力等级等属性
"""
# 验证用户权限
user_id = getattr(request.state, 'user_id', None)
await verify_project_access(organization.project_id, user_id, db)
# 验证角色是否存在且是组织
char_result = await db.execute(
select(Character).where(Character.id == organization.character_id)
@@ -124,6 +178,7 @@ async def create_organization(
async def update_organization(
org_id: str,
organization: OrganizationUpdate,
request: Request,
db: AsyncSession = Depends(get_db)
):
"""更新组织的属性"""
@@ -135,7 +190,11 @@ async def update_organization(
if not db_org:
raise HTTPException(status_code=404, detail="组织不存在")
# 更新字段
# 验证用户权限
user_id = getattr(request.state, 'user_id', None)
await verify_project_access(db_org.project_id, user_id, db)
# 更新 Organization 表字段
update_data = organization.model_dump(exclude_unset=True)
for field, value in update_data.items():
setattr(db_org, field, value)
@@ -150,6 +209,7 @@ async def update_organization(
@router.delete("/{org_id}", summary="删除组织")
async def delete_organization(
org_id: str,
request: Request,
db: AsyncSession = Depends(get_db)
):
"""删除组织(会级联删除所有成员关系)"""
@@ -161,6 +221,10 @@ async def delete_organization(
if not db_org:
raise HTTPException(status_code=404, detail="组织不存在")
# 验证用户权限
user_id = getattr(request.state, 'user_id', None)
await verify_project_access(db_org.project_id, user_id, db)
await db.delete(db_org)
await db.commit()
@@ -173,6 +237,7 @@ async def delete_organization(
@router.get("/{org_id}/members", response_model=List[OrganizationMemberDetailResponse], summary="获取组织成员")
async def get_organization_members(
org_id: str,
request: Request,
db: AsyncSession = Depends(get_db)
):
"""
@@ -184,9 +249,14 @@ async def get_organization_members(
org_result = await db.execute(
select(Organization).where(Organization.id == org_id)
)
if not org_result.scalar_one_or_none():
org = org_result.scalar_one_or_none()
if not org:
raise HTTPException(status_code=404, detail="组织不存在")
# 验证用户权限
user_id = getattr(request.state, 'user_id', None)
await verify_project_access(org.project_id, user_id, db)
# 获取成员列表
result = await db.execute(
select(OrganizationMember)
@@ -226,6 +296,7 @@ async def get_organization_members(
async def add_organization_member(
org_id: str,
member: OrganizationMemberCreate,
request: Request,
db: AsyncSession = Depends(get_db)
):
"""
@@ -242,6 +313,10 @@ async def add_organization_member(
if not org:
raise HTTPException(status_code=404, detail="组织不存在")
# 验证用户权限
user_id = getattr(request.state, 'user_id', None)
await verify_project_access(org.project_id, user_id, db)
# 验证角色存在
char_result = await db.execute(
select(Character).where(Character.id == member.character_id)
@@ -286,6 +361,7 @@ async def add_organization_member(
async def update_organization_member(
member_id: str,
member: OrganizationMemberUpdate,
request: Request,
db: AsyncSession = Depends(get_db)
):
"""更新组织成员的职位、忠诚度等信息"""
@@ -297,6 +373,14 @@ async def update_organization_member(
if not db_member:
raise HTTPException(status_code=404, detail="成员记录不存在")
# 通过成员所属的组织验证用户权限
org_result = await db.execute(
select(Organization).where(Organization.id == db_member.organization_id)
)
org = org_result.scalar_one()
user_id = getattr(request.state, 'user_id', None)
await verify_project_access(org.project_id, user_id, db)
# 更新字段
update_data = member.model_dump(exclude_unset=True)
for field, value in update_data.items():
@@ -312,6 +396,7 @@ async def update_organization_member(
@router.delete("/members/{member_id}", summary="移除组织成员")
async def remove_organization_member(
member_id: str,
request: Request,
db: AsyncSession = Depends(get_db)
):
"""
@@ -332,10 +417,225 @@ async def remove_organization_member(
select(Organization).where(Organization.id == db_member.organization_id)
)
org = org_result.scalar_one()
# 验证用户权限
user_id = getattr(request.state, 'user_id', None)
await verify_project_access(org.project_id, user_id, db)
org.member_count = max(0, org.member_count - 1)
await db.delete(db_member)
await db.commit()
logger.info(f"移除成员成功:{member_id}")
return {"message": "成员移除成功", "id": member_id}
return {"message": "成员移除成功", "id": member_id}
@router.post("/generate-stream", summary="AI生成组织(流式)")
async def generate_organization_stream(
gen_request: OrganizationGenerateRequest,
http_request: Request,
db: AsyncSession = Depends(get_db),
user_ai_service: AIService = Depends(get_user_ai_service)
):
"""
使用AI生成组织设定(支持SSE流式进度显示)
通过Server-Sent Events返回实时进度信息
"""
async def generate() -> AsyncGenerator[str, None]:
try:
# 验证用户权限和项目是否存在
user_id = getattr(http_request.state, 'user_id', None)
project = await verify_project_access(gen_request.project_id, user_id, db)
yield await SSEResponse.send_progress("开始生成组织...", 0)
# 获取已存在的角色和组织列表
yield await SSEResponse.send_progress("获取项目上下文...", 10)
existing_chars_result = await db.execute(
select(Character)
.where(Character.project_id == gen_request.project_id)
.order_by(Character.created_at.desc())
)
existing_characters = existing_chars_result.scalars().all()
# 构建现有角色和组织信息摘要
existing_info = ""
character_list = []
organization_list = []
if existing_characters:
for c in existing_characters[:10]:
if c.is_organization:
organization_list.append(f"- {c.name} [{c.organization_type or '组织'}]")
else:
character_list.append(f"- {c.name}{c.role_type or '未知'}")
if character_list:
existing_info += "\n已有角色:\n" + "\n".join(character_list)
if organization_list:
existing_info += "\n\n已有组织:\n" + "\n".join(organization_list)
# 构建项目上下文
project_context = f"""
项目信息:
- 书名:{project.title}
- 主题:{project.theme or '未设定'}
- 类型:{project.genre or '未设定'}
- 时间背景:{project.world_time_period or '未设定'}
- 地理位置:{project.world_location or '未设定'}
- 氛围基调:{project.world_atmosphere or '未设定'}
- 世界规则:{project.world_rules or '未设定'}
{existing_info}
"""
user_input = f"""
用户要求:
- 组织名称:{gen_request.name or '请AI生成'}
- 组织类型:{gen_request.organization_type or '请AI根据世界观决定'}
- 背景设定:{gen_request.background or '无特殊要求'}
- 其他要求:{gen_request.requirements or ''}
"""
yield await SSEResponse.send_progress("构建AI提示词...", 5)
# 获取自定义提示词模板
template = await PromptService.get_template("SINGLE_ORGANIZATION_GENERATION", user_id, db)
# 格式化提示词
prompt = PromptService.format_prompt(
template,
project_context=project_context,
user_input=user_input
)
yield await SSEResponse.send_progress("调用AI服务生成组织...", 10)
logger.info(f"🎯 开始为项目 {gen_request.project_id} 生成组织(SSE流式)")
try:
# 使用流式生成替代非流式
ai_content = ""
chunk_count = 0
async for chunk in user_ai_service.generate_text_stream(prompt=prompt):
chunk_count += 1
ai_content += chunk
# 发送内容块
yield await SSEResponse.send_chunk(chunk)
# 定期更新字数(5-95%,AI生成占90%)
if chunk_count % 5 == 0:
progress = min(10 + (chunk_count // 5), 95)
yield await SSEResponse.send_progress(
f"AI生成组织中... ({len(ai_content)}字符)",
progress
)
# 心跳
if chunk_count % 20 == 0:
yield await SSEResponse.send_heartbeat()
except Exception as ai_error:
logger.error(f"❌ AI服务调用异常:{str(ai_error)}")
yield await SSEResponse.send_error(f"AI服务调用失败:{str(ai_error)}")
return
if not ai_content or not ai_content.strip():
yield await SSEResponse.send_error("AI服务返回空响应")
return
yield await SSEResponse.send_progress("解析AI响应...", 90)
# ✅ 使用统一的 JSON 清洗方法
try:
cleaned_response = user_ai_service._clean_json_response(ai_content)
organization_data = json.loads(cleaned_response)
logger.info(f"✅ 组织JSON解析成功")
except json.JSONDecodeError as e:
logger.error(f"❌ 组织JSON解析失败: {e}")
logger.error(f" 原始响应预览: {ai_content[:200]}")
yield await SSEResponse.send_error(f"AI返回的内容无法解析为JSON{str(e)}")
return
yield await SSEResponse.send_progress("创建组织记录...", 95)
# 创建角色记录(组织也是角色的一种)
character = Character(
project_id=gen_request.project_id,
name=organization_data.get("name", gen_request.name or "未命名组织"),
is_organization=True,
role_type="supporting",
personality=organization_data.get("personality", ""),
background=organization_data.get("background", ""),
appearance=organization_data.get("appearance", ""),
organization_type=organization_data.get("organization_type"),
organization_purpose=organization_data.get("organization_purpose"),
organization_members=json.dumps(
organization_data.get("organization_members", []),
ensure_ascii=False
),
traits=json.dumps(
organization_data.get("traits", []),
ensure_ascii=False
)
)
db.add(character)
await db.flush()
logger.info(f"✅ 组织角色创建成功:{character.name} (ID: {character.id})")
yield await SSEResponse.send_progress("创建组织详情...", 98)
# 自动创建Organization详情记录
organization = Organization(
character_id=character.id,
project_id=gen_request.project_id,
member_count=0,
power_level=organization_data.get("power_level", 50),
location=organization_data.get("location"),
motto=organization_data.get("motto"),
color=organization_data.get("color")
)
db.add(organization)
await db.flush()
logger.info(f"✅ 组织详情创建成功:{character.name} (Org ID: {organization.id})")
yield await SSEResponse.send_progress("保存生成历史...", 99)
# 记录生成历史
history = GenerationHistory(
project_id=gen_request.project_id,
prompt=prompt,
generated_content=ai_content,
model=user_ai_service.default_model
)
db.add(history)
await db.commit()
await db.refresh(character)
logger.info(f"🎉 成功生成组织: {character.name}")
yield await SSEResponse.send_progress("组织生成完成!", 100, "success")
# 发送结果数据
yield await SSEResponse.send_result({
"character": {
"id": character.id,
"name": character.name,
"organization_type": character.organization_type,
"is_organization": character.is_organization
}
})
yield await SSEResponse.send_done()
except HTTPException as he:
logger.error(f"HTTP异常: {he.detail}")
yield await SSEResponse.send_error(he.detail, he.status_code)
except Exception as e:
logger.error(f"生成组织失败: {str(e)}")
yield await SSEResponse.send_error(f"生成组织失败: {str(e)}")
return create_sse_response(generate())
+2543 -385
View File
File diff suppressed because it is too large Load Diff
+20 -6
View File
@@ -1,12 +1,12 @@
"""AI去味API - 核心特色功能"""
from fastapi import APIRouter, Depends, HTTPException
from fastapi import APIRouter, Depends, HTTPException, Request
from sqlalchemy.ext.asyncio import AsyncSession
from app.database import get_db
from app.models.generation_history import GenerationHistory
from app.schemas.polish import PolishRequest, PolishResponse
from app.services.ai_service import AIService
from app.services.prompt_service import prompt_service
from app.services.prompt_service import prompt_service, PromptService
from app.logger import get_logger
from app.api.settings import get_user_ai_service
@@ -17,6 +17,7 @@ logger = get_logger(__name__)
@router.post("", response_model=PolishResponse, summary="AI去味")
async def polish_text(
request: PolishRequest,
http_request: Request,
db: AsyncSession = Depends(get_db),
user_ai_service: AIService = Depends(get_user_ai_service)
):
@@ -32,15 +33,21 @@ async def polish_text(
这是本项目的核心特色功能!
"""
try:
# 构建AI去味提示词
prompt = prompt_service.get_denoising_prompt(
# 获取用户ID
user_id = getattr(http_request.state, 'user_id', None)
# 获取自定义提示词模板
template = await PromptService.get_template("AI_DENOISING", user_id, db)
# 格式化提示词
prompt = PromptService.format_prompt(
template,
original_text=request.original_text
)
logger.info(f"开始AI去味处理,原文长度: {len(request.original_text)}")
# 调用AI进行去味处理
polished_text = await ai_service.generate_text(
polished_text = await user_ai_service.generate_text(
prompt=prompt,
provider=request.provider,
model=request.model,
@@ -85,6 +92,7 @@ async def polish_batch(
project_id: int = None,
provider: str = None,
model: str = None,
http_request: Request = None,
db: AsyncSession = Depends(get_db),
user_ai_service: AIService = Depends(get_user_ai_service)
):
@@ -94,12 +102,18 @@ async def polish_batch(
适用于一次性处理多个章节或段落
"""
try:
# 获取用户ID
user_id = getattr(http_request.state, 'user_id', None) if http_request else None
results = []
for idx, text in enumerate(texts):
logger.info(f"处理第 {idx+1}/{len(texts)} 个文本")
prompt = prompt_service.get_denoising_prompt(original_text=text)
# 获取自定义提示词模板
template = await PromptService.get_template("AI_DENOISING", user_id, db)
# 格式化提示词
prompt = PromptService.format_prompt(template, original_text=text)
polished_text = await user_ai_service.generate_text(
prompt=prompt,
+188 -43
View File
@@ -1,5 +1,5 @@
"""项目管理API"""
from fastapi import APIRouter, Depends, HTTPException, UploadFile, File
from fastapi import APIRouter, Depends, HTTPException, UploadFile, File, Request
from fastapi.responses import Response
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, func, delete
@@ -13,6 +13,7 @@ from app.models.outline import Outline
from app.models.chapter import Chapter
from app.models.generation_history import GenerationHistory
from app.models.relationship import CharacterRelationship, Organization, OrganizationMember
from app.models.memory import StoryMemory, PlotAnalysis
from app.schemas.project import (
ProjectCreate,
ProjectUpdate,
@@ -25,6 +26,7 @@ from app.schemas.import_export import (
ImportResult
)
from app.services.import_export_service import ImportExportService
from app.services.memory_service import memory_service
from app.logger import get_logger
from app.utils.data_consistency import (
run_full_data_consistency_check,
@@ -39,17 +41,31 @@ router = APIRouter(prefix="/projects", tags=["项目管理"])
@router.post("", response_model=ProjectResponse, summary="创建项目")
async def create_project(
project: ProjectCreate,
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
request: Request = None
):
try:
logger.info(f"创建新项目: {project.title}")
db_project = Project(**project.model_dump())
# 从认证中间件获取用户ID
user_id = getattr(request.state, 'user_id', None)
if not user_id:
logger.warning("未登录用户尝试创建项目")
raise HTTPException(status_code=401, detail="未登录")
logger.info(f"创建新项目: {project.title}, user_id={user_id}")
# 创建项目时自动设置user_id
project_data = project.model_dump()
project_data['user_id'] = user_id
db_project = Project(**project_data)
db.add(db_project)
await db.commit()
await db.refresh(db_project)
logger.info(f"项目创建成功: {db_project.id}")
logger.info(f"项目创建成功: project_id={db_project.id}, user_id={user_id}")
return db_project
except HTTPException:
raise
except Exception as e:
logger.error(f"创建项目失败: {str(e)}", exc_info=True)
raise
@@ -59,24 +75,38 @@ async def create_project(
async def get_projects(
skip: int = 0,
limit: int = 100,
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
request: Request = None
):
"""获取所有项目列表"""
"""获取当前用户的项目列表"""
try:
logger.debug(f"获取项目列表: skip={skip}, limit={limit}")
count_result = await db.execute(select(func.count(Project.id)))
# 从认证中间件获取用户ID
user_id = getattr(request.state, 'user_id', None)
if not user_id:
logger.warning("未登录用户尝试获取项目列表")
raise HTTPException(status_code=401, detail="未登录")
logger.debug(f"获取项目列表: user_id={user_id}, skip={skip}, limit={limit}")
# 只查询当前用户的项目
count_result = await db.execute(
select(func.count(Project.id)).where(Project.user_id == user_id)
)
total = count_result.scalar_one()
result = await db.execute(
select(Project)
.where(Project.user_id == user_id)
.order_by(Project.updated_at.desc())
.offset(skip)
.limit(limit)
)
projects = result.scalars().all()
logger.info(f"获取项目列表成功: 共{total}个项目")
logger.info(f"获取项目列表成功: user_id={user_id},{total}个项目")
return ProjectListResponse(total=total, items=projects)
except HTTPException:
raise
except Exception as e:
logger.error(f"获取项目列表失败: {str(e)}", exc_info=True)
raise
@@ -85,17 +115,29 @@ async def get_projects(
@router.get("/{project_id}", response_model=ProjectResponse, summary="获取项目详情")
async def get_project(
project_id: str,
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
request: Request = None
):
try:
logger.debug(f"获取项目详情: {project_id}")
# 从认证中间件获取用户ID
user_id = getattr(request.state, 'user_id', None)
if not user_id:
logger.warning("未登录用户尝试获取项目详情")
raise HTTPException(status_code=401, detail="未登录")
logger.debug(f"获取项目详情: project_id={project_id}, user_id={user_id}")
# 只查询当前用户的项目
result = await db.execute(
select(Project).where(Project.id == project_id)
select(Project).where(
Project.id == project_id,
Project.user_id == user_id
)
)
project = result.scalar_one_or_none()
if not project:
logger.warning(f"项目不存在: {project_id}")
logger.warning(f"项目不存在或无权访问: project_id={project_id}, user_id={user_id}")
raise HTTPException(status_code=404, detail="项目不存在")
logger.info(f"获取项目详情成功: {project.title}")
@@ -111,17 +153,29 @@ async def get_project(
async def update_project(
project_id: str,
project_update: ProjectUpdate,
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
request: Request = None
):
try:
logger.info(f"更新项目: {project_id}")
# 从认证中间件获取用户ID
user_id = getattr(request.state, 'user_id', None)
if not user_id:
logger.warning("未登录用户尝试更新项目")
raise HTTPException(status_code=401, detail="未登录")
logger.info(f"更新项目: project_id={project_id}, user_id={user_id}")
# 只查询当前用户的项目
result = await db.execute(
select(Project).where(Project.id == project_id)
select(Project).where(
Project.id == project_id,
Project.user_id == user_id
)
)
project = result.scalar_one_or_none()
if not project:
logger.warning(f"项目不存在: {project_id}")
logger.warning(f"项目不存在或无权访问: project_id={project_id}, user_id={user_id}")
raise HTTPException(status_code=404, detail="项目不存在")
update_data = project_update.model_dump(exclude_unset=True)
@@ -143,21 +197,43 @@ async def update_project(
@router.delete("/{project_id}", summary="删除项目")
async def delete_project(
project_id: str,
request: Request,
db: AsyncSession = Depends(get_db)
):
try:
logger.info(f"删除项目: {project_id}")
# 从认证中间件获取用户ID
user_id = getattr(request.state, 'user_id', None)
if not user_id:
logger.warning("未登录用户尝试删除项目")
raise HTTPException(status_code=401, detail="未登录")
logger.info(f"删除项目: project_id={project_id}, user_id={user_id}")
# 只查询当前用户的项目
result = await db.execute(
select(Project).where(Project.id == project_id)
select(Project).where(
Project.id == project_id,
Project.user_id == user_id
)
)
project = result.scalar_one_or_none()
if not project:
logger.warning(f"项目不存在: {project_id}")
logger.warning(f"项目不存在或无权访问: project_id={project_id}, user_id={user_id}")
raise HTTPException(status_code=404, detail="项目不存在")
project_title = project.title
# 删除向量数据库中的记忆(user_id已在上面获取)
if user_id:
try:
await memory_service.delete_project_memories(user_id, project_id)
logger.info(f"✅ 向量数据库清理成功")
except Exception as e:
logger.warning(f"⚠️ 向量数据库清理失败(继续删除其他数据): {str(e)}")
else:
logger.warning(f"⚠️ 未找到用户ID,跳过向量数据库清理")
relationships_result = await db.execute(
delete(CharacterRelationship).where(CharacterRelationship.project_id == project_id)
)
@@ -200,11 +276,14 @@ async def delete_project(
)
logger.debug(f"删除角色数: {characters_result.rowcount}")
# 注意:StoryMemory和PlotAnalysis会通过数据库级联删除自动清理
# 但向量数据库已在上面手动清理
await db.delete(project)
await db.commit()
logger.info(f"项目删除成功: {project_title}")
return {"message": "项目及所有关联数据删除成功"}
return {"message": "项目及所有关联数据(包括向量数据库)删除成功"}
except HTTPException:
raise
except Exception as e:
@@ -215,22 +294,33 @@ async def delete_project(
@router.get("/{project_id}/export", summary="导出项目章节为TXT")
async def export_project_chapters(
project_id: str,
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
request: Request = None
):
"""
导出项目的所有章节内容为TXT文本文件
按章节顺序组织包含项目基本信息
"""
try:
logger.info(f"开始导出项目: {project_id}")
# 从认证中间件获取用户ID
user_id = getattr(request.state, 'user_id', None)
if not user_id:
logger.warning("未登录用户尝试导出项目")
raise HTTPException(status_code=401, detail="未登录")
logger.info(f"开始导出项目: project_id={project_id}, user_id={user_id}")
# 只查询当前用户的项目
result = await db.execute(
select(Project).where(Project.id == project_id)
select(Project).where(
Project.id == project_id,
Project.user_id == user_id
)
)
project = result.scalar_one_or_none()
if not project:
logger.warning(f"项目不存在: {project_id}")
logger.warning(f"项目不存在或无权访问: project_id={project_id}, user_id={user_id}")
raise HTTPException(status_code=404, detail="项目不存在")
chapters_result = await db.execute(
@@ -264,6 +354,7 @@ async def export_project_chapters(
txt_content.append("\n" + "=" * 80 + "\n\n")
for chapter in chapters:
# 只显示主章节号,不显示子索引
txt_content.append(f"{chapter.chapter_number}{chapter.title}")
txt_content.append("-" * 80)
txt_content.append("") # 空行
@@ -276,7 +367,11 @@ async def export_project_chapters(
txt_content.append("\n\n" + "=" * 80 + "\n\n")
txt_content.append(f"--- 全文完 ---")
txt_content.append(f"\n导出时间: {func.now()}")
# 获取当前时间
from datetime import datetime
export_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
txt_content.append(f"\n导出时间: {export_time}")
final_content = "\n".join(txt_content)
@@ -307,6 +402,7 @@ async def export_project_chapters(
@router.post("/{project_id}/check-consistency", summary="检查数据一致性")
async def check_project_consistency(
project_id: str,
request: Request,
auto_fix: bool = True,
db: AsyncSession = Depends(get_db)
):
@@ -324,15 +420,25 @@ async def check_project_consistency(
- organization_members: 验证组织成员数据完整性
"""
try:
logger.info(f"开始数据一致性检查: {project_id}, auto_fix={auto_fix}")
# 从认证中间件获取用户ID
user_id = getattr(request.state, 'user_id', None)
if not user_id:
logger.warning("未登录用户尝试检查数据一致性")
raise HTTPException(status_code=401, detail="未登录")
logger.info(f"开始数据一致性检查: project_id={project_id}, user_id={user_id}, auto_fix={auto_fix}")
# 只查询当前用户的项目
result = await db.execute(
select(Project).where(Project.id == project_id)
select(Project).where(
Project.id == project_id,
Project.user_id == user_id
)
)
project = result.scalar_one_or_none()
if not project:
logger.warning(f"项目不存在: {project_id}")
logger.warning(f"项目不存在或无权访问: project_id={project_id}, user_id={user_id}")
raise HTTPException(status_code=404, detail="项目不存在")
report = await run_full_data_consistency_check(project_id, db, auto_fix)
@@ -350,6 +456,7 @@ async def check_project_consistency(
@router.post("/{project_id}/fix-organizations", summary="修复组织记录")
async def fix_project_organizations(
project_id: str,
request: Request,
db: AsyncSession = Depends(get_db)
):
"""
@@ -358,15 +465,25 @@ async def fix_project_organizations(
为所有is_organization=True但没有Organization记录的Character创建记录
"""
try:
logger.info(f"开始修复组织记录: {project_id}")
# 从认证中间件获取用户ID
user_id = getattr(request.state, 'user_id', None)
if not user_id:
logger.warning("未登录用户尝试修复组织记录")
raise HTTPException(status_code=401, detail="未登录")
logger.info(f"开始修复组织记录: project_id={project_id}, user_id={user_id}")
# 只查询当前用户的项目
result = await db.execute(
select(Project).where(Project.id == project_id)
select(Project).where(
Project.id == project_id,
Project.user_id == user_id
)
)
project = result.scalar_one_or_none()
if not project:
logger.warning(f"项目不存在: {project_id}")
logger.warning(f"项目不存在或无权访问: project_id={project_id}, user_id={user_id}")
raise HTTPException(status_code=404, detail="项目不存在")
fixed_count, total_count = await fix_missing_organization_records(project_id, db)
@@ -388,6 +505,7 @@ async def fix_project_organizations(
@router.post("/{project_id}/fix-member-counts", summary="修复成员计数")
async def fix_project_member_counts(
project_id: str,
request: Request,
db: AsyncSession = Depends(get_db)
):
"""
@@ -396,15 +514,25 @@ async def fix_project_member_counts(
从实际成员记录重新计算每个组织的member_count
"""
try:
logger.info(f"开始修复成员计数: {project_id}")
# 从认证中间件获取用户ID
user_id = getattr(request.state, 'user_id', None)
if not user_id:
logger.warning("未登录用户尝试修复成员计数")
raise HTTPException(status_code=401, detail="未登录")
logger.info(f"开始修复成员计数: project_id={project_id}, user_id={user_id}")
# 只查询当前用户的项目
result = await db.execute(
select(Project).where(Project.id == project_id)
select(Project).where(
Project.id == project_id,
Project.user_id == user_id
)
)
project = result.scalar_one_or_none()
if not project:
logger.warning(f"项目不存在: {project_id}")
logger.warning(f"项目不存在或无权访问: project_id={project_id}, user_id={user_id}")
raise HTTPException(status_code=404, detail="项目不存在")
fixed_count, total_count = await fix_organization_member_counts(project_id, db)
@@ -426,6 +554,7 @@ async def fix_project_member_counts(
@router.post("/{project_id}/export-data", summary="导出项目数据为JSON")
async def export_project_data(
project_id: str,
request: Request,
options: ExportOptions,
db: AsyncSession = Depends(get_db)
):
@@ -440,16 +569,25 @@ async def export_project_data(
JSON文件下载
"""
try:
logger.info(f"开始导出项目数据: {project_id}")
# 从认证中间件获取用户ID
user_id = getattr(request.state, 'user_id', None)
if not user_id:
logger.warning("未登录用户尝试导出项目数据")
raise HTTPException(status_code=401, detail="未登录")
# 检查项目是否存在
logger.info(f"开始导出项目数据: project_id={project_id}, user_id={user_id}")
# 只查询当前用户的项目
result = await db.execute(
select(Project).where(Project.id == project_id)
select(Project).where(
Project.id == project_id,
Project.user_id == user_id
)
)
project = result.scalar_one_or_none()
if not project:
logger.warning(f"项目不存在: {project_id}")
logger.warning(f"项目不存在或无权访问: project_id={project_id}, user_id={user_id}")
raise HTTPException(status_code=404, detail="项目不存在")
# 导出数据
@@ -538,6 +676,7 @@ async def validate_import_file(
@router.post("/import", response_model=ImportResult, summary="导入项目")
async def import_project(
file: UploadFile = File(...),
request: Request = None,
db: AsyncSession = Depends(get_db)
):
"""
@@ -550,7 +689,13 @@ async def import_project(
导入结果
"""
try:
logger.info(f"开始导入项目: {file.filename}")
# 从认证中间件获取用户ID
user_id = getattr(request.state, 'user_id', None)
if not user_id:
logger.warning("未登录用户尝试导入项目")
raise HTTPException(status_code=401, detail="未登录")
logger.info(f"开始导入项目: {file.filename}, user_id={user_id}")
# 检查文件类型
if not file.filename.endswith('.json'):
@@ -570,8 +715,8 @@ async def import_project(
except json.JSONDecodeError as e:
raise HTTPException(status_code=400, detail=f"无效的JSON格式: {str(e)}")
# 导入数据
import_result = await ImportExportService.import_project(data, db)
# 导入数据(传入user_id
import_result = await ImportExportService.import_project(data, db, user_id)
if import_result.success:
logger.info(f"项目导入成功: {import_result.project_id}")
+630
View File
@@ -0,0 +1,630 @@
"""提示词模板管理 API"""
from fastapi import APIRouter, HTTPException, Depends, Query, Request
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, func, delete
from typing import List, Optional
from datetime import datetime
import json
import hashlib
from app.database import get_db
from app.models.prompt_template import PromptTemplate
from app.schemas.prompt_template import (
PromptTemplateCreate,
PromptTemplateUpdate,
PromptTemplateResponse,
PromptTemplateListResponse,
PromptTemplateCategoryResponse,
PromptTemplateExport,
PromptTemplateExportItem,
PromptTemplateImportResult,
PromptTemplatePreviewRequest
)
from app.services.prompt_service import PromptService
from app.logger import get_logger
logger = get_logger(__name__)
def calculate_content_hash(content: str) -> str:
"""计算模板内容的SHA256哈希值"""
return hashlib.sha256(content.strip().encode('utf-8')).hexdigest()[:16]
router = APIRouter(prefix="/prompt-templates", tags=["提示词模板管理"])
@router.get("", response_model=PromptTemplateListResponse)
async def get_all_templates(
request: Request,
category: Optional[str] = Query(None, description="按分类筛选"),
is_active: Optional[bool] = Query(None, description="按启用状态筛选"),
db: AsyncSession = Depends(get_db)
):
"""
获取用户所有提示词模板
"""
# 从认证中间件获取用户ID
user_id = getattr(request.state, 'user_id', None)
if not user_id:
raise HTTPException(status_code=401, detail="未登录")
query = select(PromptTemplate).where(PromptTemplate.user_id == user_id)
if category:
query = query.where(PromptTemplate.category == category)
if is_active is not None:
query = query.where(PromptTemplate.is_active == is_active)
query = query.order_by(PromptTemplate.category, PromptTemplate.template_key)
result = await db.execute(query)
templates = result.scalars().all()
# 获取所有分类
categories_result = await db.execute(
select(PromptTemplate.category)
.where(PromptTemplate.user_id == user_id)
.distinct()
)
categories = [c for c in categories_result.scalars().all() if c]
return PromptTemplateListResponse(
templates=templates,
total=len(templates),
categories=sorted(categories)
)
@router.get("/categories", response_model=List[PromptTemplateCategoryResponse])
async def get_templates_by_category(
request: Request,
db: AsyncSession = Depends(get_db)
):
"""
按分类获取提示词模板合并用户自定义和系统默认
"""
# 从认证中间件获取用户ID
user_id = getattr(request.state, 'user_id', None)
if not user_id:
raise HTTPException(status_code=401, detail="未登录")
# 1. 查询用户自定义模板
result = await db.execute(
select(PromptTemplate)
.where(PromptTemplate.user_id == user_id)
.order_by(PromptTemplate.category, PromptTemplate.template_key)
)
user_templates = result.scalars().all()
# 2. 获取所有系统默认模板
system_templates = PromptService.get_all_system_templates()
# 3. 构建用户自定义模板的键集合
user_template_keys = {t.template_key for t in user_templates}
# 4. 合并模板:用户自定义的 + 未自定义的系统默认
all_templates = []
current_time = datetime.now()
# 添加用户自定义的模板
for user_template in user_templates:
user_template.is_system_default = False # 标记为已自定义
all_templates.append(user_template)
# 添加未自定义的系统默认模板
for sys_template in system_templates:
if sys_template['template_key'] not in user_template_keys:
# 这个系统模板用户还没有自定义,创建临时对象
template_obj = PromptTemplate(
id=sys_template['template_key'], # 使用template_key作为临时ID
user_id=user_id,
template_key=sys_template['template_key'],
template_name=sys_template['template_name'],
template_content=sys_template['content'],
description=sys_template['description'],
category=sys_template['category'],
parameters=json.dumps(sys_template['parameters']),
is_active=True,
is_system_default=True,
created_at=current_time,
updated_at=current_time
)
all_templates.append(template_obj)
# 5. 按分类分组
category_dict = {}
for template in all_templates:
cat = template.category or "未分类"
if cat not in category_dict:
category_dict[cat] = []
category_dict[cat].append(template)
# 6. 构建响应
response = []
for category, temps in sorted(category_dict.items()):
# 按template_key排序,确保顺序一致
temps.sort(key=lambda t: t.template_key)
response.append(PromptTemplateCategoryResponse(
category=category,
count=len(temps),
templates=temps
))
return response
@router.get("/system-defaults")
async def get_system_defaults(
request: Request
):
"""
获取所有系统默认提示词模板
"""
# 从认证中间件获取用户ID
user_id = getattr(request.state, 'user_id', None)
if not user_id:
raise HTTPException(status_code=401, detail="未登录")
# 从PromptService获取所有系统默认模板
system_templates = PromptService.get_all_system_templates()
return {
"templates": system_templates,
"total": len(system_templates)
}
@router.get("/{template_key}", response_model=PromptTemplateResponse)
async def get_template(
template_key: str,
request: Request,
db: AsyncSession = Depends(get_db)
):
"""
获取指定的提示词模板
"""
# 从认证中间件获取用户ID
user_id = getattr(request.state, 'user_id', None)
if not user_id:
raise HTTPException(status_code=401, detail="未登录")
result = await db.execute(
select(PromptTemplate).where(
PromptTemplate.user_id == user_id,
PromptTemplate.template_key == template_key
)
)
template = result.scalar_one_or_none()
if not template:
raise HTTPException(status_code=404, detail=f"模板 {template_key} 不存在")
return template
@router.post("", response_model=PromptTemplateResponse)
async def create_or_update_template(
data: PromptTemplateCreate,
request: Request,
db: AsyncSession = Depends(get_db)
):
"""
创建或更新提示词模板Upsert
"""
# 从认证中间件获取用户ID
user_id = getattr(request.state, 'user_id', None)
if not user_id:
raise HTTPException(status_code=401, detail="未登录")
# 查找现有模板
result = await db.execute(
select(PromptTemplate).where(
PromptTemplate.user_id == user_id,
PromptTemplate.template_key == data.template_key
)
)
template = result.scalar_one_or_none()
if template:
# 更新现有模板
for key, value in data.model_dump(exclude_unset=True).items():
setattr(template, key, value)
logger.info(f"用户 {user_id} 更新模板 {data.template_key}")
else:
# 创建新模板
template = PromptTemplate(
user_id=user_id,
**data.model_dump()
)
db.add(template)
logger.info(f"用户 {user_id} 创建模板 {data.template_key}")
await db.commit()
await db.refresh(template)
return template
@router.put("/{template_key}", response_model=PromptTemplateResponse)
async def update_template(
template_key: str,
data: PromptTemplateUpdate,
request: Request,
db: AsyncSession = Depends(get_db)
):
"""
更新提示词模板
"""
# 从认证中间件获取用户ID
user_id = getattr(request.state, 'user_id', None)
if not user_id:
raise HTTPException(status_code=401, detail="未登录")
result = await db.execute(
select(PromptTemplate).where(
PromptTemplate.user_id == user_id,
PromptTemplate.template_key == template_key
)
)
template = result.scalar_one_or_none()
if not template:
raise HTTPException(status_code=404, detail=f"模板 {template_key} 不存在")
# 更新模板
update_data = data.model_dump(exclude_unset=True)
for key, value in update_data.items():
setattr(template, key, value)
await db.commit()
await db.refresh(template)
logger.info(f"用户 {user_id} 更新模板 {template_key}")
return template
@router.delete("/{template_key}")
async def delete_template(
template_key: str,
request: Request,
db: AsyncSession = Depends(get_db)
):
"""
删除自定义提示词模板
"""
# 从认证中间件获取用户ID
user_id = getattr(request.state, 'user_id', None)
if not user_id:
raise HTTPException(status_code=401, detail="未登录")
result = await db.execute(
select(PromptTemplate).where(
PromptTemplate.user_id == user_id,
PromptTemplate.template_key == template_key
)
)
template = result.scalar_one_or_none()
if not template:
raise HTTPException(status_code=404, detail=f"模板 {template_key} 不存在")
await db.delete(template)
await db.commit()
logger.info(f"用户 {user_id} 删除模板 {template_key}")
return {"message": "模板已删除", "template_key": template_key}
@router.post("/{template_key}/reset")
async def reset_to_default(
template_key: str,
request: Request,
db: AsyncSession = Depends(get_db)
):
"""
重置为系统默认模板删除用户自定义版本
"""
# 从认证中间件获取用户ID
user_id = getattr(request.state, 'user_id', None)
if not user_id:
raise HTTPException(status_code=401, detail="未登录")
# 验证系统默认模板是否存在
system_template = PromptService.get_system_template_info(template_key)
if not system_template:
raise HTTPException(status_code=404, detail=f"系统默认模板 {template_key} 不存在")
# 查找并删除用户的自定义模板
result = await db.execute(
select(PromptTemplate).where(
PromptTemplate.user_id == user_id,
PromptTemplate.template_key == template_key
)
)
template = result.scalar_one_or_none()
if template:
await db.delete(template)
await db.commit()
logger.info(f"用户 {user_id} 删除自定义模板 {template_key},恢复为系统默认")
return {"message": "已重置为系统默认", "template_key": template_key}
else:
# 用户本来就没有自定义,已经是系统默认状态
logger.info(f"用户 {user_id} 的模板 {template_key} 本来就是系统默认")
return {"message": "已是系统默认状态", "template_key": template_key}
@router.post("/export", response_model=PromptTemplateExport)
async def export_templates(
request: Request,
db: AsyncSession = Depends(get_db)
):
"""
导出所有提示词模板包括用户自定义和系统默认
- 用户自定义的提示词标记为 is_customized=true
- 系统默认的提示词标记为 is_customized=false
"""
# 从认证中间件获取用户ID
user_id = getattr(request.state, 'user_id', None)
if not user_id:
raise HTTPException(status_code=401, detail="未登录")
# 1. 查询用户自定义模板
result = await db.execute(
select(PromptTemplate).where(PromptTemplate.user_id == user_id)
)
user_templates = result.scalars().all()
# 2. 获取所有系统默认模板
system_templates = PromptService.get_all_system_templates()
# 3. 构建用户自定义模板的键集合
user_template_keys = {t.template_key for t in user_templates}
# 4. 准备导出数据
export_items = []
customized_count = 0
system_default_count = 0
# 添加用户自定义的模板
for user_template in user_templates:
# 获取对应的系统模板用于计算哈希
system_template = next(
(t for t in system_templates if t["template_key"] == user_template.template_key),
None
)
system_hash = calculate_content_hash(system_template["content"]) if system_template else None
export_items.append(PromptTemplateExportItem(
template_key=user_template.template_key,
template_name=user_template.template_name,
template_content=user_template.template_content,
description=user_template.description,
category=user_template.category,
parameters=user_template.parameters,
is_active=user_template.is_active,
is_customized=True,
system_content_hash=system_hash
))
customized_count += 1
# 添加未自定义的系统默认模板
for sys_template in system_templates:
if sys_template['template_key'] not in user_template_keys:
export_items.append(PromptTemplateExportItem(
template_key=sys_template['template_key'],
template_name=sys_template['template_name'],
template_content=sys_template['content'],
description=sys_template['description'],
category=sys_template['category'],
parameters=json.dumps(sys_template['parameters']),
is_active=True,
is_customized=False,
system_content_hash=calculate_content_hash(sys_template['content'])
))
system_default_count += 1
statistics = {
"total": len(export_items),
"customized": customized_count,
"system_default": system_default_count
}
logger.info(f"用户 {user_id} 导出了 {statistics['total']} 个模板 "
f"(自定义: {statistics['customized']}, 系统默认: {statistics['system_default']})")
return PromptTemplateExport(
templates=export_items,
export_time=datetime.now(),
version="2.0",
statistics=statistics
)
@router.post("/import", response_model=PromptTemplateImportResult)
async def import_templates(
data: PromptTemplateExport,
request: Request,
db: AsyncSession = Depends(get_db)
):
"""
智能导入提示词模板
- 如果导入的是系统默认且内容未修改 删除自定义记录使用系统默认
- 如果导入的是系统默认但内容已修改 创建自定义记录
- 如果导入的是用户自定义 创建/更新自定义记录
"""
# 从认证中间件获取用户ID
user_id = getattr(request.state, 'user_id', None)
if not user_id:
raise HTTPException(status_code=401, detail="未登录")
# 获取所有系统默认模板用于比对
system_templates = PromptService.get_all_system_templates()
system_template_dict = {t["template_key"]: t for t in system_templates}
# 统计信息
kept_system_default = 0 # 保持系统默认
created_or_updated = 0 # 创建或更新自定义
converted_to_custom = 0 # 从系统默认转为自定义
converted_templates = [] # 被转换的模板列表
for template_data in data.templates:
template_key = template_data.template_key
is_customized = template_data.is_customized
imported_content = template_data.template_content.strip()
# 查找当前用户是否已有该模板的自定义版本
result = await db.execute(
select(PromptTemplate).where(
PromptTemplate.user_id == user_id,
PromptTemplate.template_key == template_key
)
)
existing = result.scalar_one_or_none()
# 获取系统默认模板
system_template = system_template_dict.get(template_key)
if not is_customized:
# 导入的标记为系统默认
if system_template:
system_content = system_template["content"].strip()
# 比对内容是否与系统默认一致
if imported_content == system_content:
# 内容一致,删除自定义记录(如果有)
if existing:
await db.delete(existing)
logger.info(f"用户 {user_id} 的模板 {template_key} 恢复为系统默认(删除自定义)")
kept_system_default += 1
else:
# 内容不一致,用户修改过,创建/更新为自定义
if existing:
# 更新现有自定义
existing.template_name = template_data.template_name
existing.template_content = template_data.template_content
existing.description = template_data.description
existing.category = template_data.category
existing.parameters = template_data.parameters
existing.is_active = template_data.is_active
else:
# 创建新自定义
new_template = PromptTemplate(
user_id=user_id,
template_key=template_data.template_key,
template_name=template_data.template_name,
template_content=template_data.template_content,
description=template_data.description,
category=template_data.category,
parameters=template_data.parameters,
is_active=template_data.is_active
)
db.add(new_template)
converted_to_custom += 1
converted_templates.append({
"template_key": template_key,
"template_name": template_data.template_name,
"reason": "内容与系统默认不一致,已转为自定义"
})
logger.info(f"用户 {user_id} 的模板 {template_key} 内容已修改,转为自定义")
else:
# 系统中不存在该模板,作为自定义导入
if existing:
existing.template_name = template_data.template_name
existing.template_content = template_data.template_content
existing.description = template_data.description
existing.category = template_data.category
existing.parameters = template_data.parameters
existing.is_active = template_data.is_active
else:
new_template = PromptTemplate(
user_id=user_id,
template_key=template_data.template_key,
template_name=template_data.template_name,
template_content=template_data.template_content,
description=template_data.description,
category=template_data.category,
parameters=template_data.parameters,
is_active=template_data.is_active
)
db.add(new_template)
created_or_updated += 1
else:
# 导入的标记为用户自定义,直接创建/更新
if existing:
existing.template_name = template_data.template_name
existing.template_content = template_data.template_content
existing.description = template_data.description
existing.category = template_data.category
existing.parameters = template_data.parameters
existing.is_active = template_data.is_active
else:
new_template = PromptTemplate(
user_id=user_id,
template_key=template_data.template_key,
template_name=template_data.template_name,
template_content=template_data.template_content,
description=template_data.description,
category=template_data.category,
parameters=template_data.parameters,
is_active=template_data.is_active
)
db.add(new_template)
created_or_updated += 1
await db.commit()
statistics = {
"total": len(data.templates),
"kept_system_default": kept_system_default,
"created_or_updated": created_or_updated,
"converted_to_custom": converted_to_custom
}
logger.info(f"用户 {user_id} 导入完成: {statistics}")
return PromptTemplateImportResult(
message="导入成功",
statistics=statistics,
converted_templates=converted_templates
)
@router.post("/{template_key}/preview")
async def preview_template(
template_key: str,
data: PromptTemplatePreviewRequest,
request: Request
):
"""
预览提示词模板渲染变量
"""
# 从认证中间件获取用户ID
user_id = getattr(request.state, 'user_id', None)
if not user_id:
raise HTTPException(status_code=401, detail="未登录")
try:
# 使用PromptService的format_prompt方法
rendered = PromptService.format_prompt(
data.template_content,
**data.parameters
)
return {
"success": True,
"rendered_content": rendered,
"parameters_used": list(data.parameters.keys())
}
except KeyError as e:
return {
"success": False,
"error": f"缺少必需的参数: {str(e)}",
"rendered_content": None
}
except Exception as e:
return {
"success": False,
"error": f"渲染失败: {str(e)}",
"rendered_content": None
}
+47 -1
View File
@@ -1,5 +1,5 @@
"""关系管理API"""
from fastapi import APIRouter, Depends, HTTPException, Query
from fastapi import APIRouter, Depends, HTTPException, Query, Request
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, or_, and_
from typing import List, Optional
@@ -12,6 +12,7 @@ from app.models.relationship import (
OrganizationMember
)
from app.models.character import Character
from app.models.project import Project
from app.schemas.relationship import (
RelationshipTypeResponse,
CharacterRelationshipCreate,
@@ -27,6 +28,26 @@ router = APIRouter(prefix="/relationships", tags=["关系管理"])
logger = get_logger(__name__)
async def verify_project_access(project_id: str, user_id: str, db: AsyncSession) -> Project:
"""验证用户是否有权访问指定项目"""
if not user_id:
raise HTTPException(status_code=401, detail="未登录")
result = await db.execute(
select(Project).where(
Project.id == project_id,
Project.user_id == user_id
)
)
project = result.scalar_one_or_none()
if not project:
logger.warning(f"项目访问被拒绝: project_id={project_id}, user_id={user_id}")
raise HTTPException(status_code=404, detail="项目不存在或无权访问")
return project
@router.get("/types", response_model=List[RelationshipTypeResponse], summary="获取关系类型列表")
async def get_relationship_types(db: AsyncSession = Depends(get_db)):
"""获取所有预定义的关系类型"""
@@ -38,9 +59,14 @@ async def get_relationship_types(db: AsyncSession = Depends(get_db)):
@router.get("/project/{project_id}", response_model=List[CharacterRelationshipResponse], summary="获取项目的所有关系")
async def get_project_relationships(
project_id: str,
request: Request,
character_id: Optional[str] = Query(None, description="筛选特定角色的关系"),
db: AsyncSession = Depends(get_db)
):
# 验证用户权限
user_id = getattr(request.state, 'user_id', None)
await verify_project_access(project_id, user_id, db)
"""
获取项目中的所有角色关系
@@ -70,8 +96,13 @@ async def get_project_relationships(
@router.get("/graph/{project_id}", response_model=RelationshipGraphData, summary="获取关系图谱数据")
async def get_relationship_graph(
project_id: str,
request: Request,
db: AsyncSession = Depends(get_db)
):
# 验证用户权限
user_id = getattr(request.state, 'user_id', None)
await verify_project_access(project_id, user_id, db)
"""
获取用于可视化的关系图谱数据
@@ -122,6 +153,7 @@ async def get_relationship_graph(
@router.post("/", response_model=CharacterRelationshipResponse, summary="创建角色关系")
async def create_relationship(
relationship: CharacterRelationshipCreate,
request: Request,
db: AsyncSession = Depends(get_db)
):
"""
@@ -131,6 +163,10 @@ async def create_relationship(
- 可以指定预定义的关系类型或自定义关系名称
- 可以设置亲密度状态等属性
"""
# 验证用户权限
user_id = getattr(request.state, 'user_id', None)
await verify_project_access(relationship.project_id, user_id, db)
# 验证角色是否存在
char_from = await db.execute(
select(Character).where(Character.id == relationship.character_from_id)
@@ -161,6 +197,7 @@ async def create_relationship(
async def update_relationship(
relationship_id: str,
relationship: CharacterRelationshipUpdate,
request: Request,
db: AsyncSession = Depends(get_db)
):
"""更新角色关系的属性(亲密度、状态等)"""
@@ -174,6 +211,10 @@ async def update_relationship(
if not db_rel:
raise HTTPException(status_code=404, detail="关系不存在")
# 验证用户权限
user_id = getattr(request.state, 'user_id', None)
await verify_project_access(db_rel.project_id, user_id, db)
# 更新字段
update_data = relationship.model_dump(exclude_unset=True)
for field, value in update_data.items():
@@ -189,6 +230,7 @@ async def update_relationship(
@router.delete("/{relationship_id}", summary="删除关系")
async def delete_relationship(
relationship_id: str,
request: Request,
db: AsyncSession = Depends(get_db)
):
"""删除角色关系"""
@@ -202,6 +244,10 @@ async def delete_relationship(
if not db_rel:
raise HTTPException(status_code=404, detail="关系不存在")
# 验证用户权限
user_id = getattr(request.state, 'user_id', None)
await verify_project_access(db_rel.project_id, user_id, db)
await db.delete(db_rel)
await db.commit()
+618 -25
View File
@@ -4,18 +4,25 @@
from fastapi import APIRouter, HTTPException, Request, Depends
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
from typing import Dict, Any, List
from typing import Dict, Any, List, Optional
from pathlib import Path
from pydantic import BaseModel
from datetime import datetime
import httpx
import json
import time
from app.database import get_db
from app.models.settings import Settings
from app.schemas.settings import SettingsCreate, SettingsUpdate, SettingsResponse
from app.schemas.settings import (
SettingsCreate, SettingsUpdate, SettingsResponse,
APIKeyPreset, APIKeyPresetConfig, PresetCreateRequest,
PresetUpdateRequest, PresetResponse, PresetListResponse
)
from app.user_manager import User
from app.logger import get_logger
from app.config import settings as app_settings, PROJECT_ROOT
from app.services.ai_service import AIService, create_user_ai_service
from app.services.ai_service import AIService, create_user_ai_service, create_user_ai_service_with_mcp
logger = get_logger(__name__)
@@ -46,9 +53,14 @@ async def get_user_ai_service(
db: AsyncSession = Depends(get_db)
) -> AIService:
"""
依赖获取当前用户的AI服务实例
从数据库读取用户设置并创建对应的AI服务
依赖获取当前用户的AI服务实例支持MCP工具自动加载
从数据库读取用户设置并创建对应的AI服务
自动传递 user_id db_session使得 AIService 能够加载用户配置的MCP工具
根据用户的所有MCP插件状态决定是否启用MCP如果有启用的插件则启用否则禁用
"""
from app.models.mcp_plugin import MCPPlugin
result = await db.execute(
select(Settings).where(Settings.user_id == user.user_id)
)
@@ -66,14 +78,34 @@ async def get_user_ai_service(
await db.refresh(settings)
logger.info(f"用户 {user.user_id} 首次使用AI服务,已从.env同步设置到数据库")
# 使用用户设置创建AI服务实例
return create_user_ai_service(
# 查询用户的所有MCP插件状态
mcp_result = await db.execute(
select(MCPPlugin).where(MCPPlugin.user_id == user.user_id)
)
mcp_plugins = mcp_result.scalars().all()
# 检查是否有启用的MCP插件
enable_mcp = any(plugin.enabled for plugin in mcp_plugins) if mcp_plugins else False
if mcp_plugins:
enabled_count = sum(1 for p in mcp_plugins if p.enabled)
logger.info(f"用户 {user.user_id}{len(mcp_plugins)} 个MCP插件,{enabled_count} 个启用,{enable_mcp} 决定使用MCP")
else:
logger.debug(f"用户 {user.user_id} 没有配置MCP插件,禁用MCP")
# ✅ 使用支持MCP的工厂函数创建AI服务实例
# 传递 user_id 和 db_session,使得 AIService 能够自动加载用户配置的MCP工具
return create_user_ai_service_with_mcp(
api_provider=settings.api_provider,
api_key=settings.api_key,
api_base_url=settings.api_base_url or "",
model_name=settings.llm_model,
temperature=settings.temperature,
max_tokens=settings.max_tokens
max_tokens=settings.max_tokens,
user_id=user.user_id, # ✅ 传递 user_id
db_session=db, # ✅ 传递 db_session
system_prompt=settings.system_prompt,
enable_mcp=enable_mcp, # 根据MCP插件状态动态决定
)
@@ -242,10 +274,8 @@ async def get_available_models(
if "data" in data and isinstance(data["data"], list):
for model in data["data"]:
model_id = model.get("id", "")
# 过滤出常用的文本生成模型
if any(keyword in model_id.lower() for keyword in [
"gpt", "gemini", "claude", "llama", "mistral", "qwen", "deepseek"
]):
# 返回所有模型,不进行过滤
if model_id:
models.append({
"value": model_id,
"label": model_id,
@@ -266,17 +296,30 @@ async def get_available_models(
}
elif provider == "anthropic":
# Anthropic 没有公开的模型列表API
raise HTTPException(
status_code=400,
detail="Anthropic 不支持自动获取模型列表,请手动输入模型名称"
)
# Anthropic models API
url = f"{api_base_url.rstrip('/')}/v1/models"
headers = {"x-api-key": api_key, "anthropic-version": "2023-06-01"}
response = await client.get(url, headers=headers)
response.raise_for_status()
data = response.json()
models = [{"value": m["id"], "label": m["id"], "description": m.get("display_name", "")} for m in data.get("data", [])]
return {"provider": provider, "models": models, "count": len(models)}
elif provider == "gemini":
# Gemini models API
url = f"{api_base_url.rstrip('/')}/models?key={api_key}"
response = await client.get(url)
response.raise_for_status()
data = response.json()
models = []
for m in data.get("models", []):
if "generateContent" in m.get("supportedGenerationMethods", []):
mid = m.get("name", "").replace("models/", "")
models.append({"value": mid, "label": m.get("displayName", mid), "description": ""})
return {"provider": provider, "models": models, "count": len(models)}
else:
raise HTTPException(
status_code=400,
detail=f"不支持的提供商: {provider}"
)
raise HTTPException(status_code=400, detail=f"不支持的提供商: {provider}")
except httpx.HTTPStatusError as e:
logger.error(f"获取模型列表失败 (HTTP {e.response.status_code}): {e.response.text}")
@@ -308,6 +351,227 @@ class ApiTestRequest(BaseModel):
llm_model: str
@router.post("/check-function-calling")
async def check_function_calling_support(data: ApiTestRequest):
"""
检查模型是否支持 Function Calling工具调用
基于业界最佳实践的测试方法
1. 发送包含工具定义的请求
2. 检查响应的 finish_reason 是否为 "tool_calls"
3. 验证响应中是否包含有效的 tool_calls 数据
Args:
data: 包含 API 配置的请求数据
Returns:
检测结果包含支持状态详细信息和建议
"""
api_key = data.api_key
api_base_url = data.api_base_url
provider = data.provider
llm_model = data.llm_model
try:
start_time = time.time()
# 定义一个简单的测试工具(天气查询)
test_tools = [{
"type": "function",
"function": {
"name": "get_weather",
"description": "获取指定城市的当前天气信息",
"parameters": {
"type": "object",
"properties": {
"city": {
"type": "string",
"description": "城市名称,例如:北京、上海、深圳"
},
"unit": {
"type": "string",
"enum": ["celsius", "fahrenheit"],
"description": "温度单位"
}
},
"required": ["city"]
}
}
}]
# 测试提示:故意设计一个需要调用工具的问题
test_prompt = "请告诉我北京现在的天气情况如何?"
logger.info(f"🧪 开始检测 Function Calling 支持")
logger.info(f" - 提供商: {provider}")
logger.info(f" - 模型: {llm_model}")
logger.info(f" - 测试工具: get_weather")
# 创建临时 AI 服务实例进行测试
test_service = AIService(
api_provider=provider,
api_key=api_key,
api_base_url=api_base_url,
default_model=llm_model,
default_temperature=0.3, # 使用较低温度以获得更确定的行为
default_max_tokens=200
)
# 发送带工具的测试请求
response = await test_service.generate_text(
prompt=test_prompt,
provider=provider,
model=llm_model,
temperature=0.3,
max_tokens=200,
tools=test_tools,
tool_choice="auto", # 让模型自动决定是否使用工具
auto_mcp=False # 禁用 MCP 自动加载
)
end_time = time.time()
response_time = round((end_time - start_time) * 1000, 2)
# 分析响应以确定是否支持 Function Calling
supported = False
finish_reason = None
tool_calls = None
response_content = None
if isinstance(response, dict):
# 检查 finish_reasonOpenAI 标准)
finish_reason = response.get("finish_reason")
# 检查是否有 tool_calls
if "tool_calls" in response and response["tool_calls"]:
supported = True
tool_calls = response["tool_calls"]
logger.info(f"✅ 检测到工具调用: {len(tool_calls)}")
# 记录返回的内容(如果有)
if "content" in response:
response_content = response["content"]
elif isinstance(response, str):
# 如果只返回字符串,说明不支持工具调用
response_content = response
logger.info(f" - 响应时间: {response_time}ms")
logger.info(f" - finish_reason: {finish_reason}")
logger.info(f" - 支持状态: {'✅ 支持' if supported else '❌ 不支持'}")
# 构建详细的返回信息
result = {
"success": True,
"supported": supported,
"message": "✅ 模型支持 Function Calling" if supported else "❌ 模型不支持 Function Calling",
"response_time_ms": response_time,
"provider": provider,
"model": llm_model,
"details": {
"finish_reason": finish_reason,
"has_tool_calls": bool(tool_calls),
"tool_call_count": len(tool_calls) if tool_calls else 0,
"test_tool": "get_weather",
"test_prompt": test_prompt,
"response_type": "tool_calls" if supported else "text"
}
}
# 添加工具调用详情
if tool_calls:
result["tool_calls"] = tool_calls
result["suggestions"] = [
"✅ 该模型支持 Function Calling,可以正常使用 MCP 插件",
"建议:启用需要的 MCP 插件以扩展 AI 能力",
"提示:测试成功检测到工具调用,模型能够正确解析和使用外部工具"
]
else:
result["response_preview"] = response_content[:200] if response_content else None
result["suggestions"] = [
"❌ 该模型不支持 Function Calling,无法使用 MCP 插件功能",
"建议:更换支持工具调用的模型",
"推荐模型:GPT-4 系列、GPT-4-turbo、Claude 3 Opus/Sonnet、Gemini 1.5 Pro 等",
"说明:模型返回了文本回复而非工具调用,表明不支持该功能"
]
return result
except ValueError as e:
error_msg = str(e)
logger.error(f"❌ Function Calling 检测配置错误: {error_msg}")
return {
"success": False,
"supported": False,
"message": "配置错误",
"error": error_msg,
"error_type": "ConfigurationError",
"suggestions": [
"请检查 API Key 是否正确",
"请确认 API Base URL 格式是否正确",
"请验证所选提供商与配置是否匹配"
]
}
except TimeoutError as e:
error_msg = str(e)
logger.error(f"❌ Function Calling 检测超时: {error_msg}")
return {
"success": False,
"supported": None,
"message": "检测超时",
"error": error_msg,
"error_type": "TimeoutError",
"suggestions": [
"请检查网络连接是否正常",
"请确认 API 服务是否可访问",
"建议:稍后重试或使用其他网络环境"
]
}
except Exception as e:
error_msg = str(e)
error_type = type(e).__name__
logger.error(f"❌ Function Calling 检测失败: {error_msg}")
logger.error(f" - 错误类型: {error_type}")
# 智能分析错误原因
suggestions = []
if "tool" in error_msg.lower() or "function" in error_msg.lower():
suggestions = [
"该模型可能不支持 Function Calling 功能",
"API 返回了与工具调用相关的错误",
"建议:更换支持工具调用的模型或联系 API 提供商"
]
elif "unauthorized" in error_msg.lower() or "401" in error_msg:
suggestions = [
"API Key 认证失败",
"请检查 API Key 是否正确且有效",
"请确认 API Key 是否有足够的权限"
]
elif "not found" in error_msg.lower() or "404" in error_msg:
suggestions = [
"模型不存在或不可用",
"请检查模型名称是否正确",
"请确认该模型在当前 API 中是否可用"
]
else:
suggestions = [
"检测过程中遇到未知错误",
"建议:检查所有配置参数是否正确",
"提示:查看详细错误信息以获取更多线索"
]
return {
"success": False,
"supported": False,
"message": "Function Calling 检测失败",
"error": error_msg,
"error_type": error_type,
"suggestions": suggestions
}
@router.post("/test")
async def test_api_connection(data: ApiTestRequest):
"""
@@ -351,7 +615,8 @@ async def test_api_connection(data: ApiTestRequest):
provider=provider,
model=llm_model,
temperature=0.7,
max_tokens=8000
max_tokens=8000,
auto_mcp=False # 测试时不加载MCP工具
)
end_time = time.time()
@@ -359,7 +624,10 @@ async def test_api_connection(data: ApiTestRequest):
logger.info(f"✅ API 测试成功")
logger.info(f" - 响应时间: {response_time}ms")
logger.info(f" - 响应内容: {response[:100] if response else 'N/A'}")
# 安全地处理响应内容(确保是字符串)
response_str = str(response) if response else 'N/A'
logger.info(f" - 响应内容: {response_str[:100]}")
return {
"success": True,
@@ -367,7 +635,7 @@ async def test_api_connection(data: ApiTestRequest):
"response_time_ms": response_time,
"provider": provider,
"model": llm_model,
"response_preview": response[:100] if response and len(response) > 100 else response,
"response_preview": response_str[:100] if len(response_str) > 100 else response_str,
"details": {
"api_available": True,
"model_accessible": True,
@@ -461,4 +729,329 @@ async def test_api_connection(data: ApiTestRequest):
"error": error_msg,
"error_type": error_type,
"suggestions": suggestions
}
}
# ========== API配置预设管理(零数据库改动方案)==========
async def get_user_settings(user_id: str, db: AsyncSession) -> Settings:
"""获取用户settings,如果不存在则创建"""
result = await db.execute(
select(Settings).where(Settings.user_id == user_id)
)
settings = result.scalar_one_or_none()
if not settings:
# 创建默认设置
env_defaults = read_env_defaults()
settings = Settings(
user_id=user_id,
**env_defaults,
preferences='{}' # 初始化为空JSON
)
db.add(settings)
await db.commit()
await db.refresh(settings)
logger.info(f"用户 {user_id} 首次访问,已创建默认设置")
return settings
@router.get("/presets", response_model=PresetListResponse)
async def get_presets(
user: User = Depends(require_login),
db: AsyncSession = Depends(get_db)
):
"""
获取所有API配置预设
从preferences字段读取预设列表
"""
settings = await get_user_settings(user.user_id, db)
# 解析preferences
try:
prefs = json.loads(settings.preferences or '{}')
except json.JSONDecodeError:
logger.warning(f"用户 {user.user_id} 的preferences字段JSON格式错误,重置为空")
prefs = {}
api_presets = prefs.get('api_presets', {'presets': [], 'version': '1.0'})
presets = api_presets.get('presets', [])
# 找到激活的预设
active_preset_id = next(
(p['id'] for p in presets if p.get('is_active')),
None
)
logger.info(f"用户 {user.user_id} 获取预设列表,共 {len(presets)}")
return {
"presets": presets,
"total": len(presets),
"active_preset_id": active_preset_id
}
@router.post("/presets", response_model=PresetResponse)
async def create_preset(
data: PresetCreateRequest,
user: User = Depends(require_login),
db: AsyncSession = Depends(get_db)
):
"""
创建新预设
将预设添加到preferences字段的JSON中
"""
settings = await get_user_settings(user.user_id, db)
# 解析preferences
try:
prefs = json.loads(settings.preferences or '{}')
except json.JSONDecodeError:
prefs = {}
api_presets = prefs.get('api_presets', {'presets': [], 'version': '1.0'})
presets = api_presets.get('presets', [])
# 创建新预设
new_preset = {
"id": f"preset_{int(datetime.now().timestamp() * 1000)}",
"name": data.name,
"description": data.description,
"is_active": False,
"created_at": datetime.now().isoformat(),
"config": data.config.model_dump()
}
presets.append(new_preset)
# 保存回preferences
api_presets['presets'] = presets
prefs['api_presets'] = api_presets
settings.preferences = json.dumps(prefs, ensure_ascii=False)
await db.commit()
logger.info(f"用户 {user.user_id} 创建预设: {data.name}")
return new_preset
@router.put("/presets/{preset_id}", response_model=PresetResponse)
async def update_preset(
preset_id: str,
data: PresetUpdateRequest,
user: User = Depends(require_login),
db: AsyncSession = Depends(get_db)
):
"""
更新预设
在preferences字段的JSON中更新指定预设
"""
settings = await get_user_settings(user.user_id, db)
# 解析preferences
try:
prefs = json.loads(settings.preferences or '{}')
except json.JSONDecodeError:
raise HTTPException(status_code=500, detail="配置数据格式错误")
api_presets = prefs.get('api_presets', {'presets': [], 'version': '1.0'})
presets = api_presets.get('presets', [])
# 找到并更新预设
target_preset = next((p for p in presets if p['id'] == preset_id), None)
if not target_preset:
raise HTTPException(status_code=404, detail="预设不存在")
# 更新字段
if data.name is not None:
target_preset['name'] = data.name
if data.description is not None:
target_preset['description'] = data.description
if data.config is not None:
target_preset['config'] = data.config.model_dump()
# 保存回preferences
prefs['api_presets'] = api_presets
settings.preferences = json.dumps(prefs, ensure_ascii=False)
await db.commit()
logger.info(f"用户 {user.user_id} 更新预设: {preset_id}")
return target_preset
@router.delete("/presets/{preset_id}")
async def delete_preset(
preset_id: str,
user: User = Depends(require_login),
db: AsyncSession = Depends(get_db)
):
"""
删除预设
从preferences字段的JSON中删除指定预设
"""
settings = await get_user_settings(user.user_id, db)
# 解析preferences
try:
prefs = json.loads(settings.preferences or '{}')
except json.JSONDecodeError:
raise HTTPException(status_code=500, detail="配置数据格式错误")
api_presets = prefs.get('api_presets', {'presets': [], 'version': '1.0'})
presets = api_presets.get('presets', [])
# 找到预设
target_preset = next((p for p in presets if p['id'] == preset_id), None)
if not target_preset:
raise HTTPException(status_code=404, detail="预设不存在")
# 检查是否是激活的预设
if target_preset.get('is_active'):
raise HTTPException(status_code=400, detail="无法删除激活中的预设,请先激活其他预设")
# 删除预设
presets = [p for p in presets if p['id'] != preset_id]
# 保存回preferences
api_presets['presets'] = presets
prefs['api_presets'] = api_presets
settings.preferences = json.dumps(prefs, ensure_ascii=False)
await db.commit()
logger.info(f"用户 {user.user_id} 删除预设: {preset_id}")
return {"message": "预设已删除", "preset_id": preset_id}
@router.post("/presets/{preset_id}/activate")
async def activate_preset(
preset_id: str,
user: User = Depends(require_login),
db: AsyncSession = Depends(get_db)
):
"""
激活预设
将预设的配置应用到Settings主字段
"""
settings = await get_user_settings(user.user_id, db)
# 解析preferences
try:
prefs = json.loads(settings.preferences or '{}')
except json.JSONDecodeError:
raise HTTPException(status_code=500, detail="配置数据格式错误")
api_presets = prefs.get('api_presets', {'presets': [], 'version': '1.0'})
presets = api_presets.get('presets', [])
# 找到目标预设
target_preset = next((p for p in presets if p['id'] == preset_id), None)
if not target_preset:
raise HTTPException(status_code=404, detail="预设不存在")
# 应用配置到Settings主字段
config = target_preset['config']
settings.api_provider = config['api_provider']
settings.api_key = config['api_key']
settings.api_base_url = config.get('api_base_url')
settings.llm_model = config['llm_model']
settings.temperature = config['temperature']
settings.max_tokens = config['max_tokens']
# 更新所有预设的is_active状态
for preset in presets:
preset['is_active'] = (preset['id'] == preset_id)
# 保存回preferences
prefs['api_presets'] = api_presets
settings.preferences = json.dumps(prefs, ensure_ascii=False)
await db.commit()
logger.info(f"用户 {user.user_id} 激活预设: {target_preset['name']}")
return {
"message": "预设已激活",
"preset_id": preset_id,
"preset_name": target_preset['name']
}
@router.post("/presets/{preset_id}/test")
async def test_preset(
preset_id: str,
user: User = Depends(require_login),
db: AsyncSession = Depends(get_db)
):
"""
测试预设的API连接
"""
settings = await get_user_settings(user.user_id, db)
# 解析preferences
try:
prefs = json.loads(settings.preferences or '{}')
except json.JSONDecodeError:
raise HTTPException(status_code=500, detail="配置数据格式错误")
api_presets = prefs.get('api_presets', {'presets': [], 'version': '1.0'})
presets = api_presets.get('presets', [])
# 找到预设
target_preset = next((p for p in presets if p['id'] == preset_id), None)
if not target_preset:
raise HTTPException(status_code=404, detail="预设不存在")
# 使用现有的test_api_connection逻辑
config = target_preset['config']
test_request = ApiTestRequest(
api_key=config['api_key'],
api_base_url=config.get('api_base_url', ''),
provider=config['api_provider'],
llm_model=config['llm_model']
)
logger.info(f"用户 {user.user_id} 测试预设: {target_preset['name']}")
return await test_api_connection(test_request)
@router.post("/presets/from-current", response_model=PresetResponse)
async def create_preset_from_current(
name: str,
description: Optional[str] = None,
user: User = Depends(require_login),
db: AsyncSession = Depends(get_db)
):
"""
从当前配置创建新预设
快捷方式将当前激活的配置保存为新预设
"""
settings = await get_user_settings(user.user_id, db)
# 从当前Settings主字段读取配置
current_config = APIKeyPresetConfig(
api_provider=settings.api_provider,
api_key=settings.api_key,
api_base_url=settings.api_base_url,
llm_model=settings.llm_model,
temperature=settings.temperature,
max_tokens=settings.max_tokens
)
# 创建预设
create_request = PresetCreateRequest(
name=name,
description=description,
config=current_config
)
logger.info(f"用户 {user.user_id} 从当前配置创建预设: {name}")
return await create_preset(create_request, user, db)
+65 -1
View File
@@ -5,6 +5,7 @@ from fastapi import APIRouter, HTTPException, Request, Depends
from pydantic import BaseModel
from typing import List, Optional
from app.user_manager import user_manager, User
from app.user_password import password_manager
router = APIRouter(prefix="/users", tags=["用户管理"])
@@ -29,6 +30,11 @@ class SetAdminRequest(BaseModel):
is_admin: bool
class ResetPasswordRequest(BaseModel):
user_id: str
new_password: Optional[str] = None # 如果为空则使用默认密码
@router.get("/current")
async def get_current_user(user: User = Depends(require_login)):
"""获取当前登录用户信息"""
@@ -122,4 +128,62 @@ async def get_user(
if not user:
raise HTTPException(status_code=404, detail="用户不存在")
return user.dict()
return user.dict()
@router.post("/reset-password")
async def reset_user_password(
data: ResetPasswordRequest,
admin_user: User = Depends(require_admin)
):
"""
重置用户密码仅管理员
如果提供了 new_password则设置为指定密码
如果未提供 new_password则重置为默认密码username@666
限制
- 不能重置自己的密码应该使用修改密码功能
"""
# 检查是否尝试重置自己的密码
if data.user_id == admin_user.user_id:
raise HTTPException(
status_code=400,
detail="不能重置自己的密码,请使用修改密码功能"
)
# 检查目标用户是否存在
target_user = await user_manager.get_user(data.user_id)
if not target_user:
raise HTTPException(
status_code=404,
detail="目标用户不存在"
)
# 重置密码
try:
actual_password = await password_manager.set_password(
target_user.user_id,
target_user.username,
data.new_password
)
# 如果使用了默认密码,返回密码供管理员告知用户
message = "密码重置成功"
response_data = {
"message": message,
"user_id": data.user_id,
"username": target_user.username
}
if not data.new_password:
response_data["default_password"] = actual_password
response_data["message"] = f"密码已重置为默认密码: {actual_password}"
return response_data
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"重置密码失败: {str(e)}"
)
File diff suppressed because it is too large Load Diff
+172 -67
View File
@@ -1,5 +1,5 @@
"""写作风格管理 API"""
from fastapi import APIRouter, Depends, HTTPException
from fastapi import APIRouter, Depends, HTTPException, Request
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, func, delete
from typing import List
@@ -15,62 +15,90 @@ from ..schemas.writing_style import (
WritingStyleListResponse,
SetDefaultStyleRequest
)
from ..services.prompt_service import WritingStyleManager
from ..logger import get_logger
router = APIRouter(prefix="/writing-styles", tags=["writing-styles"])
logger = get_logger(__name__)
def get_current_user_id(request: Request) -> str:
"""获取当前登录用户ID"""
user_id = getattr(request.state, 'user_id', None)
if not user_id:
raise HTTPException(status_code=401, detail="未登录")
return user_id
@router.get("/presets/list", response_model=List[dict])
async def get_preset_styles():
async def get_preset_styles(db: AsyncSession = Depends(get_db)):
"""
获取所有预设风格列表
获取所有预设风格列表从数据库读取
返回格式数组形式的预设风格列表
[
{"id": "natural", "name": "自然流畅", "description": "...", "prompt_content": "..."},
{"id": "classical", "name": "古典优雅", ...}
{"id": 1, "preset_id": "natural", "name": "自然流畅", "description": "...", "prompt_content": "..."},
{"id": 2, "preset_id": "classical", "name": "古典优雅", ...}
]
"""
presets = WritingStyleManager.get_all_presets()
# 将字典转换为数组,添加 id 字段
# 从数据库获取全局预设风格(user_id 为 NULL
result = await db.execute(
select(WritingStyle)
.where(WritingStyle.user_id.is_(None))
.order_by(WritingStyle.order_index)
)
preset_styles = result.scalars().all()
# 转换为响应格式
return [
{"id": preset_id, **preset_data}
for preset_id, preset_data in presets.items()
{
"id": style.id,
"preset_id": style.preset_id,
"name": style.name,
"description": style.description,
"prompt_content": style.prompt_content,
"style_type": style.style_type,
"order_index": style.order_index
}
for style in preset_styles
]
@router.post("", response_model=WritingStyleResponse, status_code=201)
async def create_writing_style(
style_data: WritingStyleCreate,
request: Request,
db: AsyncSession = Depends(get_db)
):
"""
创建新的写作风格
创建新的写作风格用户级别
- **基于预设创建**提供 preset_id系统会自动填充预设内容
- **基于预设创建**提供 preset_id系统会从数据库查询预设内容自动填充
- **完全自定义**不提供 preset_id需要手动填写所有字段
"""
# 验证项目是否存在
result = await db.execute(
select(Project).where(Project.id == style_data.project_id)
)
project = result.scalar_one_or_none()
if not project:
raise HTTPException(status_code=404, detail="项目不存在")
# 获取当前用户ID
user_id = get_current_user_id(request)
# 如果基于预设创建,获取预设内容
# 如果基于预设创建,从数据库获取预设内容
if style_data.preset_id:
preset = WritingStyleManager.get_preset_style(style_data.preset_id)
result = await db.execute(
select(WritingStyle)
.where(
WritingStyle.user_id.is_(None),
WritingStyle.preset_id == style_data.preset_id
)
)
preset = result.scalar_one_or_none()
if not preset:
raise HTTPException(status_code=400, detail=f"预设风格 '{style_data.preset_id}' 不存在")
# 使用预设内容填充(如果用户未提供)
if not style_data.name:
style_data.name = preset["name"]
style_data.name = preset.name
if not style_data.description:
style_data.description = preset["description"]
style_data.description = preset.description
if not style_data.prompt_content:
style_data.prompt_content = preset["prompt_content"]
style_data.prompt_content = preset.prompt_content
# 验证必填字段
if not style_data.name or not style_data.prompt_content:
@@ -79,16 +107,16 @@ async def create_writing_style(
detail="name 和 prompt_content 是必填字段"
)
# 获取当前最大 order_index
# 获取当前用户的最大 order_index
count_result = await db.execute(
select(func.count(WritingStyle.id))
.where(WritingStyle.project_id == style_data.project_id)
.where(WritingStyle.user_id == user_id)
)
max_order = count_result.scalar_one()
# 创建风格记录
new_style = WritingStyle(
project_id=style_data.project_id,
user_id=user_id,
name=style_data.name,
style_type=style_data.style_type or ("preset" if style_data.preset_id else "custom"),
preset_id=style_data.preset_id,
@@ -104,7 +132,7 @@ async def create_writing_style(
# 返回包含 is_default 字段的字典(新创建的风格默认不是默认风格)
return {
"id": new_style.id,
"project_id": new_style.project_id,
"user_id": new_style.user_id,
"name": new_style.name,
"style_type": new_style.style_type,
"preset_id": new_style.preset_id,
@@ -117,24 +145,85 @@ async def create_writing_style(
}
@router.get("/project/{project_id}", response_model=WritingStyleListResponse)
async def get_project_styles(
project_id: str,
@router.get("/user", response_model=WritingStyleListResponse)
async def get_user_styles(
request: Request,
db: AsyncSession = Depends(get_db)
):
"""
获取项目的所有可用写作风格
获取用户的所有可用写作风格
返回全局预设风格 + 项目的自定义风格
返回全局预设风格 + 用户的自定义风格
order_index 排序
"""
# 获取当前用户ID
user_id = get_current_user_id(request)
# 获取全局预设风格(user_id 为 NULL
result = await db.execute(
select(WritingStyle)
.where(WritingStyle.user_id.is_(None))
.order_by(WritingStyle.order_index)
)
preset_styles = list(result.scalars().all())
# 获取用户自定义风格
result = await db.execute(
select(WritingStyle)
.where(WritingStyle.user_id == user_id)
.order_by(WritingStyle.order_index)
)
custom_styles = list(result.scalars().all())
# 合并:预设风格 + 自定义风格
all_styles = preset_styles + custom_styles
# 转换为响应格式
styles_with_default = []
for style in all_styles:
style_dict = {
"id": style.id,
"user_id": style.user_id,
"name": style.name,
"style_type": style.style_type,
"preset_id": style.preset_id,
"description": style.description,
"prompt_content": style.prompt_content,
"order_index": style.order_index,
"created_at": style.created_at,
"updated_at": style.updated_at,
"is_default": False # 用户级别不再需要默认风格标记
}
styles_with_default.append(style_dict)
return {"styles": styles_with_default, "total": len(styles_with_default)}
@router.get("/project/{project_id}", response_model=WritingStyleListResponse)
async def get_project_styles(
project_id: str,
request: Request,
db: AsyncSession = Depends(get_db)
):
"""
获取项目可用的所有写作风格保留用于向后兼容
返回全局预设风格 + 该用户的自定义风格
order_index 排序并标记哪个是当前项目的默认风格
"""
# 验证项目是否存在
# 获取当前用户ID
user_id = get_current_user_id(request)
# 验证项目访问权限
result = await db.execute(
select(Project).where(Project.id == project_id)
select(Project).where(
Project.id == project_id,
Project.user_id == user_id
)
)
project = result.scalar_one_or_none()
if not project:
raise HTTPException(status_code=404, detail="项目不存在")
raise HTTPException(status_code=404, detail="项目不存在或无权访问")
# 获取该项目的默认风格ID
result = await db.execute(
@@ -143,18 +232,18 @@ async def get_project_styles(
)
default_style_id = result.scalar_one_or_none()
# 获取全局预设风格(project_id 为 NULL
# 获取全局预设风格(user_id 为 NULL
result = await db.execute(
select(WritingStyle)
.where(WritingStyle.project_id.is_(None))
.where(WritingStyle.user_id.is_(None))
.order_by(WritingStyle.order_index)
)
preset_styles = list(result.scalars().all())
# 获取项目自定义风格
# 获取用户自定义风格
result = await db.execute(
select(WritingStyle)
.where(WritingStyle.project_id == project_id)
.where(WritingStyle.user_id == user_id)
.order_by(WritingStyle.order_index)
)
custom_styles = list(result.scalars().all())
@@ -167,7 +256,7 @@ async def get_project_styles(
for style in all_styles:
style_dict = {
"id": style.id,
"project_id": style.project_id,
"user_id": style.user_id,
"name": style.name,
"style_type": style.style_type,
"preset_id": style.preset_id,
@@ -196,16 +285,16 @@ async def get_writing_style(
if not style:
raise HTTPException(status_code=404, detail="写作风格不存在")
# 检查是否有项目将其设置为默认风格
# 检查是否有项目将其设置为默认风格(一个风格可能被多个项目使用,使用 first() 避免 MultipleResultsFound
result = await db.execute(
select(ProjectDefaultStyle).where(ProjectDefaultStyle.style_id == style_id)
)
is_default = result.scalar_one_or_none() is not None
is_default = result.scalars().first() is not None
# 返回包含 is_default 字段的字典
return {
"id": style.id,
"project_id": style.project_id,
"user_id": style.user_id,
"name": style.name,
"style_type": style.style_type,
"preset_id": style.preset_id,
@@ -222,6 +311,7 @@ async def get_writing_style(
async def update_writing_style(
style_id: int,
style_data: WritingStyleUpdate,
request: Request,
db: AsyncSession = Depends(get_db)
):
"""
@@ -230,6 +320,9 @@ async def update_writing_style(
- 只能修改自定义风格
- 不能修改全局预设风格
"""
# 获取当前用户ID
user_id = get_current_user_id(request)
result = await db.execute(
select(WritingStyle).where(WritingStyle.id == style_id)
)
@@ -238,9 +331,13 @@ async def update_writing_style(
raise HTTPException(status_code=404, detail="写作风格不存在")
# 检查是否为全局预设风格(不允许修改)
if style.project_id is None:
if style.user_id is None:
raise HTTPException(status_code=403, detail="不能修改全局预设风格,只能修改自定义风格")
# 验证用户权限(只能修改自己的风格)
if style.user_id != user_id:
raise HTTPException(status_code=403, detail="无权修改其他用户的风格")
# 更新字段
update_data = style_data.model_dump(exclude_unset=True)
@@ -254,16 +351,16 @@ async def update_writing_style(
await db.commit()
await db.refresh(style)
# 检查是否有项目将其设置为默认风格
# 检查是否有项目将其设置为默认风格(一个风格可能被多个项目使用,使用 first() 避免 MultipleResultsFound
result = await db.execute(
select(ProjectDefaultStyle).where(ProjectDefaultStyle.style_id == style_id)
)
is_default = result.scalar_one_or_none() is not None
is_default = result.scalars().first() is not None
# 返回包含 is_default 字段的字典
return {
"id": style.id,
"project_id": style.project_id,
"user_id": style.user_id,
"name": style.name,
"style_type": style.style_type,
"preset_id": style.preset_id,
@@ -279,6 +376,7 @@ async def update_writing_style(
@router.delete("/{style_id}", status_code=204)
async def delete_writing_style(
style_id: int,
request: Request,
db: AsyncSession = Depends(get_db)
):
"""
@@ -289,6 +387,9 @@ async def delete_writing_style(
- 不能删除默认风格必须先设置其他风格为默认
- 删除后无法恢复
"""
# 获取当前用户ID
user_id = get_current_user_id(request)
result = await db.execute(
select(WritingStyle).where(WritingStyle.id == style_id)
)
@@ -297,14 +398,18 @@ async def delete_writing_style(
raise HTTPException(status_code=404, detail="写作风格不存在")
# 检查是否为全局预设风格(不允许删除)
if style.project_id is None:
if style.user_id is None:
raise HTTPException(status_code=403, detail="不能删除全局预设风格,只能删除自定义风格")
# 检查是否有项目将其设置为默认风格
# 验证用户权限(只能删除自己的风格
if style.user_id != user_id:
raise HTTPException(status_code=403, detail="无权删除其他用户的风格")
# 检查是否有项目将其设置为默认风格(一个风格可能被多个项目使用,使用 first() 避免 MultipleResultsFound
result = await db.execute(
select(ProjectDefaultStyle).where(ProjectDefaultStyle.style_id == style_id)
)
default_relation = result.scalar_one_or_none()
default_relation = result.scalars().first()
if default_relation:
raise HTTPException(
status_code=400,
@@ -321,6 +426,7 @@ async def delete_writing_style(
async def set_default_style(
style_id: int,
request_data: SetDefaultStyleRequest,
request: Request,
db: AsyncSession = Depends(get_db)
):
"""
@@ -335,13 +441,19 @@ async def set_default_style(
"""
project_id = request_data.project_id
# 验证项目是否存在
# 获取当前用户ID
user_id = get_current_user_id(request)
# 验证项目访问权限
result = await db.execute(
select(Project).where(Project.id == project_id)
select(Project).where(
Project.id == project_id,
Project.user_id == user_id
)
)
project = result.scalar_one_or_none()
if not project:
raise HTTPException(status_code=404, detail="项目不存在")
raise HTTPException(status_code=404, detail="项目不存在或无权访问")
# 验证风格是否存在
result = await db.execute(
@@ -351,9 +463,9 @@ async def set_default_style(
if not style:
raise HTTPException(status_code=404, detail="写作风格不存在")
# 验证风格是否属于该项目(自定义风格)或是全局预设风格
if style.project_id is not None and style.project_id != project_id:
raise HTTPException(status_code=403, detail="无权操作其他项目的风格")
# 验证风格是否属于该用户(自定义风格)或是全局预设风格
if style.user_id is not None and style.user_id != user_id:
raise HTTPException(status_code=403, detail="无权操作其他用户的风格")
# 使用 UPSERT 逻辑:先删除该项目的旧默认风格记录,再插入新的
await db.execute(
@@ -379,6 +491,7 @@ async def set_default_style(
@router.post("/project/{project_id}/init-defaults", response_model=WritingStyleListResponse)
async def initialize_default_styles(
project_id: str,
request: Request,
db: AsyncSession = Depends(get_db)
):
"""
@@ -387,13 +500,5 @@ async def initialize_default_styles(
新架构下预设风格是全局的不需要为每个项目单独初始化
该接口保留用于兼容性直接返回项目可用的所有风格
"""
# 验证项目是否存在
result = await db.execute(
select(Project).where(Project.id == project_id)
)
project = result.scalar_one_or_none()
if not project:
raise HTTPException(status_code=404, detail="项目不存在")
# 直接返回项目可用的所有风格(全局预设 + 项目自定义)
return await get_project_styles(project_id, db)
# 直接返回项目可用的所有风格(全局预设 + 用户自定义)
return await get_project_styles(project_id, request, db)
+33 -8
View File
@@ -3,6 +3,7 @@ from pydantic_settings import BaseSettings
from typing import Optional
from pathlib import Path
import logging
import os
# 获取项目根目录(从backend/app/config.py向上两级)
PROJECT_ROOT = Path(__file__).parent.parent
@@ -12,13 +13,11 @@ DATA_DIR.mkdir(exist_ok=True)
# 配置模块使用标准logging(在logger.py初始化之前)
config_logger = logging.getLogger(__name__)
# 数据库文件路径(绝对路径)
DB_FILE = DATA_DIR / "ai_story.db"
# 数据库配置:PostgreSQL
# 从环境变量获取数据库URL
DATABASE_URL = os.getenv("DATABASE_URL", "postgresql+asyncpg://mumuai:password@localhost:5432/mumuai_novel")
# 生成数据库URL(在类外部生成,确保使用绝对路径)
# 将Windows反斜杠转换为正斜杠,SQLite URL格式要求
DATABASE_URL = f"sqlite+aiosqlite:///{str(DB_FILE.absolute()).replace(chr(92), '/')}"
config_logger.debug(f"数据库文件路径: {DB_FILE}")
config_logger.debug(f"数据库类型: PostgreSQL")
config_logger.debug(f"数据库URL: {DATABASE_URL}")
class Settings(BaseSettings):
@@ -41,9 +40,31 @@ class Settings(BaseSettings):
# CORS配置
cors_origins: list[str] = ["http://localhost:8000", "http://127.0.0.1:8000"]
# 数据库配置 - 使用预先计算好的绝对路径URL
# 数据库配置 - PostgreSQL
database_url: str = DATABASE_URL
# PostgreSQL连接池配置(优化后支持150-200并发用户)
database_pool_size: int = 50 # 核心连接池大小(优化:从30提升到50)
database_max_overflow: int = 30 # 最大溢出连接数(优化:从20提升到30)
database_pool_timeout: int = 90 # 连接池超时秒数(优化:从60提升到90)
database_pool_recycle: int = 1800 # 连接回收时间秒数(30分钟,防止长时间连接失效)
database_pool_pre_ping: bool = True # 连接前ping检测,确保连接有效
database_pool_use_lifo: bool = True # 使用LIFO策略提高连接复用率
# 连接池高级配置
database_echo_pool: bool = False # 是否记录连接池日志(调试用)
database_pool_reset_on_return: str = "rollback" # 连接归还时的重置策略:rollback/commit/none
database_max_identifier_length: int = 128 # PostgreSQL标识符最大长度
# 会话监控配置
database_session_max_active: int = 50 # 活跃会话警告阈值(从100降低到50)
database_session_leak_threshold: int = 100 # 会话泄漏严重告警阈值
# 数据库监控配置
database_enable_slow_query_log: bool = True # 启用慢查询日志
database_slow_query_threshold: float = 1.0 # 慢查询阈值(秒)
database_enable_metrics: bool = True # 启用性能指标收集
# AI服务配置
openai_api_key: Optional[str] = None
openai_base_url: Optional[str] = None
@@ -54,7 +75,10 @@ class Settings(BaseSettings):
default_ai_provider: str = "openai"
default_model: str = "gpt-4"
default_temperature: float = 0.7
default_max_tokens: int = 2000
default_max_tokens: int = 32000
# MCP配置
mcp_max_rounds: int = 3 # MCP工具调用最大轮数(全局统一控制)
# LinuxDO OAuth2 配置
LINUXDO_CLIENT_ID: Optional[str] = None
@@ -85,6 +109,7 @@ class Settings(BaseSettings):
class Config:
env_file = ".env"
case_sensitive = False
extra = "ignore" # 忽略未定义的环境变量,避免验证错误
# 创建全局配置实例
+286 -179
View File
@@ -1,11 +1,10 @@
"""数据库连接和会话管理 - 支持多用户数据隔离"""
"""数据库连接和会话管理 - PostgreSQL 多用户数据隔离"""
import asyncio
from typing import Dict, Any
from datetime import datetime
from sqlalchemy import select, text
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker
from sqlalchemy.orm import declarative_base
from sqlalchemy.pool import StaticPool
from fastapi import Request, HTTPException
from app.config import settings
from app.logger import get_logger
@@ -21,7 +20,8 @@ from app.models import (
Project, Outline, Character, Chapter, GenerationHistory,
Settings, WritingStyle, ProjectDefaultStyle,
RelationshipType, CharacterRelationship, Organization, OrganizationMember,
StoryMemory, PlotAnalysis, AnalysisTask
StoryMemory, PlotAnalysis, AnalysisTask, BatchGenerationTask,
RegenerationTask, Career, CharacterCareer, User, MCPPlugin, PromptTemplate
)
# 引擎缓存:每个用户一个引擎
@@ -45,51 +45,101 @@ _session_stats = {
async def get_engine(user_id: str):
"""获取或创建用户专属的数据库引擎(线程安全)
PostgreSQL: 所有用户共享一个数据库通过user_id字段隔离数据
Args:
user_id: 用户ID
Returns:
用户专属的异步引擎
"""
if user_id in _engine_cache:
return _engine_cache[user_id]
# PostgreSQL模式:所有用户共享同一个引擎
cache_key = "shared_postgres"
if cache_key in _engine_cache:
return _engine_cache[cache_key]
async with _cache_lock:
if user_id not in _engine_locks:
_engine_locks[user_id] = asyncio.Lock()
user_lock = _engine_locks[user_id]
async with user_lock:
if user_id not in _engine_cache:
db_url = f"sqlite+aiosqlite:///data/ai_story_user_{user_id}.db"
engine = create_async_engine(
db_url,
echo=False,
future=True,
poolclass=StaticPool,
pool_pre_ping=True,
pool_recycle=3600,
connect_args={
"timeout": 30,
"check_same_thread": False
}
)
if cache_key not in _engine_cache:
# 检测数据库类型
is_sqlite = 'sqlite' in settings.database_url.lower()
try:
async with engine.begin() as conn:
await conn.execute(text("PRAGMA journal_mode=WAL"))
await conn.execute(text("PRAGMA synchronous=NORMAL"))
await conn.execute(text("PRAGMA cache_size=-64000"))
await conn.execute(text("PRAGMA temp_store=MEMORY"))
await conn.execute(text("PRAGMA busy_timeout=5000"))
# 基础引擎参数
engine_args = {
"echo": settings.database_echo_pool,
"echo_pool": settings.database_echo_pool,
"future": True,
}
if is_sqlite:
# SQLite 配置(使用 NullPool,不支持连接池参数)
engine_args["connect_args"] = {
"check_same_thread": False,
"timeout": 30.0, # 等待锁释放的超时时间(秒)
}
# 启用连接前检测以支持更好的并发
engine_args["pool_pre_ping"] = True
logger.info("📊 使用 SQLite 数据库(NullPool,超时30秒,WAL模式)")
else:
# PostgreSQL 配置(完整连接池支持)
connect_args = {
"server_settings": {
"application_name": settings.app_name,
"jit": "off",
"search_path": "public",
},
"command_timeout": 60,
"statement_cache_size": 500,
}
engine_args.update({
"pool_size": settings.database_pool_size,
"max_overflow": settings.database_max_overflow,
"pool_timeout": settings.database_pool_timeout,
"pool_pre_ping": settings.database_pool_pre_ping,
"pool_recycle": settings.database_pool_recycle,
"pool_use_lifo": settings.database_pool_use_lifo,
"pool_reset_on_return": settings.database_pool_reset_on_return,
"max_identifier_length": settings.database_max_identifier_length,
"connect_args": connect_args
})
total_connections = settings.database_pool_size + settings.database_max_overflow
estimated_concurrent_users = total_connections * 2
logger.info(
f"📊 PostgreSQL 连接池配置:\n"
f" ├─ 核心连接: {settings.database_pool_size}\n"
f" ├─ 溢出连接: {settings.database_max_overflow}\n"
f" ├─ 总连接数: {total_connections}\n"
f" ├─ 获取超时: {settings.database_pool_timeout}\n"
f" ├─ 连接回收: {settings.database_pool_recycle}\n"
f" └─ 预估并发: {estimated_concurrent_users}+用户"
)
engine = create_async_engine(settings.database_url, **engine_args)
_engine_cache[cache_key] = engine
# 如果是 SQLite,启用 WAL 模式以支持读写并发
if is_sqlite:
try:
from sqlalchemy import event
from sqlalchemy.pool import NullPool
logger.info(f"✅ 用户 {user_id} 的数据库已优化(WAL模式 + 64MB缓存)")
except Exception as e:
logger.warning(f"⚠️ 用户 {user_id} 数据库优化失败: {str(e)}")
_engine_cache[user_id] = engine
logger.info(f"为用户 {user_id} 创建数据库引擎")
@event.listens_for(engine.sync_engine, "connect")
def set_sqlite_pragma(dbapi_conn, connection_record):
cursor = dbapi_conn.cursor()
cursor.execute("PRAGMA journal_mode=WAL")
cursor.execute("PRAGMA synchronous=NORMAL")
cursor.execute("PRAGMA cache_size=-64000") # 64MB 缓存
cursor.execute("PRAGMA busy_timeout=30000") # 30秒超时
cursor.close()
logger.info("✅ SQLite WAL 模式已启用(支持读写并发)")
except Exception as e:
logger.warning(f"⚠️ 启用 WAL 模式失败: {e},使用默认配置")
return _engine_cache[user_id]
return _engine_cache[cache_key]
async def get_db(request: Request):
@@ -117,7 +167,7 @@ async def get_db(request: Request):
_session_stats["created"] += 1
_session_stats["active"] += 1
logger.debug(f"📊 会话创建 [User:{user_id}][ID:{session_id}] - 活跃:{_session_stats['active']}, 总创建:{_session_stats['created']}, 总关闭:{_session_stats['closed']}")
# logger.debug(f"📊 会话创建 [User:{user_id}][ID:{session_id}] - 活跃:{_session_stats['active']}, 总创建:{_session_stats['created']}, 总关闭:{_session_stats['closed']}")
try:
yield session
@@ -157,8 +207,11 @@ async def get_db(request: Request):
logger.debug(f"📊 会话关闭 [User:{user_id}][ID:{session_id}] - 活跃:{_session_stats['active']}, 总创建:{_session_stats['created']}, 总关闭:{_session_stats['closed']}, 错误:{_session_stats['errors']}")
if _session_stats["active"] > 100:
logger.warning(f"🚨 活跃会话数过多: {_session_stats['active']},可能存在连接泄漏!")
# 使用优化后的会话监控阈值
if _session_stats["active"] > settings.database_session_leak_threshold:
logger.error(f"🚨 严重告警:活跃会话数 {_session_stats['active']} 超过泄漏阈值 {settings.database_session_leak_threshold}")
elif _session_stats["active"] > settings.database_session_max_active:
logger.warning(f"⚠️ 警告:活跃会话数 {_session_stats['active']} 超过警告阈值 {settings.database_session_max_active},可能存在连接泄漏!")
elif _session_stats["active"] < 0:
logger.error(f"🚨 活跃会话数异常: {_session_stats['active']},统计可能不准确!")
@@ -170,147 +223,25 @@ async def get_db(request: Request):
except:
pass
async def _init_relationship_types(user_id: str):
"""为指定用户初始化预置的关系类型数据
async def init_db(user_id: str = None):
"""
初始化数据库已弃用
此函数已弃用仅保留用于向后兼容
新的最佳实践:
- 表结构管理: 使用 'alembic upgrade head'
- 用户配置: Settings 在首次访问时自动创建延迟初始化
Args:
user_id: 用户ID
user_id: 用户ID (已不再使用)
"""
from app.models.relationship import RelationshipType
relationship_types = [
{"name": "父亲", "category": "family", "reverse_name": "子女", "intimacy_range": "high", "icon": "👨"},
{"name": "母亲", "category": "family", "reverse_name": "子女", "intimacy_range": "high", "icon": "👩"},
{"name": "兄弟", "category": "family", "reverse_name": "兄弟", "intimacy_range": "high", "icon": "👬"},
{"name": "姐妹", "category": "family", "reverse_name": "姐妹", "intimacy_range": "high", "icon": "👭"},
{"name": "子女", "category": "family", "reverse_name": "父母", "intimacy_range": "high", "icon": "👶"},
{"name": "配偶", "category": "family", "reverse_name": "配偶", "intimacy_range": "high", "icon": "💑"},
{"name": "恋人", "category": "family", "reverse_name": "恋人", "intimacy_range": "high", "icon": "💕"},
{"name": "师父", "category": "social", "reverse_name": "徒弟", "intimacy_range": "high", "icon": "🎓"},
{"name": "徒弟", "category": "social", "reverse_name": "师父", "intimacy_range": "high", "icon": "📚"},
{"name": "朋友", "category": "social", "reverse_name": "朋友", "intimacy_range": "medium", "icon": "🤝"},
{"name": "同学", "category": "social", "reverse_name": "同学", "intimacy_range": "medium", "icon": "🎒"},
{"name": "邻居", "category": "social", "reverse_name": "邻居", "intimacy_range": "low", "icon": "🏘️"},
{"name": "知己", "category": "social", "reverse_name": "知己", "intimacy_range": "high", "icon": "💙"},
{"name": "上司", "category": "professional", "reverse_name": "下属", "intimacy_range": "low", "icon": "👔"},
{"name": "下属", "category": "professional", "reverse_name": "上司", "intimacy_range": "low", "icon": "💼"},
{"name": "同事", "category": "professional", "reverse_name": "同事", "intimacy_range": "medium", "icon": "🤵"},
{"name": "合作伙伴", "category": "professional", "reverse_name": "合作伙伴", "intimacy_range": "medium", "icon": "🤜🤛"},
{"name": "敌人", "category": "hostile", "reverse_name": "敌人", "intimacy_range": "low", "icon": "⚔️"},
{"name": "仇人", "category": "hostile", "reverse_name": "仇人", "intimacy_range": "low", "icon": "💢"},
{"name": "竞争对手", "category": "hostile", "reverse_name": "竞争对手", "intimacy_range": "low", "icon": "🎯"},
{"name": "宿敌", "category": "hostile", "reverse_name": "宿敌", "intimacy_range": "low", "icon": ""},
]
try:
engine = await get_engine(user_id)
AsyncSessionLocal = async_sessionmaker(
engine,
class_=AsyncSession,
expire_on_commit=False
)
async with AsyncSessionLocal() as session:
result = await session.execute(select(RelationshipType))
existing = result.scalars().first()
if existing:
logger.info(f"用户 {user_id} 的关系类型数据已存在,跳过初始化")
return
logger.info(f"开始为用户 {user_id} 插入关系类型数据...")
for rt_data in relationship_types:
relationship_type = RelationshipType(**rt_data)
session.add(relationship_type)
await session.commit()
logger.info(f"成功为用户 {user_id} 插入 {len(relationship_types)} 条关系类型数据")
except Exception as e:
logger.error(f"用户 {user_id} 初始化关系类型数据失败: {str(e)}", exc_info=True)
raise
async def _init_global_writing_styles(user_id: str):
"""为指定用户初始化全局预设写作风格
全局预设风格的 project_id NULL所有用户共享
只在第一次创建数据库时插入一次
Args:
user_id: 用户ID
"""
from app.models.writing_style import WritingStyle
from app.services.prompt_service import WritingStyleManager
try:
engine = await get_engine(user_id)
AsyncSessionLocal = async_sessionmaker(
engine,
class_=AsyncSession,
expire_on_commit=False
)
async with AsyncSessionLocal() as session:
# 检查是否已存在全局预设风格
result = await session.execute(
select(WritingStyle).where(WritingStyle.project_id.is_(None))
)
existing = result.scalars().first()
if existing:
logger.info(f"用户 {user_id} 的全局预设风格已存在,跳过初始化")
return
logger.info(f"开始为用户 {user_id} 插入全局预设写作风格...")
# 获取所有预设风格配置
presets = WritingStyleManager.get_all_presets()
for index, (preset_id, preset_data) in enumerate(presets.items(), start=1):
style = WritingStyle(
project_id=None, # NULL 表示全局预设
name=preset_data["name"],
style_type="preset",
preset_id=preset_id,
description=preset_data["description"],
prompt_content=preset_data["prompt_content"],
order_index=index
)
session.add(style)
await session.commit()
logger.info(f"成功为用户 {user_id} 插入 {len(presets)} 个全局预设写作风格")
except Exception as e:
logger.error(f"用户 {user_id} 初始化全局预设写作风格失败: {str(e)}", exc_info=True)
raise
async def init_db(user_id: str):
"""初始化指定用户的数据库,创建所有表并插入预置数据
Args:
user_id: 用户ID
"""
try:
logger.info(f"开始初始化用户 {user_id} 的数据库...")
engine = await get_engine(user_id)
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
await _init_relationship_types(user_id)
await _init_global_writing_styles(user_id)
logger.info(f"用户 {user_id} 的数据库初始化成功")
except Exception as e:
logger.error(f"用户 {user_id} 的数据库初始化失败: {str(e)}", exc_info=True)
raise
logger.warning(
"⚠️ init_db() 已弃用且无实际作用!\n"
" - 表结构: 由 Alembic 管理\n"
" - 用户配置: Settings API 自动创建\n"
" 建议移除此调用"
)
async def close_db():
@@ -324,4 +255,180 @@ async def close_db():
logger.info("所有数据库连接已关闭")
except Exception as e:
logger.error(f"关闭数据库连接失败: {str(e)}", exc_info=True)
raise
raise
async def get_database_stats():
"""获取数据库连接和会话统计信息
Returns:
dict: 包含数据库统计信息的字典
"""
from app.config import settings
# 获取连接池详细状态
pool_stats = {}
cache_key = "shared_postgres"
if cache_key in _engine_cache:
engine = _engine_cache[cache_key]
try:
pool = engine.pool
pool_stats = {
"size": pool.size(), # 当前连接池大小
"checked_in": pool.checkedin(), # 可用连接数
"checked_out": pool.checkedout(), # 正在使用的连接数
"overflow": pool.overflow(), # 溢出连接数
"usage_percent": (pool.checkedout() / (settings.database_pool_size + settings.database_max_overflow)) * 100,
}
except Exception as e:
logger.warning(f"获取连接池状态失败: {e}")
pool_stats = {"error": str(e)}
stats = {
"session_stats": {
"created": _session_stats["created"],
"closed": _session_stats["closed"],
"active": _session_stats["active"],
"errors": _session_stats["errors"],
"generator_exits": _session_stats["generator_exits"],
"last_check": _session_stats["last_check"],
},
"pool_stats": pool_stats, # 新增:连接池实时状态
"engine_cache": {
"total_engines": len(_engine_cache),
"engine_keys": list(_engine_cache.keys()),
},
"config": {
"database_type": "PostgreSQL",
"pool_size": settings.database_pool_size,
"max_overflow": settings.database_max_overflow,
"total_connections": settings.database_pool_size + settings.database_max_overflow,
"pool_timeout": settings.database_pool_timeout,
"pool_recycle": settings.database_pool_recycle,
"session_max_active_threshold": settings.database_session_max_active,
"session_leak_threshold": settings.database_session_leak_threshold,
},
"health": {
"status": "healthy",
"warnings": [],
"errors": [],
}
}
# 健康检查
if _session_stats["active"] > settings.database_session_leak_threshold:
stats["health"]["status"] = "critical"
stats["health"]["errors"].append(
f"活跃会话数 {_session_stats['active']} 超过泄漏阈值 {settings.database_session_leak_threshold}"
)
elif _session_stats["active"] > settings.database_session_max_active:
stats["health"]["status"] = "warning"
stats["health"]["warnings"].append(
f"活跃会话数 {_session_stats['active']} 超过警告阈值 {settings.database_session_max_active}"
)
if _session_stats["active"] < 0:
stats["health"]["status"] = "error"
stats["health"]["errors"].append(f"活跃会话数异常: {_session_stats['active']}")
# 连接池使用率检查
if pool_stats and "usage_percent" in pool_stats:
usage = pool_stats["usage_percent"]
if usage > 90:
stats["health"]["status"] = "warning"
stats["health"]["warnings"].append(f"连接池使用率过高: {usage:.1f}%")
elif usage > 95:
stats["health"]["status"] = "critical"
stats["health"]["errors"].append(f"连接池几乎耗尽: {usage:.1f}%")
error_rate = (_session_stats["errors"] / max(_session_stats["created"], 1)) * 100
if error_rate > 5:
if stats["health"]["status"] == "healthy":
stats["health"]["status"] = "warning"
stats["health"]["warnings"].append(f"会话错误率过高: {error_rate:.2f}%")
stats["health"]["error_rate"] = f"{error_rate:.2f}%"
return stats
async def check_database_health(user_id: str = None) -> dict:
"""检查数据库连接健康状态
Args:
user_id: 可选的用户ID如果提供则检查特定用户的数据库
Returns:
dict: 健康检查结果
"""
result = {
"healthy": True,
"checks": {},
"timestamp": datetime.now().isoformat()
}
try:
# 检查引擎是否存在
cache_key = "shared_postgres"
if user_id:
engine = await get_engine(user_id)
else:
if cache_key not in _engine_cache:
result["checks"]["engine"] = {"status": "not_initialized", "healthy": True}
return result
engine = _engine_cache[cache_key]
# 测试数据库连接
AsyncSessionLocal = async_sessionmaker(
engine,
class_=AsyncSession,
expire_on_commit=False
)
async with AsyncSessionLocal() as session:
# 执行简单查询测试连接
await session.execute(text("SELECT 1"))
result["checks"]["connection"] = {"status": "ok", "healthy": True}
# 检查连接池状态(仅PostgreSQL)
if hasattr(engine.pool, 'size'):
pool_status = {
"size": engine.pool.size(),
"checked_in": engine.pool.checkedin(),
"checked_out": engine.pool.checkedout(),
"overflow": engine.pool.overflow(),
"healthy": True
}
# 连接池健康检查
if engine.pool.overflow() >= settings.database_max_overflow:
pool_status["healthy"] = False
pool_status["warning"] = "连接池溢出已满"
result["healthy"] = False
result["checks"]["pool"] = pool_status
except Exception as e:
result["healthy"] = False
result["checks"]["error"] = {
"status": "error",
"message": str(e),
"healthy": False
}
logger.error(f"数据库健康检查失败: {str(e)}", exc_info=True)
return result
async def reset_session_stats():
"""重置会话统计信息(用于测试或维护)"""
global _session_stats
_session_stats = {
"created": 0,
"closed": 0,
"active": 0,
"errors": 0,
"generator_exits": 0,
"last_check": datetime.now().isoformat()
}
logger.info("✅ 会话统计信息已重置")
return _session_stats
+5 -1
View File
@@ -130,11 +130,15 @@ def _configure_third_party_loggers():
logging.getLogger('sqlalchemy.dialects').setLevel(logging.WARNING)
logging.getLogger('sqlalchemy.orm').setLevel(logging.WARNING)
# aiosqlite - 异步SQLite,禁用DEBUG日志
logging.getLogger('aiosqlite').setLevel(logging.WARNING)
# Watchfiles - 开发时的文件监控,降低级别
logging.getLogger('watchfiles').setLevel(logging.WARNING)
# httpx - HTTP客户端
# httpx/httpcore - HTTP客户端,禁用DEBUG日志
logging.getLogger('httpx').setLevel(logging.WARNING)
logging.getLogger('httpcore').setLevel(logging.WARNING)
# openai/anthropic - AI客户端库
logging.getLogger('openai').setLevel(logging.WARNING)
+24 -2
View File
@@ -12,6 +12,7 @@ from app.database import close_db, _session_stats
from app.logger import setup_logging, get_logger
from app.middleware import RequestIDMiddleware
from app.middleware.auth_middleware import AuthMiddleware
from app.mcp import mcp_client, register_status_sync
setup_logging(
level=config_settings.log_level,
@@ -26,10 +27,23 @@ logger = get_logger(__name__)
@asynccontextmanager
async def lifespan(app: FastAPI):
"""应用生命周期管理"""
logger.info("应用启动,等待用户登录...")
# 注册MCP状态同步服务
register_status_sync()
logger.info("应用启动完成")
yield
# 清理MCP插件
await mcp_client.cleanup()
# 清理HTTP客户端池
from app.services.ai_service import cleanup_http_clients
await cleanup_http_clients()
# 关闭数据库连接
await close_db()
logger.info("应用已关闭")
@@ -114,22 +128,30 @@ async def db_session_stats():
from app.api import (
projects, outlines, characters, chapters,
wizard_stream, relationships, organizations,
auth, users, settings, writing_styles, memories
auth, users, settings, writing_styles, memories,
mcp_plugins, admin, inspiration, prompt_templates,
changelog, careers
)
app.include_router(auth.router, prefix="/api")
app.include_router(users.router, prefix="/api")
app.include_router(settings.router, prefix="/api")
app.include_router(admin.router, prefix="/api")
app.include_router(projects.router, prefix="/api")
app.include_router(wizard_stream.router, prefix="/api")
app.include_router(inspiration.router, prefix="/api")
app.include_router(outlines.router, prefix="/api")
app.include_router(characters.router, prefix="/api")
app.include_router(careers.router, prefix="/api") # 职业管理API
app.include_router(chapters.router, prefix="/api")
app.include_router(relationships.router, prefix="/api")
app.include_router(organizations.router, prefix="/api")
app.include_router(writing_styles.router, prefix="/api")
app.include_router(memories.router) # 记忆管理API (已包含/api前缀)
app.include_router(mcp_plugins.router, prefix="/api") # MCP插件管理API
app.include_router(prompt_templates.router, prefix="/api") # 提示词模板管理API
app.include_router(changelog.router, prefix="/api") # 更新日志API
static_dir = Path(__file__).parent.parent / "static"
if static_dir.exists():
+36
View File
@@ -0,0 +1,36 @@
"""MCP模块 - 统一的MCP客户端管理
本模块提供MCPModel Context Protocol客户端的统一管理接口
推荐使用方式
from app.mcp import mcp_client, MCPPluginConfig
# 注册插件
await mcp_client.register(MCPPluginConfig(
user_id="user123",
plugin_name="exa-search",
url="http://localhost:8000/mcp"
))
# 获取工具
tools = await mcp_client.get_tools("user123", "exa-search")
# 调用工具
result = await mcp_client.call_tool("user123", "exa-search", "web_search", {"query": "..."})
# 注册状态变更回调
from app.mcp.status_sync import register_status_sync
register_status_sync()
"""
from .facade import mcp_client, MCPClientFacade, MCPPluginConfig, MCPError, PluginStatus
from .status_sync import register_status_sync
__all__ = [
"mcp_client",
"MCPClientFacade",
"MCPPluginConfig",
"MCPError",
"PluginStatus",
"register_status_sync",
]
+42
View File
@@ -0,0 +1,42 @@
"""MCP模块配置常量"""
from dataclasses import dataclass
@dataclass(frozen=True)
class MCPConfig:
"""MCP模块配置常量(不可变)"""
# 连接池配置
MAX_CLIENTS: int = 1000 # 最大客户端数量
CLIENT_TTL_SECONDS: int = 3600 # 客户端过期时间(1小时)
IDLE_TIMEOUT_SECONDS: int = 1800 # 空闲超时(30分钟)
# 健康检查配置
HEALTH_CHECK_INTERVAL_SECONDS: int = 60 # 健康检查间隔
ERROR_RATE_CRITICAL: float = 0.7 # 严重错误率阈值
ERROR_RATE_WARNING: float = 0.4 # 警告错误率阈值
MIN_REQUESTS_FOR_HEALTH_CHECK: int = 10 # 进行健康检查的最小请求数
# 清理任务配置
CLEANUP_INTERVAL_SECONDS: int = 300 # 清理任务间隔(5分钟)
# 缓存配置
TOOL_CACHE_TTL_MINUTES: int = 10 # 工具定义缓存TTL
# 重试配置
MAX_RETRIES: int = 3 # 最大重试次数
BASE_RETRY_DELAY_SECONDS: float = 1.0 # 基础重试延迟
MAX_RETRY_DELAY_SECONDS: float = 10.0 # 最大重试延迟
# 超时配置
DEFAULT_TIMEOUT_SECONDS: float = 60.0 # 默认超时时间
TOOL_CALL_TIMEOUT_SECONDS: float = 60.0 # 工具调用超时时间
# 日志配置
LOG_TOOL_ARGUMENTS: bool = True # 是否记录工具参数
LOG_TOOL_RESULTS: bool = False # 是否记录工具结果(可能很大)
# 全局配置实例
mcp_config = MCPConfig()
File diff suppressed because it is too large Load Diff
+50
View File
@@ -0,0 +1,50 @@
"""MCP插件状态同步服务
将内存中的会话状态变更同步到数据库确保状态一致性
"""
from typing import Dict, Any
from sqlalchemy import update
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from app.models.mcp_plugin import MCPPlugin
from app.logger import get_logger
logger = get_logger(__name__)
async def sync_status_to_db(event: Dict[str, Any]):
"""
状态变更回调 - 同步到数据库
"""
user_id = event["user_id"]
plugin_name = event["plugin_name"]
new_status = event["new_status"]
reason = event.get("reason", "")
try:
from app.database import get_engine
engine = await get_engine(user_id)
AsyncSessionLocal = async_sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
async with AsyncSessionLocal() as db:
stmt = (
update(MCPPlugin)
.where(MCPPlugin.user_id == user_id, MCPPlugin.plugin_name == plugin_name)
.values(status=new_status, last_error=reason if new_status == "error" else None)
)
await db.execute(stmt)
await db.commit()
logger.debug(f"✅ 状态已同步到数据库: {plugin_name} -> {new_status}")
except Exception as e:
logger.error(f"❌ 状态同步失败: {plugin_name}, 错误: {e}")
def register_status_sync():
"""注册状态同步回调到MCP客户端"""
from app.mcp import mcp_client
mcp_client.register_status_callback(sync_status_to_db)
logger.info("✅ MCP状态同步服务已注册")
+16 -4
View File
@@ -1,9 +1,12 @@
"""
认证中间件 - Cookie 中提取用户信息并注入到 request.state
"""
from fastapi import Request
from fastapi import Request, HTTPException
from starlette.middleware.base import BaseHTTPMiddleware
from app.user_manager import user_manager
from app.logger import get_logger
logger = get_logger(__name__)
class AuthMiddleware(BaseHTTPMiddleware):
@@ -20,9 +23,18 @@ class AuthMiddleware(BaseHTTPMiddleware):
if user_id:
user = await user_manager.get_user(user_id)
if user:
request.state.user_id = user_id
request.state.user = user
request.state.is_admin = user.is_admin
# 检查用户是否被禁用 (trust_level = -1)
if user.trust_level == -1:
logger.warning(f"禁用用户尝试访问: {user_id} ({user.username})")
# 清除用户状态,视为未登录
request.state.user_id = None
request.state.user = None
request.state.is_admin = False
else:
# 用户正常,注入状态
request.state.user_id = user_id
request.state.user = user
request.state.is_admin = user.is_admin
else:
# 用户不存在,清除状态
request.state.user_id = None
+26 -17
View File
@@ -1,35 +1,44 @@
"""数据模型"""
"""数据模型导出"""
from app.models.project import Project
from app.models.outline import Outline
from app.models.character import Character
from app.models.chapter import Chapter
from app.models.character import Character
from app.models.relationship import CharacterRelationship, Organization, OrganizationMember, RelationshipType
from app.models.generation_history import GenerationHistory
from app.models.analysis_task import AnalysisTask
from app.models.batch_generation_task import BatchGenerationTask
from app.models.settings import Settings
from app.models.memory import StoryMemory, PlotAnalysis
from app.models.writing_style import WritingStyle
from app.models.project_default_style import ProjectDefaultStyle
from app.models.relationship import (
RelationshipType,
CharacterRelationship,
Organization,
OrganizationMember
)
from app.models.memory import StoryMemory, PlotAnalysis
from app.models.analysis_task import AnalysisTask
from app.models.mcp_plugin import MCPPlugin
from app.models.user import User, UserPassword
from app.models.regeneration_task import RegenerationTask
from app.models.career import Career, CharacterCareer
from app.models.prompt_template import PromptTemplate
__all__ = [
"Project",
"Outline",
"Character",
"Chapter",
"GenerationHistory",
"Settings",
"WritingStyle",
"ProjectDefaultStyle",
"RelationshipType",
"Character",
"CharacterRelationship",
"Organization",
"OrganizationMember",
"RelationshipType",
"GenerationHistory",
"AnalysisTask",
"BatchGenerationTask",
"Settings",
"StoryMemory",
"PlotAnalysis",
"AnalysisTask",
"WritingStyle",
"ProjectDefaultStyle",
"MCPPlugin",
"User",
"UserPassword",
"RegenerationTask",
"Career",
"CharacterCareer",
"PromptTemplate"
]
@@ -0,0 +1,43 @@
"""批量生成任务数据模型"""
from sqlalchemy import Column, String, Integer, DateTime, Boolean, JSON
from sqlalchemy.sql import func
from app.database import Base
import uuid
class BatchGenerationTask(Base):
"""批量生成任务表"""
__tablename__ = "batch_generation_tasks"
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
project_id = Column(String(36), nullable=False, comment="项目ID")
user_id = Column(String(100), nullable=False, comment="用户ID")
# 任务配置
start_chapter_number = Column(Integer, nullable=False, comment="起始章节序号")
chapter_count = Column(Integer, nullable=False, comment="生成章节数量")
chapter_ids = Column(JSON, nullable=False, comment="待生成的章节ID列表")
style_id = Column(Integer, comment="使用的写作风格ID")
target_word_count = Column(Integer, default=3000, comment="目标字数")
enable_analysis = Column(Boolean, default=False, comment="是否启用同步分析")
# 任务状态
status = Column(String(20), default="pending", comment="任务状态: pending/running/completed/failed/cancelled")
total_chapters = Column(Integer, default=0, comment="总章节数")
completed_chapters = Column(Integer, default=0, comment="已完成章节数")
failed_chapters = Column(JSON, default=list, comment="失败的章节信息列表")
current_chapter_id = Column(String(36), comment="当前正在生成的章节ID")
current_chapter_number = Column(Integer, comment="当前正在生成的章节序号")
current_retry_count = Column(Integer, default=0, comment="当前章节重试次数")
max_retries = Column(Integer, default=3, comment="最大重试次数")
# 时间记录
created_at = Column(DateTime, server_default=func.now(), comment="创建时间")
started_at = Column(DateTime, comment="开始时间")
completed_at = Column(DateTime, comment="完成时间")
# 错误信息
error_message = Column(String(500), comment="错误信息")
def __repr__(self):
return f"<BatchGenerationTask(id={self.id}, status={self.status}, completed={self.completed_chapters}/{self.total_chapters})>"
+77
View File
@@ -0,0 +1,77 @@
"""职业数据模型"""
from sqlalchemy import Column, String, Text, DateTime, Integer, ForeignKey, Index
from sqlalchemy.sql import func
from app.database import Base
import uuid
class Career(Base):
"""职业表"""
__tablename__ = "careers"
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
project_id = Column(String(36), ForeignKey("projects.id", ondelete="CASCADE"), nullable=False)
# 基本信息
name = Column(String(100), nullable=False, comment="职业名称")
type = Column(String(20), nullable=False, comment="职业类型: main(主职业)/sub(副职业)")
description = Column(Text, comment="职业描述")
category = Column(String(50), comment="职业分类(如:战斗系、生产系、辅助系)")
# 阶段设定
stages = Column(Text, nullable=False, comment="职业阶段列表(JSON): [{level:1, name:'', description:''}, ...]")
max_stage = Column(Integer, nullable=False, default=10, comment="最大阶段数")
# 职业特性
requirements = Column(Text, comment="职业要求/限制")
special_abilities = Column(Text, comment="特殊能力描述")
worldview_rules = Column(Text, comment="世界观规则关联")
# 职业属性加成(可选,JSON格式)
attribute_bonuses = Column(Text, comment="属性加成(JSON): {strength: '+10%', intelligence: '+5%'}")
# 元数据
source = Column(String(20), default='ai', comment="来源: ai/manual")
created_at = Column(DateTime, server_default=func.now(), comment="创建时间")
updated_at = Column(DateTime, server_default=func.now(), onupdate=func.now(), comment="更新时间")
__table_args__ = (
Index('idx_project_id', 'project_id'),
Index('idx_type', 'type'),
)
def __repr__(self):
return f"<Career(id={self.id}, name={self.name}, type={self.type})>"
class CharacterCareer(Base):
"""角色职业关联表"""
__tablename__ = "character_careers"
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
character_id = Column(String(36), ForeignKey("characters.id", ondelete="CASCADE"), nullable=False)
career_id = Column(String(36), ForeignKey("careers.id", ondelete="CASCADE"), nullable=False)
career_type = Column(String(20), nullable=False, comment="main(主职业)/sub(副职业)")
# 阶段进度
current_stage = Column(Integer, nullable=False, default=1, comment="当前阶段(对应职业中的数值)")
stage_progress = Column(Integer, default=0, comment="阶段内进度(0-100")
# 时间记录
started_at = Column(String(100), comment="开始修炼时间(小说时间线)")
reached_current_stage_at = Column(String(100), comment="到达当前阶段时间")
# 备注
notes = Column(Text, comment="备注(如:修炼心得、特殊事件)")
created_at = Column(DateTime, server_default=func.now(), comment="创建时间")
updated_at = Column(DateTime, server_default=func.now(), onupdate=func.now(), comment="更新时间")
__table_args__ = (
Index('idx_character_id', 'character_id'),
Index('idx_career_type', 'career_type'),
Index('idx_character_career', 'character_id', 'career_id', unique=True),
)
def __repr__(self):
return f"<CharacterCareer(character_id={self.character_id}, career_id={self.career_id}, type={self.career_type})>"
+9 -1
View File
@@ -17,8 +17,16 @@ class Chapter(Base):
summary = Column(Text, comment="章节摘要")
word_count = Column(Integer, default=0, comment="字数统计")
status = Column(String(20), default="draft", comment="章节状态")
# 大纲关联字段(实现一对多关系)
outline_id = Column(String(36), ForeignKey("outlines.id", ondelete="SET NULL"), nullable=True, comment="关联的大纲ID")
sub_index = Column(Integer, default=1, comment="大纲下的子章节序号")
# 大纲展开规划数据(JSON格式)
expansion_plan = Column(Text, comment="展开规划详情(JSON): 包含key_events, character_focus, emotional_tone等")
created_at = Column(DateTime, server_default=func.now(), comment="创建时间")
updated_at = Column(DateTime, server_default=func.now(), onupdate=func.now(), comment="更新时间")
def __repr__(self):
return f"<Chapter(id={self.id}, chapter_number={self.chapter_number}, title={self.title})>"
return f"<Chapter(id={self.id}, chapter_number={self.chapter_number}, title={self.title}, outline_id={self.outline_id})>"
+8 -3
View File
@@ -1,5 +1,5 @@
"""角色数据模型"""
from sqlalchemy import Column, String, Text, DateTime, ForeignKey, Boolean
from sqlalchemy import Column, String, Text, DateTime, ForeignKey, Boolean, Integer
from sqlalchemy.sql import func
from app.database import Base
import uuid
@@ -14,8 +14,8 @@ class Character(Base):
# 基本信息
name = Column(String(100), nullable=False, comment="角色/组织名称")
age = Column(String(20), comment="年龄")
gender = Column(String(20), comment="性别")
age = Column(String(50), comment="年龄")
gender = Column(String(50), comment="性别")
is_organization = Column(Boolean, default=False, comment="是否为组织")
# 角色类型:protagonist(主角)/supporting(配角)/antagonist(反派)
@@ -32,6 +32,11 @@ class Character(Base):
organization_purpose = Column(String(500), comment="组织目的")
organization_members = Column(Text, comment="组织成员(JSON)")
# 职业相关字段(冗余字段,用于提升查询性能)
main_career_id = Column(String(36), ForeignKey("careers.id", ondelete="SET NULL"), comment="主职业ID")
main_career_stage = Column(Integer, comment="主职业当前阶段")
sub_careers = Column(Text, comment="副职业列表(JSON): [{\"career_id\": \"xxx\", \"stage\": 3}, ...]")
# 其他
avatar_url = Column(String(500), comment="头像URL")
traits = Column(Text, comment="特征标签(JSON)")
+52
View File
@@ -0,0 +1,52 @@
"""MCP插件配置数据模型"""
from sqlalchemy import Column, String, Text, Boolean, Integer, DateTime, Index, JSON
from sqlalchemy.sql import func
from app.database import Base
import uuid
class MCPPlugin(Base):
"""MCP插件配置表"""
__tablename__ = "mcp_plugins"
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
user_id = Column(String(50), nullable=False, index=True, comment="用户ID")
# 插件基本信息
plugin_name = Column(String(100), nullable=False, comment="插件名称(唯一标识)")
display_name = Column(String(200), nullable=False, comment="显示名称")
description = Column(Text, comment="插件描述")
plugin_type = Column(String(50), default="http", comment="插件类型:http/stdio")
# 连接配置
server_url = Column(String(500), comment="服务器URLHTTP类型)")
command = Column(String(500), comment="启动命令(stdio类型)")
args = Column(JSON, comment="命令参数(stdio类型)")
env = Column(JSON, comment="环境变量")
headers = Column(JSON, comment="HTTP请求头")
# 插件配置
config = Column(JSON, comment="插件特定配置(JSON")
tools = Column(JSON, comment="提供的工具列表")
# 状态管理
enabled = Column(Boolean, default=True, comment="是否启用")
status = Column(String(50), default="inactive", comment="状态:active/inactive/error")
last_error = Column(Text, comment="最后错误信息")
last_test_at = Column(DateTime, comment="最后测试时间")
# 排序和分组
category = Column(String(100), default="general", comment="分类")
sort_order = Column(Integer, default=0, comment="排序顺序")
# 时间戳
created_at = Column(DateTime, server_default=func.now(), comment="创建时间")
updated_at = Column(DateTime, server_default=func.now(), onupdate=func.now(), comment="更新时间")
__table_args__ = (
Index('idx_user_plugin', 'user_id', 'plugin_name', unique=True),
Index('idx_user_enabled', 'user_id', 'enabled'),
)
def __repr__(self):
return f"<MCPPlugin(id={self.id}, name={self.plugin_name}, enabled={self.enabled})>"
+2 -2
View File
@@ -9,7 +9,7 @@ class StoryMemory(Base):
"""故事记忆表 - 存储结构化的故事片段和元数据"""
__tablename__ = "story_memories"
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
id = Column(String(100), primary_key=True, default=lambda: str(uuid.uuid4()))
project_id = Column(String(36), ForeignKey("projects.id", ondelete="CASCADE"), nullable=False, index=True)
chapter_id = Column(String(36), ForeignKey("chapters.id", ondelete="CASCADE"), nullable=True, index=True)
@@ -45,7 +45,7 @@ class StoryMemory(Base):
# 伏笔相关字段
is_foreshadow = Column(Integer, default=0, comment="伏笔状态: 0=普通记忆, 1=已埋下伏笔, 2=伏笔已回收")
foreshadow_resolved_at = Column(String(36), ForeignKey("chapters.id", ondelete="SET NULL"), comment="伏笔回收的章节ID")
foreshadow_resolved_at = Column(String(100), ForeignKey("chapters.id", ondelete="SET NULL"), comment="伏笔回收的章节ID")
foreshadow_strength = Column(Float, comment="伏笔强度 0.0-1.0")
# 向量数据库关联
+10 -1
View File
@@ -1,5 +1,5 @@
"""项目数据模型"""
from sqlalchemy import Column, String, Text, DateTime, Integer
from sqlalchemy import Column, String, Text, DateTime, Integer, CheckConstraint
from sqlalchemy.sql import func
from app.database import Base
import uuid
@@ -10,6 +10,7 @@ class Project(Base):
__tablename__ = "projects"
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
user_id = Column(String(100), nullable=False, index=True, comment="用户ID")
title = Column(String(200), nullable=False, comment="项目标题")
description = Column(Text, comment="项目简介")
theme = Column(Text, comment="主题")
@@ -19,6 +20,7 @@ class Project(Base):
status = Column(String(20), default="planning", comment="创作状态")
wizard_status = Column(String(20), default="incomplete", comment="向导完成状态: incomplete/completed")
wizard_step = Column(Integer, default=0, comment="向导当前步骤: 0-4")
outline_mode = Column(String(20), nullable=False, default="one-to-many", comment="大纲章节模式: one-to-one(传统模式) 或 one-to-many(细化模式)")
# 世界构建字段
world_time_period = Column(Text, comment="时间背景")
@@ -34,5 +36,12 @@ class Project(Base):
created_at = Column(DateTime, server_default=func.now(), comment="创建时间")
updated_at = Column(DateTime, server_default=func.now(), onupdate=func.now(), comment="更新时间")
__table_args__ = (
CheckConstraint(
"outline_mode IN ('one-to-one', 'one-to-many')",
name='check_outline_mode'
),
)
def __repr__(self):
return f"<Project(id={self.id}, title={self.title})>"
+30
View File
@@ -0,0 +1,30 @@
"""提示词模板数据模型"""
from sqlalchemy import Column, String, Text, Boolean, DateTime, Index
from sqlalchemy.sql import func
from app.database import Base
import uuid
class PromptTemplate(Base):
"""提示词模板表"""
__tablename__ = "prompt_templates"
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
user_id = Column(String(50), nullable=False, index=True, comment="用户ID")
template_key = Column(String(100), nullable=False, comment="模板键名")
template_name = Column(String(200), nullable=False, comment="模板显示名称")
template_content = Column(Text, nullable=False, comment="模板内容")
description = Column(Text, comment="模板描述")
category = Column(String(50), comment="模板分类")
parameters = Column(Text, comment="模板参数定义(JSON)")
is_active = Column(Boolean, default=True, comment="是否启用")
is_system_default = Column(Boolean, default=False, comment="是否为系统默认模板")
created_at = Column(DateTime, server_default=func.now(), comment="创建时间")
updated_at = Column(DateTime, server_default=func.now(), onupdate=func.now(), comment="更新时间")
__table_args__ = (
Index('idx_user_template', 'user_id', 'template_key', unique=True),
)
def __repr__(self):
return f"<PromptTemplate(id={self.id}, user_id={self.user_id}, template_key={self.template_key})>"
+51
View File
@@ -0,0 +1,51 @@
"""章节重新生成任务模型"""
from sqlalchemy import Column, String, Text, Integer, DateTime, ForeignKey, JSON, Boolean
from sqlalchemy.sql import func
from app.database import Base
import uuid
class RegenerationTask(Base):
"""章节重新生成任务表"""
__tablename__ = "regeneration_tasks"
# 基本信息
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
chapter_id = Column(String(36), ForeignKey('chapters.id', ondelete='CASCADE'), nullable=False, index=True)
analysis_id = Column(String(36), nullable=True, comment="关联的分析结果ID")
user_id = Column(String(50), nullable=False, index=True)
project_id = Column(String(36), nullable=False, index=True)
# 修改指令
modification_instructions = Column(Text, nullable=False, comment="综合修改指令")
original_suggestions = Column(JSON, comment="来自分析的原始建议列表")
selected_suggestion_indices = Column(JSON, comment="用户选择的建议索引")
custom_instructions = Column(Text, comment="用户自定义修改意见")
# 生成参数
style_id = Column(Integer, nullable=True, comment="写作风格ID")
target_word_count = Column(Integer, default=3000, comment="目标字数")
focus_areas = Column(JSON, comment="重点优化方向")
preserve_elements = Column(JSON, comment="需要保留的元素配置")
# 状态跟踪
status = Column(String(20), default='pending', comment="pending/running/completed/failed")
progress = Column(Integer, default=0, comment="进度 0-100")
error_message = Column(Text, nullable=True)
# 内容版本
original_content = Column(Text, comment="原始章节内容快照")
original_word_count = Column(Integer, comment="原始字数")
regenerated_content = Column(Text, comment="重新生成的内容")
regenerated_word_count = Column(Integer, comment="新内容字数")
version_number = Column(Integer, default=1, comment="版本号")
version_note = Column(String(500), comment="版本说明")
# 时间戳
created_at = Column(DateTime, server_default=func.now())
started_at = Column(DateTime, nullable=True)
completed_at = Column(DateTime, nullable=True)
def __repr__(self):
return f"<RegenerationTask(id={self.id[:8]}..., chapter_id={self.chapter_id[:8]}..., status={self.status})>"
+2 -2
View File
@@ -38,7 +38,7 @@ class CharacterRelationship(Base):
relationship_name = Column(String(100), comment="自定义关系名称")
# 关系属性
intimacy_level = Column(Integer, default=50, comment="亲密度:0-100")
intimacy_level = Column(Integer, default=50, comment="亲密度:-100到100")
status = Column(String(20), default="active", comment="状态:active/broken/past/complicated")
description = Column(Text, comment="关系详细描述")
@@ -75,7 +75,7 @@ class Organization(Base):
# 组织特色
motto = Column(String(200), comment="宗旨/口号")
color = Column(String(20), comment="代表颜色")
color = Column(String(100), comment="代表颜色")
created_at = Column(DateTime, server_default=func.now(), comment="创建时间")
updated_at = Column(DateTime, server_default=func.now(), onupdate=func.now(), comment="更新时间")
+1
View File
@@ -17,6 +17,7 @@ class Settings(Base):
llm_model = Column(String(100), default="gpt-4", comment="模型名称")
temperature = Column(Float, default=0.7, comment="温度参数")
max_tokens = Column(Integer, default=2000, comment="最大token数")
system_prompt = Column(Text, comment="系统级别提示词,每次AI调用都会使用")
preferences = Column(Text, comment="其他偏好设置(JSON)")
created_at = Column(DateTime, server_default=func.now(), comment="创建时间")
updated_at = Column(DateTime, server_default=func.now(), onupdate=func.now(), comment="更新时间")
+47
View File
@@ -0,0 +1,47 @@
"""
用户数据模型 - 存储用户基本信息
"""
from sqlalchemy import Column, String, Integer, Boolean, DateTime
from sqlalchemy.sql import func
from app.database import Base
class User(Base):
"""用户模型 - 存储OAuth和本地用户信息"""
__tablename__ = "users"
user_id = Column(String(100), primary_key=True, index=True, comment="用户ID,格式:linuxdo_{id} 或 local_{id}")
username = Column(String(100), nullable=False, index=True, comment="用户名")
display_name = Column(String(200), nullable=False, comment="显示名称")
avatar_url = Column(String(500), nullable=True, comment="头像URL")
trust_level = Column(Integer, default=0, comment="信任等级(仅用于显示)")
is_admin = Column(Boolean, default=False, comment="是否为管理员")
linuxdo_id = Column(String(100), nullable=False, unique=True, index=True, comment="LinuxDO用户ID或本地用户ID")
created_at = Column(DateTime(timezone=True), server_default=func.now(), comment="创建时间")
last_login = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), comment="最后登录时间")
def to_dict(self):
"""转换为字典"""
return {
"user_id": self.user_id,
"username": self.username,
"display_name": self.display_name,
"avatar_url": self.avatar_url,
"trust_level": self.trust_level,
"is_admin": self.is_admin,
"linuxdo_id": self.linuxdo_id,
"created_at": self.created_at.isoformat() if self.created_at else None,
"last_login": self.last_login.isoformat() if self.last_login else None,
}
class UserPassword(Base):
"""用户密码模型 - 存储用户密码信息"""
__tablename__ = "user_passwords"
user_id = Column(String(100), primary_key=True, index=True, comment="用户ID")
username = Column(String(100), nullable=False, comment="用户名")
password_hash = Column(String(64), nullable=False, comment="密码哈希(SHA256")
has_custom_password = Column(Boolean, default=False, comment="是否为自定义密码")
created_at = Column(DateTime(timezone=True), server_default=func.now(), comment="创建时间")
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), comment="更新时间")
+2 -2
View File
@@ -9,7 +9,7 @@ class WritingStyle(Base):
__tablename__ = "writing_styles"
id = Column(Integer, primary_key=True, autoincrement=True)
project_id = Column(String(36), ForeignKey("projects.id", ondelete="CASCADE"), nullable=True, comment="所属项目IDNULL表示全局预设风格)")
user_id = Column(String(255), ForeignKey("users.user_id", ondelete="CASCADE"), nullable=True, comment="所属用户IDNULL表示全局预设风格)")
name = Column(String(100), nullable=False, comment="风格名称")
style_type = Column(String(50), nullable=False, comment="风格类型:preset/custom")
preset_id = Column(String(50), comment="预设风格IDnatural/classical/modern等")
@@ -20,4 +20,4 @@ class WritingStyle(Base):
updated_at = Column(DateTime, server_default=func.now(), onupdate=func.now(), comment="更新时间")
def __repr__(self):
return f"<WritingStyle(id={self.id}, name={self.name}, project_id={self.project_id})>"
return f"<WritingStyle(id={self.id}, name={self.name}, user_id={self.user_id})>"
+154
View File
@@ -0,0 +1,154 @@
"""职业相关的Pydantic模型"""
from pydantic import BaseModel, Field, ConfigDict
from typing import Optional, List, Dict, Any
from datetime import datetime
class CareerStage(BaseModel):
"""职业阶段模型"""
level: int = Field(..., description="阶段等级")
name: str = Field(..., description="阶段名称")
description: Optional[str] = Field(None, description="阶段描述")
class CareerBase(BaseModel):
"""职业基础模型"""
name: str = Field(..., description="职业名称")
type: str = Field(..., description="职业类型: main(主职业)/sub(副职业)")
description: Optional[str] = Field(None, description="职业描述")
category: Optional[str] = Field(None, description="职业分类")
stages: List[CareerStage] = Field(..., description="职业阶段列表")
max_stage: int = Field(10, description="最大阶段数")
requirements: Optional[str] = Field(None, description="职业要求/限制")
special_abilities: Optional[str] = Field(None, description="特殊能力描述")
worldview_rules: Optional[str] = Field(None, description="世界观规则关联")
attribute_bonuses: Optional[Dict[str, str]] = Field(None, description="属性加成")
class CareerCreate(CareerBase):
"""创建职业的请求模型"""
project_id: str = Field(..., description="项目ID")
source: str = Field("manual", description="来源: ai/manual")
class CareerUpdate(BaseModel):
"""更新职业的请求模型"""
name: Optional[str] = None
type: Optional[str] = None
description: Optional[str] = None
category: Optional[str] = None
stages: Optional[List[CareerStage]] = None
max_stage: Optional[int] = None
requirements: Optional[str] = None
special_abilities: Optional[str] = None
worldview_rules: Optional[str] = None
attribute_bonuses: Optional[Dict[str, str]] = None
class CareerResponse(BaseModel):
"""职业响应模型"""
id: str
project_id: str
name: str
type: str
description: Optional[str] = None
category: Optional[str] = None
stages: List[CareerStage]
max_stage: int
requirements: Optional[str] = None
special_abilities: Optional[str] = None
worldview_rules: Optional[str] = None
attribute_bonuses: Optional[Dict[str, str]] = None
source: str
created_at: datetime
updated_at: datetime
model_config = ConfigDict(from_attributes=True)
class CareerListResponse(BaseModel):
"""职业列表响应模型"""
total: int
main_careers: List[CareerResponse] = Field(default_factory=list, description="主职业列表")
sub_careers: List[CareerResponse] = Field(default_factory=list, description="副职业列表")
class CareerGenerateRequest(BaseModel):
"""AI生成职业体系的请求模型"""
project_id: str = Field(..., description="项目ID")
main_career_count: int = Field(5, description="主职业数量", ge=1, le=20)
sub_career_count: int = Field(8, description="副职业数量", ge=0, le=30)
enable_mcp: bool = Field(False, description="是否启用MCP工具增强")
# ===== 角色职业关联相关 =====
class CharacterCareerBase(BaseModel):
"""角色职业关联基础模型"""
career_id: str = Field(..., description="职业ID")
career_type: str = Field(..., description="main(主职业)/sub(副职业)")
current_stage: int = Field(1, description="当前阶段", ge=1)
stage_progress: int = Field(0, description="阶段内进度(0-100", ge=0, le=100)
started_at: Optional[str] = Field(None, description="开始修炼时间")
reached_current_stage_at: Optional[str] = Field(None, description="到达当前阶段时间")
notes: Optional[str] = Field(None, description="备注")
class CharacterCareerCreate(CharacterCareerBase):
"""创建角色职业关联的请求模型"""
character_id: str = Field(..., description="角色ID")
class CharacterCareerUpdate(BaseModel):
"""更新角色职业关联的请求模型"""
current_stage: Optional[int] = Field(None, ge=1)
stage_progress: Optional[int] = Field(None, ge=0, le=100)
reached_current_stage_at: Optional[str] = None
notes: Optional[str] = None
class CharacterCareerDetail(BaseModel):
"""角色职业详情模型(包含职业信息)"""
id: str
character_id: str
career_id: str
career_name: str = Field(..., description="职业名称")
career_type: str
current_stage: int
stage_name: str = Field(..., description="当前阶段名称")
stage_description: Optional[str] = Field(None, description="当前阶段描述")
stage_progress: int
max_stage: int = Field(..., description="该职业的最大阶段")
started_at: Optional[str] = None
reached_current_stage_at: Optional[str] = None
notes: Optional[str] = None
created_at: datetime
updated_at: datetime
class CharacterCareerResponse(BaseModel):
"""角色职业响应模型"""
main_career: Optional[CharacterCareerDetail] = Field(None, description="主职业")
sub_careers: List[CharacterCareerDetail] = Field(default_factory=list, description="副职业列表")
class SetMainCareerRequest(BaseModel):
"""设置主职业请求模型"""
career_id: str = Field(..., description="职业ID")
current_stage: int = Field(1, description="当前阶段", ge=1)
started_at: Optional[str] = Field(None, description="开始修炼时间")
class AddSubCareerRequest(BaseModel):
"""添加副职业请求模型"""
career_id: str = Field(..., description="职业ID")
current_stage: int = Field(1, description="当前阶段", ge=1)
started_at: Optional[str] = Field(None, description="开始修炼时间")
class UpdateCareerStageRequest(BaseModel):
"""更新职业阶段请求模型"""
current_stage: int = Field(..., description="新的阶段", ge=1)
stage_progress: int = Field(0, description="阶段进度", ge=0, le=100)
reached_current_stage_at: Optional[str] = Field(None, description="到达时间")
notes: Optional[str] = Field(None, description="备注")
+104 -5
View File
@@ -1,6 +1,6 @@
"""章节相关的Pydantic模型"""
from pydantic import BaseModel, Field
from typing import Optional
from pydantic import BaseModel, Field, ConfigDict
from typing import Optional, List, Dict, Any
from datetime import datetime
@@ -12,6 +12,9 @@ class ChapterBase(BaseModel):
summary: Optional[str] = Field(None, description="章节摘要")
word_count: Optional[int] = Field(0, description="字数")
status: Optional[str] = Field("draft", description="章节状态")
outline_id: Optional[str] = Field(None, description="关联的大纲ID")
sub_index: Optional[int] = Field(1, description="大纲下的子章节序号")
expansion_plan: Optional[str] = Field(None, description="展开规划详情(JSON)")
class ChapterCreate(BaseModel):
@@ -22,6 +25,9 @@ class ChapterCreate(BaseModel):
content: Optional[str] = Field(None, description="章节内容")
summary: Optional[str] = Field(None, description="章节摘要")
status: Optional[str] = Field("draft", description="章节状态")
outline_id: Optional[str] = Field(None, description="关联的大纲ID")
sub_index: Optional[int] = Field(1, description="大纲下的子章节序号")
expansion_plan: Optional[str] = Field(None, description="展开规划详情(JSON)")
class ChapterUpdate(BaseModel):
@@ -44,11 +50,15 @@ class ChapterResponse(BaseModel):
summary: Optional[str] = None
word_count: int = 0
status: str
outline_id: Optional[str] = None
sub_index: Optional[int] = 1
expansion_plan: Optional[str] = None
outline_title: Optional[str] = None # 大纲标题(从Outline表联查)
outline_order: Optional[int] = None # 大纲排序序号(从Outline表联查)
created_at: datetime
updated_at: datetime
class Config:
from_attributes = True
model_config = ConfigDict(from_attributes=True)
class ChapterListResponse(BaseModel):
@@ -65,4 +75,93 @@ class ChapterGenerateRequest(BaseModel):
description="目标字数,默认3000字",
ge=500, # 最小500字
le=10000 # 最大10000字
)
)
enable_mcp: bool = Field(True, description="是否启用MCP工具增强(搜索参考资料)")
model: Optional[str] = Field(None, description="指定使用的AI模型,不提供则使用用户默认模型")
narrative_perspective: Optional[str] = Field(None, description="临时人称视角:first_person/third_person/omniscient,不提供则使用项目默认")
class BatchGenerateRequest(BaseModel):
"""批量生成章节的请求模型"""
start_chapter_number: int = Field(..., description="起始章节序号")
count: int = Field(..., description="生成章节数量", ge=1, le=20)
style_id: Optional[int] = Field(None, description="写作风格ID")
target_word_count: Optional[int] = Field(
3000,
description="目标字数,默认3000字",
ge=500,
le=10000
)
enable_analysis: bool = Field(False, description="是否启用同步分析")
enable_mcp: bool = Field(True, description="是否启用MCP工具增强(搜索参考资料)")
max_retries: int = Field(3, description="每个章节的最大重试次数", ge=0, le=5)
model: Optional[str] = Field(None, description="指定使用的AI模型,不提供则使用用户默认模型")
class BatchGenerateResponse(BaseModel):
"""批量生成响应模型"""
batch_id: str = Field(..., description="批次ID")
message: str = Field(..., description="响应消息")
chapters_to_generate: list[dict] = Field(..., description="待生成章节列表")
estimated_time_minutes: int = Field(..., description="预估耗时(分钟)")
class BatchGenerateStatusResponse(BaseModel):
"""批量生成状态响应模型"""
batch_id: str
status: str
total: int
completed: int
current_chapter_id: Optional[str] = None
current_chapter_number: Optional[int] = None
current_retry_count: Optional[int] = None
max_retries: Optional[int] = None
failed_chapters: list[dict] = []
created_at: Optional[str] = None
started_at: Optional[str] = None
completed_at: Optional[str] = None
error_message: Optional[str] = None
class SceneData(BaseModel):
"""场景数据模型"""
location: str = Field(..., description="场景地点")
characters: List[str] = Field(..., description="参与角色列表")
purpose: str = Field(..., description="场景目的")
class ExpansionPlanUpdate(BaseModel):
"""章节规划更新模型"""
summary: Optional[str] = Field(None, description="章节情节概要")
key_events: Optional[List[str]] = Field(None, description="关键事件列表")
character_focus: Optional[List[str]] = Field(None, description="涉及角色列表")
emotional_tone: Optional[str] = Field(None, description="情感基调")
narrative_goal: Optional[str] = Field(None, description="叙事目标")
conflict_type: Optional[str] = Field(None, description="冲突类型")
estimated_words: Optional[int] = Field(None, description="预估字数", ge=500, le=10000)
scenes: Optional[List[SceneData]] = Field(None, description="场景列表")
model_config = ConfigDict(json_schema_extra={
"example": {
"key_events": ["主角遇到挑战", "关键决策时刻"],
"character_focus": ["张三", "李四"],
"emotional_tone": "紧张激烈",
"narrative_goal": "推进主线剧情",
"conflict_type": "内心冲突",
"estimated_words": 3000,
"scenes": [
{
"location": "城市广场",
"characters": ["张三", "李四"],
"purpose": "初次相遇"
}
]
}
})
class ExpansionPlanResponse(BaseModel):
"""章节规划响应模型"""
id: str = Field(..., description="章节ID")
expansion_plan: Optional[Dict[str, Any]] = Field(None, description="规划数据")
message: str = Field(..., description="响应消息")
+55 -5
View File
@@ -1,5 +1,5 @@
"""角色相关的Pydantic模型"""
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, ConfigDict
from typing import Optional, List, Dict, Any
from datetime import datetime
@@ -21,6 +21,36 @@ class CharacterBase(BaseModel):
traits: Optional[str] = Field(None, description="特征标签(JSON)")
class CharacterCreate(BaseModel):
"""手动创建角色的请求模型"""
project_id: str = Field(..., description="项目ID")
name: str = Field(..., description="角色/组织姓名")
age: Optional[str] = Field(None, description="年龄")
gender: Optional[str] = Field(None, description="性别")
is_organization: bool = Field(False, description="是否为组织")
role_type: Optional[str] = Field("supporting", description="角色类型:protagonist/supporting/antagonist")
personality: Optional[str] = Field(None, description="性格特点/组织特性")
background: Optional[str] = Field(None, description="背景故事")
appearance: Optional[str] = Field(None, description="外貌特征")
relationships: Optional[str] = Field(None, description="人际关系(JSON)")
organization_type: Optional[str] = Field(None, description="组织类型")
organization_purpose: Optional[str] = Field(None, description="组织目的")
organization_members: Optional[str] = Field(None, description="组织成员(JSON)")
traits: Optional[str] = Field(None, description="特征标签(JSON)")
avatar_url: Optional[str] = Field(None, description="头像URL")
# 组织额外字段
power_level: Optional[int] = Field(None, description="组织势力等级(0-100)")
location: Optional[str] = Field(None, description="组织所在地")
motto: Optional[str] = Field(None, description="组织格言/口号")
color: Optional[str] = Field(None, description="组织代表颜色")
# 职业字段
main_career_id: Optional[str] = Field(None, description="主职业ID")
main_career_stage: Optional[int] = Field(None, description="主职业阶段")
sub_careers: Optional[str] = Field(None, description="副职业列表JSON字符串")
class CharacterUpdate(BaseModel):
"""更新角色的请求模型"""
name: Optional[str] = None
@@ -36,6 +66,17 @@ class CharacterUpdate(BaseModel):
organization_purpose: Optional[str] = None
organization_members: Optional[str] = None
traits: Optional[str] = None
# 组织额外字段(会同步到Organization表)
power_level: Optional[int] = Field(None, description="组织势力等级(0-100)")
location: Optional[str] = Field(None, description="组织所在地")
motto: Optional[str] = Field(None, description="组织格言/口号")
color: Optional[str] = Field(None, description="组织代表颜色")
# 职业字段(会同步到CharacterCareer表)
main_career_id: Optional[str] = Field(None, description="主职业ID")
main_career_stage: Optional[int] = Field(None, description="主职业阶段")
sub_careers: Optional[str] = Field(None, description="副职业列表JSON字符串")
class CharacterResponse(CharacterBase):
@@ -46,8 +87,18 @@ class CharacterResponse(CharacterBase):
created_at: datetime
updated_at: datetime
class Config:
from_attributes = True
# 组织额外字段(从Organization表关联)
power_level: Optional[int] = Field(None, description="组织势力等级(0-100)")
location: Optional[str] = Field(None, description="组织所在地")
motto: Optional[str] = Field(None, description="组织格言/口号")
color: Optional[str] = Field(None, description="组织代表颜色")
# 职业信息字段
main_career_id: Optional[str] = Field(None, description="主职业ID")
main_career_stage: Optional[int] = Field(None, description="主职业阶段")
sub_careers: Optional[List[Dict[str, Any]]] = Field(None, description="副职业列表")
model_config = ConfigDict(from_attributes=True)
class CharacterGenerateRequest(BaseModel):
@@ -57,8 +108,7 @@ class CharacterGenerateRequest(BaseModel):
role_type: Optional[str] = Field(None, description="角色类型")
background: Optional[str] = Field(None, description="角色背景")
requirements: Optional[str] = Field(None, description="特殊要求")
provider: Optional[str] = Field(None, description="AI提供商")
model: Optional[str] = Field(None, description="AI模型")
enable_mcp: bool = Field(True, description="是否启用MCP工具增强(搜索人物原型参考)")
class CharacterListResponse(BaseModel):
+40
View File
@@ -19,6 +19,11 @@ class ChapterExportData(BaseModel):
word_count: int = 0
status: str = "draft"
created_at: Optional[str] = None
# 大纲细化功能新增字段
outline_title: Optional[str] = None # 关联的大纲标题(用于导入时重建关联)
sub_index: Optional[int] = None # 大纲下的子章节序号
expansion_plan: Optional[Dict[str, Any]] = None # 展开规划详情(JSON对象)
class CharacterExportData(BaseModel):
@@ -31,9 +36,20 @@ class CharacterExportData(BaseModel):
personality: Optional[str] = None
background: Optional[str] = None
appearance: Optional[str] = None
relationships: Optional[str] = None
traits: Optional[List[str]] = None
organization_type: Optional[str] = None
organization_purpose: Optional[str] = None
organization_members: Optional[str] = None
avatar_url: Optional[str] = None
main_career_id: Optional[str] = None
main_career_stage: Optional[int] = None
sub_careers: Optional[str] = None
# 组织专属字段
power_level: Optional[int] = None
location: Optional[str] = None
motto: Optional[str] = None
color: Optional[str] = None
created_at: Optional[str] = None
@@ -133,4 +149,28 @@ class ImportResult(BaseModel):
project_id: Optional[str] = None
message: str
statistics: Dict[str, int] = {}
details: Optional[Dict[str, List[str]]] = None
warnings: List[str] = []
class CharactersExportRequest(BaseModel):
"""角色/组织批量导出请求"""
character_ids: List[str] = Field(..., description="要导出的角色/组织ID列表")
class CharactersExportData(BaseModel):
"""角色/组织批量导出数据"""
version: str = "1.0.0"
export_time: str
export_type: str = "characters"
count: int
data: List[CharacterExportData]
class CharactersImportResult(BaseModel):
"""角色/组织导入结果"""
success: bool
message: str
statistics: Dict[str, int]
details: Dict[str, List[str]]
warnings: List[str] = []
+104
View File
@@ -0,0 +1,104 @@
"""MCP插件Pydantic模式"""
from pydantic import BaseModel, Field, ConfigDict
from typing import Optional, Dict, Any, List
from datetime import datetime
class MCPToolSchema(BaseModel):
"""MCP工具定义"""
name: str
description: Optional[str] = None
inputSchema: Optional[Dict[str, Any]] = None
category: Optional[str] = None
class MCPPluginBase(BaseModel):
"""插件基础模式"""
plugin_name: str = Field(..., description="插件唯一标识")
display_name: Optional[str] = Field(None, description="显示名称")
description: Optional[str] = Field(None, description="插件描述")
plugin_type: str = Field(default="http", description="插件类型:http/stdio")
category: str = Field(default="general", description="分类")
sort_order: int = Field(default=0, description="排序顺序")
class MCPPluginCreate(MCPPluginBase):
"""创建插件"""
server_url: Optional[str] = Field(None, description="服务器URLHTTP类型)")
command: Optional[str] = Field(None, description="启动命令(stdio类型)")
args: Optional[List[str]] = Field(None, description="命令参数")
env: Optional[Dict[str, str]] = Field(None, description="环境变量")
headers: Optional[Dict[str, str]] = Field(None, description="HTTP请求头")
config: Optional[Dict[str, Any]] = Field(None, description="插件特定配置")
enabled: bool = Field(default=True, description="是否启用")
class MCPPluginSimpleCreate(BaseModel):
"""简化的插件创建(通过标准MCP配置JSON)"""
config_json: str = Field(..., description="标准MCP配置JSON字符串")
enabled: bool = Field(default=True, description="是否启用")
category: str = Field(default="general", description="插件分类")
class MCPPluginUpdate(BaseModel):
"""更新插件"""
display_name: Optional[str] = None
description: Optional[str] = None
server_url: Optional[str] = None
command: Optional[str] = None
args: Optional[List[str]] = None
env: Optional[Dict[str, str]] = None
headers: Optional[Dict[str, str]] = None
config: Optional[Dict[str, Any]] = None
enabled: Optional[bool] = None
category: Optional[str] = None
sort_order: Optional[int] = None
class MCPPluginResponse(BaseModel):
"""插件响应 - 优化后只返回必要字段"""
id: str
plugin_name: str
display_name: str
description: Optional[str] = None
plugin_type: str
category: str
# HTTP类型字段
server_url: Optional[str] = None
headers: Optional[Dict[str, str]] = None
# Stdio类型字段
command: Optional[str] = None
args: Optional[List[str]] = None
env: Optional[Dict[str, str]] = None
# 状态字段
enabled: bool
status: str
last_error: Optional[str] = None
last_test_at: Optional[datetime] = None
# 时间戳
created_at: datetime
model_config = ConfigDict(from_attributes=True)
class MCPToolCall(BaseModel):
"""工具调用请求"""
plugin_id: str = Field(..., description="插件ID")
tool_name: str = Field(..., description="工具名称")
arguments: Dict[str, Any] = Field(default_factory=dict, description="工具参数")
class MCPTestResult(BaseModel):
"""测试结果"""
success: bool
message: str
response_time_ms: Optional[float] = None
tools_count: Optional[int] = None
tools: Optional[List[MCPToolSchema]] = None
error: Optional[str] = None
error_type: Optional[str] = None
suggestions: Optional[List[str]] = None
+144 -11
View File
@@ -1,9 +1,74 @@
"""大纲相关的Pydantic模型"""
from pydantic import BaseModel, Field
from typing import Optional
from pydantic import BaseModel, Field, ConfigDict
from typing import Optional, List, Dict, Any
from datetime import datetime
# 角色预测相关Schema
class CharacterPredictionRequest(BaseModel):
"""角色预测请求"""
project_id: str
start_chapter: int
chapter_count: int = 3
plot_stage: str = "development"
story_direction: Optional[str] = "自然延续"
enable_mcp: bool = True
class PredictedCharacter(BaseModel):
"""预测的角色信息"""
name: Optional[str] = None
role_description: str
suggested_role_type: str
importance: str
appearance_chapter: int
key_abilities: List[str] = []
plot_function: str
relationship_suggestions: List[Dict[str, str]] = []
class CharacterPredictionResponse(BaseModel):
"""角色预测响应"""
needs_new_characters: bool
reason: str
character_count: int
predicted_characters: List[PredictedCharacter]
# 组织预测相关Schema
class OrganizationPredictionRequest(BaseModel):
"""组织预测请求"""
project_id: str
start_chapter: int
chapter_count: int = 3
plot_stage: str = "development"
story_direction: Optional[str] = "自然延续"
enable_mcp: bool = True
class PredictedOrganization(BaseModel):
"""预测的组织信息"""
name: Optional[str] = None
organization_description: str
organization_type: str
importance: str
appearance_chapter: int
power_level: int = 50
plot_function: str
location: Optional[str] = None
motto: Optional[str] = None
initial_members: List[Dict[str, Any]] = []
relationship_suggestions: List[Dict[str, str]] = []
class OrganizationPredictionResponse(BaseModel):
"""组织预测响应"""
needs_new_organizations: bool
reason: str
organization_count: int
predicted_organizations: List[PredictedOrganization]
class OutlineBase(BaseModel):
"""大纲基础模型"""
title: str = Field(..., description="章节标题")
@@ -38,8 +103,7 @@ class OutlineResponse(BaseModel):
created_at: datetime
updated_at: datetime
class Config:
from_attributes = True
model_config = ConfigDict(from_attributes=True)
class OutlineGenerateRequest(BaseModel):
@@ -61,6 +125,17 @@ class OutlineGenerateRequest(BaseModel):
story_direction: Optional[str] = Field(None, description="故事发展方向提示(续写时使用)")
plot_stage: str = Field("development", description="情节阶段: development(发展), climax(高潮), ending(结局)")
keep_existing: bool = Field(False, description="是否保留现有大纲(续写时)")
enable_mcp: bool = Field(True, description="是否启用MCP工具增强(搜索情节设计参考)")
# 自动角色引入相关参数
enable_auto_characters: bool = Field(True, description="是否启用自动角色引入(根据剧情推进自动创建新角色)")
require_character_confirmation: bool = Field(True, description="是否需要用户确认新角色(False则AI预测的角色直接创建)")
confirmed_characters: Optional[List[Dict[str, Any]]] = Field(None, description="用户确认的角色列表(跳过预测直接创建)")
# 自动组织引入相关参数
enable_auto_organizations: bool = Field(True, description="是否启用自动组织引入(根据剧情推进自动创建新组织)")
require_organization_confirmation: bool = Field(True, description="是否需要用户确认新组织(False则AI预测的组织直接创建)")
confirmed_organizations: Optional[List[Dict[str, Any]]] = Field(None, description="用户确认的组织列表(跳过预测直接创建)")
class ChapterOutlineGenerateRequest(BaseModel):
@@ -77,12 +152,70 @@ class OutlineListResponse(BaseModel):
items: list[OutlineResponse]
class OutlineReorderItem(BaseModel):
"""单个大纲重排序"""
id: str = Field(..., description="大纲ID")
order_index: int = Field(..., description="新的序号", ge=1)
class ChapterPlanItem(BaseModel):
"""单个章节规划"""
sub_index: int = Field(..., description="子章节序号", ge=1)
title: str = Field(..., description="章节标题")
plot_summary: str = Field(..., description="剧情摘要(200-300字)")
key_events: list[str] = Field(..., description="关键事件列表")
character_focus: list[str] = Field(..., description="主要涉及的角色")
emotional_tone: str = Field(..., description="情感基调")
narrative_goal: str = Field(..., description="叙事目标")
conflict_type: str = Field(..., description="冲突类型")
estimated_words: int = Field(3000, description="预计字数", ge=1000)
scenes: Optional[list[str]] = Field(None, description="场景列表(可选)")
class OutlineReorderRequest(BaseModel):
"""大纲批量重排序请求"""
orders: list[OutlineReorderItem] = Field(..., description="排序列表")
class OutlineExpansionRequest(BaseModel):
"""大纲展开为多章节的请求模型(outline_id从路径参数获取)"""
target_chapter_count: int = Field(3, description="目标章节数", ge=1, le=10)
expansion_strategy: str = Field("balanced", description="展开策略: balanced(均衡), climax(高潮重点), detail(细节丰富)")
enable_scene_analysis: bool = Field(False, description="是否包含场景规划")
auto_create_chapters: bool = Field(True, description="是否自动创建章节记录")
provider: Optional[str] = Field(None, description="AI提供商")
model: Optional[str] = Field(None, description="AI模型")
class OutlineExpansionResponse(BaseModel):
"""大纲展开响应模型"""
outline_id: str = Field(..., description="大纲ID")
outline_title: str = Field(..., description="大纲标题")
target_chapter_count: int = Field(..., description="目标章节数")
actual_chapter_count: int = Field(..., description="实际生成的章节数")
expansion_strategy: str = Field(..., description="使用的展开策略")
chapter_plans: list[ChapterPlanItem] = Field(..., description="章节规划列表")
created_chapters: Optional[list] = Field(None, description="已创建的章节列表")
class BatchOutlineExpansionRequest(BaseModel):
"""批量大纲展开请求模型"""
project_id: str = Field(..., description="项目ID")
outline_ids: Optional[list[str]] = Field(None, description="要展开的大纲ID列表(为空则展开所有)")
chapters_per_outline: int = Field(3, description="每个大纲的目标章节数", ge=1, le=10)
expansion_strategy: str = Field("balanced", description="展开策略")
enable_scene_analysis: bool = Field(False, description="是否包含场景规划")
auto_create_chapters: bool = Field(True, description="是否自动创建章节记录")
provider: Optional[str] = Field(None, description="AI提供商")
model: Optional[str] = Field(None, description="AI模型")
class BatchOutlineExpansionResponse(BaseModel):
"""批量大纲展开响应模型"""
project_id: str = Field(..., description="项目ID")
total_outlines_expanded: int = Field(..., description="总共展开的大纲数")
total_chapters_created: int = Field(..., description="总共创建的章节数")
expansion_results: list[OutlineExpansionResponse] = Field(..., description="展开结果列表")
skipped_outlines: Optional[list[dict]] = Field(None, description="跳过的大纲列表(已展开)")
class CreateChaptersFromPlansRequest(BaseModel):
"""根据已有规划创建章节的请求模型"""
chapter_plans: list[ChapterPlanItem] = Field(..., description="章节规划列表(来自之前的AI生成结果)")
class CreateChaptersFromPlansResponse(BaseModel):
"""根据已有规划创建章节的响应模型"""
outline_id: str = Field(..., description="大纲ID")
outline_title: str = Field(..., description="大纲标题")
chapters_created: int = Field(..., description="创建的章节数")
created_chapters: list = Field(..., description="创建的章节列表")
+12 -4
View File
@@ -1,6 +1,6 @@
"""项目相关的Pydantic模型"""
from pydantic import BaseModel, Field
from typing import Optional
from pydantic import BaseModel, Field, ConfigDict
from typing import Optional, Literal
from datetime import datetime
@@ -11,6 +11,10 @@ class ProjectBase(BaseModel):
theme: Optional[str] = Field(None, description="主题")
genre: Optional[str] = Field(None, description="小说类型")
target_words: Optional[int] = Field(None, description="目标字数")
outline_mode: Literal["one-to-one", "one-to-many"] = Field(
default="one-to-many",
description="大纲章节模式: one-to-one(传统模式,1大纲→1章节) 或 one-to-many(细化模式,1大纲→N章节)"
)
class ProjectCreate(ProjectBase):
@@ -51,11 +55,11 @@ class ProjectResponse(ProjectBase):
chapter_count: Optional[int] = None
narrative_perspective: Optional[str] = None
character_count: Optional[int] = None
outline_mode: str # 显式声明以确保响应中包含
created_at: datetime
updated_at: datetime
class Config:
from_attributes = True
model_config = ConfigDict(from_attributes=True)
class ProjectListResponse(BaseModel):
@@ -73,6 +77,10 @@ class ProjectWizardRequest(BaseModel):
narrative_perspective: str = Field(..., description="叙事视角")
character_count: int = Field(5, ge=5, description="角色数量(至少5个)")
target_words: Optional[int] = Field(None, description="目标字数")
outline_mode: Literal["one-to-one", "one-to-many"] = Field(
default="one-to-many",
description="大纲章节模式"
)
class WorldBuildingResponse(BaseModel):
+89
View File
@@ -0,0 +1,89 @@
"""提示词模板相关的Pydantic模型"""
from pydantic import BaseModel, Field, ConfigDict
from typing import Optional, List
from datetime import datetime
class PromptTemplateBase(BaseModel):
"""提示词模板基础模型"""
template_key: str = Field(..., description="模板键名")
template_name: str = Field(..., description="模板显示名称")
template_content: str = Field(..., description="模板内容")
description: Optional[str] = Field(None, description="模板描述")
category: Optional[str] = Field(None, description="模板分类")
parameters: Optional[str] = Field(None, description="模板参数定义(JSON)")
is_active: bool = Field(True, description="是否启用")
class PromptTemplateCreate(PromptTemplateBase):
"""创建提示词模板请求模型"""
pass
class PromptTemplateUpdate(BaseModel):
"""更新提示词模板请求模型"""
template_name: Optional[str] = Field(None, description="模板显示名称")
template_content: Optional[str] = Field(None, description="模板内容")
description: Optional[str] = Field(None, description="模板描述")
category: Optional[str] = Field(None, description="模板分类")
parameters: Optional[str] = Field(None, description="模板参数定义(JSON)")
is_active: Optional[bool] = Field(None, description="是否启用")
class PromptTemplateResponse(PromptTemplateBase):
"""提示词模板响应模型"""
model_config = ConfigDict(from_attributes=True)
id: str
user_id: str
is_system_default: bool
created_at: datetime
updated_at: datetime
class PromptTemplateListResponse(BaseModel):
"""提示词模板列表响应"""
templates: List[PromptTemplateResponse]
total: int
categories: List[str]
class PromptTemplateCategoryResponse(BaseModel):
"""提示词模板分类响应"""
category: str
count: int
templates: List[PromptTemplateResponse]
class PromptTemplateExportItem(BaseModel):
"""提示词模板导出项模型"""
template_key: str = Field(..., description="模板键名")
template_name: str = Field(..., description="模板显示名称")
template_content: str = Field(..., description="模板内容")
description: Optional[str] = Field(None, description="模板描述")
category: Optional[str] = Field(None, description="模板分类")
parameters: Optional[str] = Field(None, description="模板参数定义(JSON)")
is_active: bool = Field(True, description="是否启用")
is_customized: bool = Field(..., description="是否为用户自定义(false=系统默认,true=用户自定义)")
system_content_hash: Optional[str] = Field(None, description="系统默认内容的哈希值,用于比对")
class PromptTemplateExport(BaseModel):
"""提示词模板导出模型"""
templates: List[PromptTemplateExportItem]
export_time: datetime
version: str = "2.0"
statistics: Optional[dict] = Field(None, description="导出统计信息")
class PromptTemplateImportResult(BaseModel):
"""提示词模板导入结果"""
message: str
statistics: dict = Field(..., description="导入统计信息")
converted_templates: List[dict] = Field(default_factory=list, description="被转换为自定义的模板列表")
class PromptTemplatePreviewRequest(BaseModel):
"""提示词模板预览请求"""
template_content: str = Field(..., description="模板内容")
parameters: dict = Field(..., description="参数字典")
+65
View File
@@ -0,0 +1,65 @@
"""章节重新生成相关的Schema定义"""
from pydantic import BaseModel, Field
from typing import Optional, List, Dict, Any
from datetime import datetime
class PreserveElementsConfig(BaseModel):
"""保留元素配置"""
preserve_structure: bool = Field(False, description="是否保留整体结构")
preserve_dialogues: List[str] = Field(default_factory=list, description="需要保留的对话片段关键词")
preserve_plot_points: List[str] = Field(default_factory=list, description="需要保留的情节点关键词")
preserve_character_traits: bool = Field(True, description="保持角色性格一致")
class ChapterRegenerateRequest(BaseModel):
"""章节重新生成请求"""
# 修改来源
modification_source: str = Field("custom", description="修改来源: custom/analysis_suggestions/mixed")
# 基于分析建议
selected_suggestion_indices: Optional[List[int]] = Field(None, description="选中的建议索引列表")
# 自定义修改指令
custom_instructions: Optional[str] = Field(None, description="用户自定义的修改要求")
# 保留配置
preserve_elements: Optional[PreserveElementsConfig] = Field(None, description="保留元素配置")
# 生成参数
style_id: Optional[int] = Field(None, description="写作风格ID")
target_word_count: int = Field(3000, description="目标字数", ge=500, le=10000)
focus_areas: List[str] = Field(default_factory=list, description="重点优化方向")
# 版本管理
save_as_version: bool = Field(True, description="是否保存为新版本")
version_note: Optional[str] = Field(None, description="版本说明", max_length=500)
auto_apply: bool = Field(False, description="是否自动应用(替换当前内容)")
class RegenerationTaskResponse(BaseModel):
"""重新生成任务响应"""
task_id: str
chapter_id: str
status: str
message: str
estimated_time_seconds: int = 120
class RegenerationTaskStatus(BaseModel):
"""重新生成任务状态"""
task_id: str
chapter_id: str
status: str
progress: int
error_message: Optional[str] = None
created_at: Optional[datetime] = None
started_at: Optional[datetime] = None
completed_at: Optional[datetime] = None
# 结果信息
original_word_count: Optional[int] = None
regenerated_word_count: Optional[int] = None
version_number: Optional[int] = None
+7 -11
View File
@@ -1,5 +1,5 @@
"""关系管理相关的Pydantic模型"""
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, ConfigDict
from typing import Optional, List
from datetime import datetime
@@ -17,8 +17,7 @@ class RelationshipTypeResponse(BaseModel):
description: Optional[str] = None
created_at: datetime
class Config:
from_attributes = True
model_config = ConfigDict(from_attributes=True)
# ============ 角色关系相关 ============
@@ -27,7 +26,7 @@ class CharacterRelationshipBase(BaseModel):
"""角色关系基础模型"""
relationship_type_id: Optional[int] = Field(None, description="关系类型ID")
relationship_name: Optional[str] = Field(None, description="自定义关系名称")
intimacy_level: int = Field(50, ge=0, le=100, description="亲密度:0-100")
intimacy_level: int = Field(50, ge=-100, le=100, description="亲密度:-100到100")
status: str = Field("active", description="状态:active/broken/past/complicated")
description: Optional[str] = Field(None, description="关系描述")
started_at: Optional[str] = Field(None, description="关系开始时间(故事时间)")
@@ -45,7 +44,7 @@ class CharacterRelationshipUpdate(BaseModel):
"""更新角色关系的请求模型"""
relationship_type_id: Optional[int] = None
relationship_name: Optional[str] = None
intimacy_level: Optional[int] = Field(None, ge=0, le=100)
intimacy_level: Optional[int] = Field(None, ge=-100, le=100)
status: Optional[str] = None
description: Optional[str] = None
started_at: Optional[str] = None
@@ -62,8 +61,7 @@ class CharacterRelationshipResponse(CharacterRelationshipBase):
created_at: datetime
updated_at: datetime
class Config:
from_attributes = True
model_config = ConfigDict(from_attributes=True)
class RelationshipGraphNode(BaseModel):
@@ -127,8 +125,7 @@ class OrganizationResponse(OrganizationBase):
created_at: datetime
updated_at: datetime
class Config:
from_attributes = True
model_config = ConfigDict(from_attributes=True)
class OrganizationDetailResponse(BaseModel):
@@ -185,8 +182,7 @@ class OrganizationMemberResponse(OrganizationMemberBase):
created_at: datetime
updated_at: datetime
class Config:
from_attributes = True
model_config = ConfigDict(from_attributes=True)
class OrganizationMemberDetailResponse(BaseModel):
+61 -2
View File
@@ -1,6 +1,6 @@
"""设置相关的Pydantic模型"""
from pydantic import BaseModel, Field, ConfigDict
from typing import Optional
from typing import Optional, List
from datetime import datetime
@@ -14,6 +14,7 @@ class SettingsBase(BaseModel):
llm_model: Optional[str] = Field(default="gpt-4", description="模型名称")
temperature: Optional[float] = Field(default=0.7, ge=0.0, le=2.0, description="温度参数")
max_tokens: Optional[int] = Field(default=2000, ge=1, description="最大token数")
system_prompt: Optional[str] = Field(default=None, description="系统级别提示词,每次AI调用都会使用")
preferences: Optional[str] = Field(default=None, description="其他偏好设置(JSON)")
@@ -34,4 +35,62 @@ class SettingsResponse(SettingsBase):
id: str
user_id: str
created_at: datetime
updated_at: datetime
updated_at: datetime
# ========== API配置预设相关模型 ==========
class APIKeyPresetConfig(BaseModel):
"""预设配置内容"""
model_config = ConfigDict(protected_namespaces=())
api_provider: str = Field(..., description="API提供商")
api_key: str = Field(..., description="API密钥")
api_base_url: Optional[str] = Field(None, description="自定义API地址")
llm_model: str = Field(..., description="模型名称")
temperature: float = Field(default=0.7, ge=0.0, le=2.0, description="温度参数")
max_tokens: int = Field(default=2000, ge=1, description="最大token数")
class APIKeyPreset(BaseModel):
"""API配置预设"""
model_config = ConfigDict(protected_namespaces=())
id: str = Field(..., description="预设ID")
name: str = Field(..., min_length=1, max_length=50, description="预设名称")
description: Optional[str] = Field(None, max_length=200, description="预设描述")
is_active: bool = Field(default=False, description="是否激活")
created_at: datetime = Field(..., description="创建时间")
config: APIKeyPresetConfig = Field(..., description="配置内容")
class PresetCreateRequest(BaseModel):
"""创建预设请求"""
model_config = ConfigDict(protected_namespaces=())
name: str = Field(..., min_length=1, max_length=50, description="预设名称")
description: Optional[str] = Field(None, max_length=200, description="预设描述")
config: APIKeyPresetConfig = Field(..., description="配置内容")
class PresetUpdateRequest(BaseModel):
"""更新预设请求"""
model_config = ConfigDict(protected_namespaces=())
name: Optional[str] = Field(None, min_length=1, max_length=50, description="预设名称")
description: Optional[str] = Field(None, max_length=200, description="预设描述")
config: Optional[APIKeyPresetConfig] = Field(None, description="配置内容")
class PresetResponse(APIKeyPreset):
"""预设响应"""
pass
class PresetListResponse(BaseModel):
"""预设列表响应"""
model_config = ConfigDict(protected_namespaces=())
presets: List[PresetResponse] = Field(..., description="预设列表")
total: int = Field(..., description="总数")
active_preset_id: Optional[str] = Field(None, description="当前激活的预设ID")
+10 -7
View File
@@ -1,5 +1,5 @@
"""写作风格 Schema"""
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, ConfigDict
from typing import Optional
from datetime import datetime
@@ -13,9 +13,13 @@ class WritingStyleBase(BaseModel):
prompt_content: str = Field(..., description="风格提示词内容")
class WritingStyleCreate(WritingStyleBase):
"""创建写作风格(仅用于创建项目自定义风格)"""
project_id: str = Field(..., description="所属项目ID")
class WritingStyleCreate(BaseModel):
"""创建写作风格(仅用于创建用户自定义风格)"""
name: str = Field(..., description="风格名称")
style_type: Optional[str] = Field(None, description="风格类型:preset/custom")
preset_id: Optional[str] = Field(None, description="预设风格ID")
description: Optional[str] = Field(None, description="风格描述")
prompt_content: str = Field(..., description="风格提示词内容")
class WritingStyleUpdate(BaseModel):
@@ -33,7 +37,7 @@ class SetDefaultStyleRequest(BaseModel):
class WritingStyleResponse(BaseModel):
"""写作风格响应"""
id: int
project_id: Optional[str] = None # NULL 表示全局预设风格
user_id: Optional[str] = None # NULL 表示全局预设风格
name: str
style_type: str
preset_id: Optional[str] = None
@@ -44,8 +48,7 @@ class WritingStyleResponse(BaseModel):
created_at: datetime
updated_at: datetime
class Config:
from_attributes = True
model_config = ConfigDict(from_attributes=True)
class WritingStyleListResponse(BaseModel):
@@ -0,0 +1,6 @@
"""AI 客户端模块"""
from .base_client import BaseAIClient
from .openai_client import OpenAIClient
from .anthropic_client import AnthropicClient
__all__ = ["BaseAIClient", "OpenAIClient", "AnthropicClient"]
@@ -0,0 +1,142 @@
"""Anthropic 客户端"""
from typing import Any, AsyncGenerator, Dict, Optional
from anthropic import AsyncAnthropic
from app.logger import get_logger
from app.services.ai_config import AIClientConfig, default_config
logger = get_logger(__name__)
class AnthropicClient:
"""Anthropic API 客户端"""
def __init__(self, api_key: str, base_url: Optional[str] = None, config: Optional[AIClientConfig] = None):
self.config = config or default_config
kwargs = {"api_key": api_key}
if base_url:
kwargs["base_url"] = base_url
self.client = AsyncAnthropic(**kwargs)
async def chat_completion(
self,
messages: list,
model: str,
temperature: float,
max_tokens: int,
system_prompt: Optional[str] = None,
tools: Optional[list] = None,
tool_choice: Optional[str] = None,
) -> Dict[str, Any]:
kwargs = {
"model": model,
"max_tokens": max_tokens,
"temperature": temperature,
"messages": messages,
}
if system_prompt:
kwargs["system"] = system_prompt
if tools:
kwargs["tools"] = tools
if tool_choice == "required":
kwargs["tool_choice"] = {"type": "any"}
elif tool_choice == "auto":
kwargs["tool_choice"] = {"type": "auto"}
response = await self.client.messages.create(**kwargs)
tool_calls = []
content = ""
for block in response.content:
if block.type == "tool_use":
tool_calls.append({
"id": block.id,
"type": "function",
"function": {"name": block.name, "arguments": block.input},
})
elif block.type == "text":
content += block.text
return {
"content": content,
"tool_calls": tool_calls if tool_calls else None,
"finish_reason": response.stop_reason,
}
async def chat_completion_stream(
self,
messages: list,
model: str,
temperature: float,
max_tokens: int,
system_prompt: Optional[str] = None,
tools: Optional[list] = None,
tool_choice: Optional[str] = None,
) -> AsyncGenerator[Dict[str, Any], None]:
"""
流式生成支持工具调用
Yields:
Dict with keys:
- content: str - 文本内容块
- tool_calls: list - 工具调用列表如果有
- done: bool - 是否结束
"""
kwargs = {
"model": model,
"max_tokens": max_tokens,
"temperature": temperature,
"messages": messages,
}
if system_prompt:
kwargs["system"] = system_prompt
if tools:
kwargs["tools"] = tools
if tool_choice == "required":
kwargs["tool_choice"] = {"type": "any"}
elif tool_choice == "auto":
kwargs["tool_choice"] = {"type": "auto"}
try:
async with self.client.messages.stream(**kwargs) as stream:
try:
tool_calls = []
async for chunk in stream:
# 处理不同类型的块
if chunk.type == "text_delta":
yield {"content": chunk.text}
elif chunk.type == "tool_use_delta":
# 工具调用增量
if not tool_calls or tool_calls[-1].get("id") != chunk.id:
tool_calls.append({
"id": chunk.id,
"type": "function",
"function": {
"name": chunk.name,
"arguments": ""
}
})
# 追加参数
if tool_calls[-1]["function"]["arguments"] is None:
tool_calls[-1]["function"]["arguments"] = ""
tool_calls[-1]["function"]["arguments"] += chunk.input_gets_new_text or ""
elif chunk.type == "message_delta":
if chunk.stop_reason:
# 流结束
if tool_calls:
yield {"tool_calls": tool_calls}
yield {"done": True, "finish_reason": chunk.stop_reason}
except GeneratorExit:
# 生成器被关闭,这是正常的清理过程
logger.debug("Anthropic 流式响应生成器被关闭(GeneratorExit)")
raise
except Exception as iter_error:
logger.error(f"Anthropic 流式响应迭代出错: {str(iter_error)}")
raise
except GeneratorExit:
# 重新抛出GeneratorExit,让调用方处理
raise
except Exception as e:
logger.error(f"Anthropic 流式请求出错: {str(e)}")
raise
@@ -0,0 +1,154 @@
"""AI 客户端基类"""
import asyncio
import hashlib
from abc import ABC, abstractmethod
from typing import Any, AsyncGenerator, Dict, Optional
import httpx
from app.logger import get_logger
from app.services.ai_config import AIClientConfig, default_config
logger = get_logger(__name__)
# 全局 HTTP 客户端池
_http_client_pool: Dict[str, httpx.AsyncClient] = {}
_global_semaphore: Optional[asyncio.Semaphore] = None
def _get_semaphore(max_concurrent: int) -> asyncio.Semaphore:
"""获取全局信号量"""
global _global_semaphore
if _global_semaphore is None:
_global_semaphore = asyncio.Semaphore(max_concurrent)
return _global_semaphore
class BaseAIClient(ABC):
"""AI HTTP 客户端基类"""
def __init__(
self,
api_key: str,
base_url: str,
config: Optional[AIClientConfig] = None,
):
self.api_key = api_key
self.base_url = base_url.rstrip("/")
self.config = config or default_config
self.http_client = self._get_or_create_client()
def _get_client_key(self) -> str:
"""生成客户端唯一键"""
key_hash = hashlib.md5(self.api_key.encode()).hexdigest()[:8]
return f"{self.__class__.__name__}_{self.base_url}_{key_hash}"
def _get_or_create_client(self) -> httpx.AsyncClient:
"""获取或创建 HTTP 客户端"""
client_key = self._get_client_key()
if client_key in _http_client_pool:
client = _http_client_pool[client_key]
if not client.is_closed:
return client
del _http_client_pool[client_key]
http_cfg = self.config.http
client = httpx.AsyncClient(
timeout=httpx.Timeout(
connect=http_cfg.connect_timeout,
read=http_cfg.read_timeout,
write=http_cfg.write_timeout,
pool=http_cfg.pool_timeout,
),
limits=httpx.Limits(
max_keepalive_connections=http_cfg.max_keepalive_connections,
max_connections=http_cfg.max_connections,
keepalive_expiry=http_cfg.keepalive_expiry,
),
)
_http_client_pool[client_key] = client
logger.info(f"✅ 创建 HTTP 客户端: {client_key}")
return client
@abstractmethod
def _build_headers(self) -> Dict[str, str]:
"""构建请求头"""
pass
async def _request_with_retry(
self,
method: str,
endpoint: str,
payload: Dict[str, Any],
stream: bool = False,
) -> Any:
"""带重试的 HTTP 请求"""
url = f"{self.base_url}{endpoint}"
headers = self._build_headers()
retry_cfg = self.config.retry
rate_cfg = self.config.rate_limit
semaphore = _get_semaphore(rate_cfg.max_concurrent_requests)
async with semaphore:
await asyncio.sleep(rate_cfg.request_delay)
for attempt in range(retry_cfg.max_retries):
try:
if attempt > 0:
delay = min(
retry_cfg.base_delay * (retry_cfg.exponential_base ** attempt),
retry_cfg.max_delay,
)
logger.warning(f"⚠️ 重试 {attempt + 1}/{retry_cfg.max_retries},等待 {delay}s")
await asyncio.sleep(delay)
if stream:
return self.http_client.stream(method, url, headers=headers, json=payload)
response = await self.http_client.request(method, url, headers=headers, json=payload)
response.raise_for_status()
return response.json()
except httpx.HTTPStatusError as e:
if e.response.status_code in retry_cfg.non_retryable_status_codes:
raise
if attempt == retry_cfg.max_retries - 1:
raise
except (httpx.ConnectError, httpx.TimeoutException):
if attempt == retry_cfg.max_retries - 1:
raise
@abstractmethod
async def chat_completion(
self,
messages: list,
model: str,
temperature: float,
max_tokens: int,
tools: Optional[list] = None,
tool_choice: Optional[str] = None,
) -> Dict[str, Any]:
"""聊天补全"""
pass
@abstractmethod
async def chat_completion_stream(
self,
messages: list,
model: str,
temperature: float,
max_tokens: int,
) -> AsyncGenerator[str, None]:
"""流式聊天补全"""
pass
async def cleanup_all_clients():
"""清理所有 HTTP 客户端"""
for key, client in list(_http_client_pool.items()):
if not client.is_closed:
await client.aclose()
_http_client_pool.clear()
logger.info("✅ HTTP 客户端池已清理")
@@ -0,0 +1,189 @@
"""Gemini 客户端"""
from typing import Any, AsyncGenerator, Dict, List, Optional
import httpx
from app.services.ai_config import AIClientConfig, default_config
from app.logger import get_logger
logger = get_logger(__name__)
class GeminiClient:
"""Google Gemini API 客户端"""
def __init__(self, api_key: str, base_url: Optional[str] = None, config: Optional[AIClientConfig] = None):
self.api_key = api_key
self.base_url = (base_url or "https://generativelanguage.googleapis.com/v1beta").rstrip("/")
self.config = config or default_config
http_cfg = self.config.http
self.client = httpx.AsyncClient(
timeout=httpx.Timeout(
connect=http_cfg.connect_timeout,
read=http_cfg.read_timeout,
write=http_cfg.write_timeout,
pool=http_cfg.pool_timeout
)
)
def _convert_tools_to_gemini(self, tools: list) -> list:
"""将 OpenAI 格式工具转换为 Gemini 格式"""
gemini_tools = []
for tool in tools:
if tool.get("type") == "function":
func = tool["function"]
params = func.get("parameters", {}).copy() if func.get("parameters") else {}
params.pop("$schema", None)
params.pop("additionalProperties", None)
if params and "type" not in params:
params["type"] = "object"
decl = {
"name": func["name"],
"description": func.get("description") or func["name"],
}
if params:
decl["parameters"] = params
gemini_tools.append(decl)
return [{"functionDeclarations": gemini_tools}] if gemini_tools else []
async def chat_completion(
self,
messages: list,
model: str,
temperature: float,
max_tokens: int,
system_prompt: Optional[str] = None,
tools: Optional[list] = None,
tool_choice: Optional[str] = None,
) -> Dict[str, Any]:
url = f"{self.base_url}/models/{model}:generateContent?key={self.api_key}"
contents = []
for msg in messages:
role = "user" if msg["role"] == "user" else "model"
contents.append({"role": role, "parts": [{"text": msg["content"]}]})
payload = {
"contents": contents,
"generationConfig": {"temperature": temperature, "maxOutputTokens": max_tokens}
}
if system_prompt:
payload["systemInstruction"] = {"parts": [{"text": system_prompt}]}
if tools:
payload["tools"] = self._convert_tools_to_gemini(tools)
response = await self.client.post(url, json=payload)
response.raise_for_status()
data = response.json()
candidates = data.get("candidates", [])
if not candidates or len(candidates) == 0:
# 返回空内容而不是报错,保持流程继续
return {
"content": "",
"tool_calls": None,
"finish_reason": "stop"
}
parts = candidates[0].get("content", {}).get("parts", [])
text = ""
tool_calls = []
for part in parts:
if "text" in part:
text += part["text"]
elif "functionCall" in part:
fc = part["functionCall"]
tool_calls.append({
"id": f"call_{fc['name']}",
"type": "function",
"function": {"name": fc["name"], "arguments": fc.get("args", {})}
})
return {
"content": text,
"tool_calls": tool_calls if tool_calls else None,
"finish_reason": "tool_calls" if tool_calls else "stop"
}
async def chat_completion_stream(
self,
messages: list,
model: str,
temperature: float,
max_tokens: int,
system_prompt: Optional[str] = None,
tools: Optional[list] = None,
tool_choice: Optional[str] = None,
) -> AsyncGenerator[Dict[str, Any], None]:
"""
流式生成支持工具调用
Yields:
Dict with keys:
- content: str - 文本内容块
- tool_calls: list - 工具调用列表如果有
- done: bool - 是否结束
"""
url = f"{self.base_url}/models/{model}:streamGenerateContent?key={self.api_key}&alt=sse"
contents = []
for msg in messages:
role = "user" if msg["role"] == "user" else "model"
contents.append({"role": role, "parts": [{"text": msg["content"]}]})
payload = {
"contents": contents,
"generationConfig": {"temperature": temperature, "maxOutputTokens": max_tokens}
}
if system_prompt:
payload["systemInstruction"] = {"parts": [{"text": system_prompt}]}
if tools:
payload["tools"] = self._convert_tools_to_gemini(tools)
try:
async with self.client.stream("POST", url, json=payload) as response:
response.raise_for_status()
try:
async for line in response.aiter_lines():
if line.startswith("data: "):
import json
try:
data = json.loads(line[6:])
candidates = data.get("candidates", [])
if candidates and len(candidates) > 0:
parts = candidates[0].get("content", {}).get("parts", [])
if parts and len(parts) > 0:
text = ""
function_calls = []
for part in parts:
if "text" in part:
text += part["text"]
elif "functionCall" in part:
fc = part["functionCall"]
function_calls.append({
"id": f"call_{fc['name']}",
"type": "function",
"function": {
"name": fc["name"],
"arguments": fc.get("args", {})
}
})
if text:
yield {"content": text}
if function_calls:
yield {"tool_calls": function_calls}
except json.JSONDecodeError:
continue
except GeneratorExit:
# 生成器被关闭,这是正常的清理过程
logger.debug("Gemini 流式响应生成器被关闭(GeneratorExit)")
raise
except Exception as iter_error:
logger.error(f"Gemini 流式响应迭代出错: {str(iter_error)}")
raise
except GeneratorExit:
# 重新抛出GeneratorExit,让调用方处理
raise
except Exception as e:
logger.error(f"Gemini 流式请求出错: {str(e)}")
raise
@@ -0,0 +1,159 @@
"""OpenAI 客户端"""
import json
from typing import Any, AsyncGenerator, Dict, Optional
from app.logger import get_logger
from .base_client import BaseAIClient
logger = get_logger(__name__)
class OpenAIClient(BaseAIClient):
"""OpenAI API 客户端"""
def _build_headers(self) -> Dict[str, str]:
return {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
}
def _build_payload(
self,
messages: list,
model: str,
temperature: float,
max_tokens: int,
tools: Optional[list] = None,
tool_choice: Optional[str] = None,
stream: bool = False,
) -> Dict[str, Any]:
payload = {
"model": model,
"messages": messages,
"temperature": temperature,
"max_tokens": max_tokens,
}
if stream:
payload["stream"] = True
if tools:
# 清理 $schema 字段
cleaned = []
for t in tools:
tc = t.copy()
if "function" in tc and "parameters" in tc["function"]:
tc["function"]["parameters"] = {
k: v for k, v in tc["function"]["parameters"].items() if k != "$schema"
}
cleaned.append(tc)
payload["tools"] = cleaned
if tool_choice:
payload["tool_choice"] = tool_choice
return payload
async def chat_completion(
self,
messages: list,
model: str,
temperature: float,
max_tokens: int,
tools: Optional[list] = None,
tool_choice: Optional[str] = None,
) -> Dict[str, Any]:
payload = self._build_payload(messages, model, temperature, max_tokens, tools, tool_choice)
logger.debug(f"📤 OpenAI 请求 payload: {json.dumps(payload, ensure_ascii=False, indent=2)}")
data = await self._request_with_retry("POST", "/chat/completions", payload)
# 调试日志:输出原始响应
logger.debug(f"📥 OpenAI 原始响应: {json.dumps(data, ensure_ascii=False, indent=2)}")
choices = data.get("choices", [])
if not choices or len(choices) == 0:
raise ValueError("API 返回空 choices 或 choices 为空列表")
choice = choices[0]
message = choice.get("message", {})
return {
"content": message.get("content", ""),
"tool_calls": message.get("tool_calls"),
"finish_reason": choice.get("finish_reason"),
}
async def chat_completion_stream(
self,
messages: list,
model: str,
temperature: float,
max_tokens: int,
tools: Optional[list] = None,
tool_choice: Optional[str] = None,
) -> AsyncGenerator[Dict[str, Any], None]:
"""
流式生成支持工具调用
Yields:
Dict with keys:
- content: str - 文本内容块
- tool_calls: list - 工具调用列表如果有
- done: bool - 是否结束
"""
payload = self._build_payload(messages, model, temperature, max_tokens, tools, tool_choice, stream=True)
tool_calls_buffer = {} # 收集工具调用块
try:
async with await self._request_with_retry("POST", "/chat/completions", payload, stream=True) as response:
response.raise_for_status()
try:
async for line in response.aiter_lines():
if line.startswith("data: "):
data_str = line[6:]
if data_str.strip() == "[DONE]":
# 流结束,检查是否有工具调用需要处理
if tool_calls_buffer:
yield {"tool_calls": list(tool_calls_buffer.values()), "done": True}
yield {"done": True}
break
try:
data = json.loads(data_str)
choices = data.get("choices", [])
if choices and len(choices) > 0:
delta = choices[0].get("delta", {})
content = delta.get("content", "")
# 检查工具调用
tc_list = delta.get("tool_calls")
if tc_list:
for tc in tc_list:
index = tc.get("index", 0)
if index not in tool_calls_buffer:
tool_calls_buffer[index] = tc
else:
existing = tool_calls_buffer[index]
# 合并 function.arguments
if "function" in tc and "function" in existing:
if tc["function"].get("arguments"):
existing["function"]["arguments"] = (
existing["function"].get("arguments", "") +
tc["function"]["arguments"]
)
if content:
yield {"content": content}
except json.JSONDecodeError:
continue
except GeneratorExit:
# 生成器被关闭,这是正常的清理过程
logger.debug("流式响应生成器被关闭(GeneratorExit)")
raise
except Exception as iter_error:
logger.error(f"流式响应迭代出错: {str(iter_error)}")
raise
except GeneratorExit:
# 重新抛出GeneratorExit,让调用方处理
raise
except Exception as e:
logger.error(f"流式请求出错: {str(e)}")
raise
+44
View File
@@ -0,0 +1,44 @@
"""AI 服务配置管理"""
from dataclasses import dataclass, field
from typing import Optional
@dataclass
class HTTPClientConfig:
"""HTTP 客户端配置"""
connect_timeout: float = 90.0
read_timeout: float = 300.0
write_timeout: float = 90.0
pool_timeout: float = 90.0
max_keepalive_connections: int = 50
max_connections: int = 100
keepalive_expiry: float = 60.0
@dataclass
class RetryConfig:
"""重试配置"""
max_retries: int = 3
base_delay: float = 0.2
max_delay: float = 10.0
exponential_base: int = 2
non_retryable_status_codes: tuple = field(default_factory=lambda: (401, 403, 404))
@dataclass
class RateLimitConfig:
"""限流配置"""
max_concurrent_requests: int = 5
request_delay: float = 0.2
@dataclass
class AIClientConfig:
"""AI 客户端完整配置"""
http: HTTPClientConfig = field(default_factory=HTTPClientConfig)
retry: RetryConfig = field(default_factory=RetryConfig)
rate_limit: RateLimitConfig = field(default_factory=RateLimitConfig)
# 全局默认配置
default_config = AIClientConfig()
@@ -0,0 +1,6 @@
"""AI Provider 模块"""
from .base_provider import BaseAIProvider
from .openai_provider import OpenAIProvider
from .anthropic_provider import AnthropicProvider
__all__ = ["BaseAIProvider", "OpenAIProvider", "AnthropicProvider"]
@@ -0,0 +1,161 @@
"""Anthropic Provider"""
from typing import Any, AsyncGenerator, Dict, List, Optional
from app.logger import get_logger
from app.services.ai_clients.anthropic_client import AnthropicClient
from .base_provider import BaseAIProvider
logger = get_logger(__name__)
class AnthropicProvider(BaseAIProvider):
"""Anthropic 提供商"""
def __init__(self, client: AnthropicClient):
self.client = client
async def generate(
self,
prompt: str,
model: str,
temperature: float,
max_tokens: int,
system_prompt: Optional[str] = None,
tools: Optional[List[Dict]] = None,
tool_choice: Optional[str] = None,
) -> Dict[str, Any]:
messages = [{"role": "user", "content": prompt}]
return await self.client.chat_completion(
messages=messages,
model=model,
temperature=temperature,
max_tokens=max_tokens,
system_prompt=system_prompt,
tools=tools,
tool_choice=tool_choice,
)
async def generate_stream(
self,
prompt: str,
model: str,
temperature: float,
max_tokens: int,
system_prompt: Optional[str] = None,
tools: Optional[List[Dict]] = None,
tool_choice: Optional[str] = None,
user_id: Optional[str] = None,
) -> AsyncGenerator[str, None]:
# 如果有工具,使用真正的流式工具调用
if tools:
logger.debug(f"🔧 AnthropicProvider: 有 {len(tools)} 个工具,使用流式处理")
messages = [{"role": "user", "content": prompt}]
actual_tool_choice = tool_choice if tool_choice else "auto"
tool_calls_buffer = []
async for chunk in self.client.chat_completion_stream(
messages=messages,
model=model,
temperature=temperature,
max_tokens=max_tokens,
system_prompt=system_prompt,
tools=tools,
tool_choice=actual_tool_choice,
):
# 检查是否有工具调用
if chunk.get("tool_calls"):
tool_calls_buffer.extend(chunk["tool_calls"])
logger.debug(f"🔧 收到工具调用: {len(chunk['tool_calls'])}")
# 检查是否结束
if chunk.get("done"):
if tool_calls_buffer:
logger.info(f"🔧 流式结束,处理 {len(tool_calls_buffer)} 个工具调用")
from app.mcp import mcp_client
actual_user_id = user_id or ""
tool_results = await mcp_client.batch_call_tools(
user_id=actual_user_id,
tool_calls=tool_calls_buffer
)
# 将工具结果注入到上下文中
tool_context = mcp_client.build_tool_context(tool_results, format="markdown")
# 构建最终提示词,要求AI基于工具结果回答
final_prompt = f"{prompt}\n\n{tool_context}\n\n请基于以上工具查询结果,给出完整详细的回答。"
final_messages = [{"role": "user", "content": final_prompt}]
# 递归调用生成最终结果
async for final_chunk in self._generate_with_tools(
final_messages, model, temperature, max_tokens, system_prompt, tools, user_id
):
yield final_chunk
break
# 输出文本内容
if chunk.get("content"):
yield chunk["content"]
return
# 无工具时普通流式生成
messages = [{"role": "user", "content": prompt}]
async for chunk in self.client.chat_completion_stream(
messages=messages,
model=model,
temperature=temperature,
max_tokens=max_tokens,
system_prompt=system_prompt,
):
# 确保只 yield 字符串内容,避免 yield 字典导致类型错误
if isinstance(chunk, dict):
if chunk.get("content"):
yield chunk["content"]
else:
yield chunk
async def _generate_with_tools(
self,
messages: list,
model: str,
temperature: float,
max_tokens: int,
system_prompt: Optional[str] = None,
tools: list = None,
user_id: Optional[str] = None,
) -> AsyncGenerator[str, None]:
"""辅助方法:带工具的流式生成"""
tool_calls_buffer = []
async for chunk in self.client.chat_completion_stream(
messages=messages,
model=model,
temperature=temperature,
max_tokens=max_tokens,
system_prompt=system_prompt,
tools=tools,
tool_choice="auto",
):
if chunk.get("tool_calls"):
tool_calls_buffer.extend(chunk["tool_calls"])
logger.debug(f"🔧 _generate_with_tools 收到工具调用: {len(chunk['tool_calls'])}")
if chunk.get("done"):
if tool_calls_buffer:
from app.mcp import mcp_client
actual_user_id = user_id or ""
tool_results = await mcp_client.batch_call_tools(
user_id=actual_user_id,
tool_calls=tool_calls_buffer
)
tool_context = mcp_client.build_tool_context(tool_results, format="markdown")
messages.append({"role": "user", "content": f"{tool_context}\n\n请基于以上工具查询结果,给出完整详细的回答。"})
async for final_chunk in self._generate_with_tools(
messages, model, temperature, max_tokens, system_prompt, tools, user_id
):
yield final_chunk
break
if chunk.get("content"):
yield chunk["content"]
@@ -0,0 +1,36 @@
"""AI Provider 基类"""
from abc import ABC, abstractmethod
from typing import Any, AsyncGenerator, Dict, List, Optional
class BaseAIProvider(ABC):
"""AI 提供商抽象基类"""
@abstractmethod
async def generate(
self,
prompt: str,
model: str,
temperature: float,
max_tokens: int,
system_prompt: Optional[str] = None,
tools: Optional[List[Dict]] = None,
tool_choice: Optional[str] = None,
) -> Dict[str, Any]:
"""生成文本"""
pass
@abstractmethod
async def generate_stream(
self,
prompt: str,
model: str,
temperature: float,
max_tokens: int,
system_prompt: Optional[str] = None,
tools: Optional[List[Dict]] = None,
tool_choice: Optional[str] = None,
user_id: Optional[str] = None,
) -> AsyncGenerator[str, None]:
"""流式生成"""
pass
@@ -0,0 +1,159 @@
"""Gemini Provider"""
from typing import Any, AsyncGenerator, Dict, List, Optional
from app.logger import get_logger
from app.services.ai_clients.gemini_client import GeminiClient
from .base_provider import BaseAIProvider
logger = get_logger(__name__)
class GeminiProvider(BaseAIProvider):
def __init__(self, client: GeminiClient):
self.client = client
async def generate(
self,
prompt: str,
model: str,
temperature: float,
max_tokens: int,
system_prompt: Optional[str] = None,
tools: Optional[List[Dict]] = None,
tool_choice: Optional[str] = None,
) -> Dict[str, Any]:
messages = [{"role": "user", "content": prompt}]
return await self.client.chat_completion(
messages=messages,
model=model,
temperature=temperature,
max_tokens=max_tokens,
system_prompt=system_prompt,
tools=tools,
tool_choice=tool_choice,
)
async def generate_stream(
self,
prompt: str,
model: str,
temperature: float,
max_tokens: int,
system_prompt: Optional[str] = None,
tools: Optional[List[Dict]] = None,
tool_choice: Optional[str] = None,
user_id: Optional[str] = None,
) -> AsyncGenerator[str, None]:
# 如果有工具,使用真正的流式工具调用
if tools:
logger.debug(f"🔧 GeminiProvider: 有 {len(tools)} 个工具,使用流式处理")
messages = [{"role": "user", "content": prompt}]
actual_tool_choice = tool_choice if tool_choice else "auto"
tool_calls_buffer = []
async for chunk in self.client.chat_completion_stream(
messages=messages,
model=model,
temperature=temperature,
max_tokens=max_tokens,
system_prompt=system_prompt,
tools=tools,
tool_choice=actual_tool_choice,
):
# 检查是否有工具调用
if chunk.get("tool_calls"):
tool_calls_buffer.extend(chunk["tool_calls"])
logger.debug(f"🔧 收到工具调用: {len(chunk['tool_calls'])}")
# 检查是否结束
if chunk.get("done"):
if tool_calls_buffer:
logger.info(f"🔧 流式结束,处理 {len(tool_calls_buffer)} 个工具调用")
from app.mcp import mcp_client
actual_user_id = user_id or ""
tool_results = await mcp_client.batch_call_tools(
user_id=actual_user_id,
tool_calls=tool_calls_buffer
)
# 将工具结果注入到上下文中
tool_context = mcp_client.build_tool_context(tool_results, format="markdown")
# 构建最终提示词,要求AI基于工具结果回答
final_prompt = f"{prompt}\n\n{tool_context}\n\n请基于以上工具查询结果,给出完整详细的回答。"
final_messages = [{"role": "user", "content": final_prompt}]
# 递归调用生成最终结果
async for final_chunk in self._generate_with_tools(
final_messages, model, temperature, max_tokens, system_prompt, tools, user_id
):
yield final_chunk
break
# 输出文本内容
if chunk.get("content"):
yield chunk["content"]
return
# 无工具时普通流式生成
messages = [{"role": "user", "content": prompt}]
async for chunk in self.client.chat_completion_stream(
messages=messages,
model=model,
temperature=temperature,
max_tokens=max_tokens,
system_prompt=system_prompt,
):
# 确保只 yield 字符串内容,避免 yield 字典导致类型错误
if isinstance(chunk, dict):
if chunk.get("content"):
yield chunk["content"]
else:
yield chunk
async def _generate_with_tools(
self,
messages: list,
model: str,
temperature: float,
max_tokens: int,
system_prompt: Optional[str] = None,
tools: list = None,
user_id: Optional[str] = None,
) -> AsyncGenerator[str, None]:
"""辅助方法:带工具的流式生成"""
tool_calls_buffer = []
async for chunk in self.client.chat_completion_stream(
messages=messages,
model=model,
temperature=temperature,
max_tokens=max_tokens,
system_prompt=system_prompt,
tools=tools,
tool_choice="auto",
):
if chunk.get("tool_calls"):
tool_calls_buffer.extend(chunk["tool_calls"])
logger.debug(f"🔧 _generate_with_tools 收到工具调用: {len(chunk['tool_calls'])}")
if chunk.get("done"):
if tool_calls_buffer:
from app.mcp import mcp_client
actual_user_id = user_id or ""
tool_results = await mcp_client.batch_call_tools(
user_id=actual_user_id,
tool_calls=tool_calls_buffer
)
tool_context = mcp_client.build_tool_context(tool_results, format="markdown")
messages.append({"role": "user", "content": f"{tool_context}\n\n请基于以上工具查询结果,给出完整详细的回答。"})
async for final_chunk in self._generate_with_tools(
messages, model, temperature, max_tokens, system_prompt, tools, user_id
):
yield final_chunk
break
if chunk.get("content"):
yield chunk["content"]
@@ -0,0 +1,161 @@
"""OpenAI Provider"""
from typing import Any, AsyncGenerator, Dict, List, Optional
from app.logger import get_logger
from app.services.ai_clients.openai_client import OpenAIClient
from .base_provider import BaseAIProvider
logger = get_logger(__name__)
class OpenAIProvider(BaseAIProvider):
"""OpenAI 提供商"""
def __init__(self, client: OpenAIClient):
self.client = client
async def generate(
self,
prompt: str,
model: str,
temperature: float,
max_tokens: int,
system_prompt: Optional[str] = None,
tools: Optional[List[Dict]] = None,
tool_choice: Optional[str] = None,
) -> Dict[str, Any]:
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
messages.append({"role": "user", "content": prompt})
return await self.client.chat_completion(
messages=messages,
model=model,
temperature=temperature,
max_tokens=max_tokens,
tools=tools,
tool_choice=tool_choice,
)
async def generate_stream(
self,
prompt: str,
model: str,
temperature: float,
max_tokens: int,
system_prompt: Optional[str] = None,
tools: Optional[List[Dict]] = None,
tool_choice: Optional[str] = None,
user_id: Optional[str] = None,
) -> AsyncGenerator[str, None]:
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
messages.append({"role": "user", "content": prompt})
# 如果有工具,使用真正的流式工具调用
if tools:
logger.debug(f"🔧 OpenAIProvider: 有 {len(tools)} 个工具,使用流式处理")
actual_tool_choice = tool_choice if tool_choice else "auto"
tool_calls_buffer = []
async for chunk in self.client.chat_completion_stream(
messages=messages,
model=model,
temperature=temperature,
max_tokens=max_tokens,
tools=tools,
tool_choice=actual_tool_choice,
):
# 检查是否有工具调用
if chunk.get("tool_calls"):
tool_calls_buffer.extend(chunk["tool_calls"])
logger.debug(f"🔧 收到工具调用: {len(chunk['tool_calls'])}")
# 检查是否结束
if chunk.get("done"):
if tool_calls_buffer:
logger.info(f"🔧 流式结束,处理 {len(tool_calls_buffer)} 个工具调用")
from app.mcp import mcp_client
actual_user_id = user_id or ""
tool_results = await mcp_client.batch_call_tools(
user_id=actual_user_id,
tool_calls=tool_calls_buffer
)
# 将工具结果注入到上下文中
tool_context = mcp_client.build_tool_context(tool_results, format="markdown")
# 构建最终提示词,要求AI基于工具结果回答
final_prompt = f"{prompt}\n\n{tool_context}\n\n请基于以上工具查询结果,给出完整详细的回答。"
final_messages = messages.copy()
final_messages.append({"role": "user", "content": final_prompt})
# 递归调用生成最终结果
async for final_chunk in self._generate_with_tools(
final_messages, model, temperature, max_tokens, tools, user_id
):
yield final_chunk
break
# 输出文本内容
if chunk.get("content"):
yield chunk["content"]
return
# 无工具时普通流式生成
async for chunk in self.client.chat_completion_stream(
messages=messages,
model=model,
temperature=temperature,
max_tokens=max_tokens,
):
# 确保只 yield 字符串内容,避免 yield 字典导致类型错误
if isinstance(chunk, dict):
if chunk.get("content"):
yield chunk["content"]
else:
yield chunk
async def _generate_with_tools(
self,
messages: list,
model: str,
temperature: float,
max_tokens: int,
tools: list,
user_id: Optional[str] = None,
) -> AsyncGenerator[str, None]:
"""辅助方法:带工具的流式生成(无tool_choice,AI自由决定)"""
async for chunk in self.client.chat_completion_stream(
messages=messages,
model=model,
temperature=temperature,
max_tokens=max_tokens,
tools=tools,
tool_choice="auto",
):
if chunk.get("tool_calls"):
from app.mcp import mcp_client
actual_user_id = user_id or ""
tool_results = await mcp_client.batch_call_tools(
user_id=actual_user_id,
tool_calls=chunk["tool_calls"]
)
tool_context = mcp_client.build_tool_context(tool_results, format="markdown")
# 再次调用获取最终回答
messages.append({"role": "user", "content": f"{tool_context}\n\n请基于以上工具查询结果,给出完整详细的回答。"})
async for final_chunk in self._generate_with_tools(
messages, model, temperature, max_tokens, tools, user_id
):
yield final_chunk
break
if chunk.get("done"):
break
if chunk.get("content"):
yield chunk["content"]
+475 -394
View File
@@ -1,17 +1,68 @@
"""AI服务封装 - 统一的OpenAI和Claude接口"""
from typing import Optional, AsyncGenerator, List, Dict, Any
from openai import AsyncOpenAI
from anthropic import AsyncAnthropic
"""AI服务封装 - 统一的AI接口
重构后支持自动MCP工具加载
- 所有AI方法在请求前自动检查用户MCP配置
- 如果有启用的MCP插件且有可用工具自动发送tools
- 通过 auto_mcp 参数控制是否启用自动工具加载
"""
from typing import Optional, AsyncGenerator, List, Dict, Any, Union
from app.config import settings as app_settings
from app.logger import get_logger
import httpx
from app.services.ai_config import AIClientConfig, default_config
from app.services.ai_clients.openai_client import OpenAIClient
from app.services.ai_clients.anthropic_client import AnthropicClient
from app.services.ai_clients.gemini_client import GeminiClient
from app.services.ai_clients.base_client import cleanup_all_clients
from app.services.ai_providers.openai_provider import OpenAIProvider
from app.services.ai_providers.anthropic_provider import AnthropicProvider
from app.services.ai_providers.gemini_provider import GeminiProvider
from app.services.ai_providers.base_provider import BaseAIProvider
from app.services.json_helper import clean_json_response, parse_json
# 导出清理函数
cleanup_http_clients = cleanup_all_clients
logger = get_logger(__name__)
class AIService:
"""AI服务统一接口 - 支持从用户设置或全局配置初始化"""
"""
AI服务统一接口
MCP工具支持
- 在创建服务时传入 user_id db_session
- 根据用户MCP插件的enabled状态自动决定是否启用MCP
- 如果有任意一个MCP插件启用则加载并使用工具
- 如果所有插件都关闭则不使用任何MCP工具
- 通过 auto_mcp=False 可临时禁用自动工具加载
- 通过 mcp_max_rounds 控制工具调用轮数
- 通过 clear_mcp_cache() 可清理MCP工具缓存
MCP启用逻辑backend/app/api/settings.py 中的 get_user_ai_service
- 查询用户的所有MCP插件
- 如果有启用的插件 (enabled=True) enable_mcp=True
- 如果所有插件都关闭或没有插件 enable_mcp=False
使用示例
# 创建支持MCP的AI服务(根据插件状态自动决定是否启用)
ai_service = create_user_ai_service_with_mcp(
api_provider="openai",
api_key="...",
user_id="user123",
db_session=db
)
# 自动加载MCP工具(如果有启用的插件)
result = await ai_service.generate_text(prompt="...")
# 临时禁用MCP工具
result = await ai_service.generate_text(prompt="...", auto_mcp=False)
# 自定义轮数
result = await ai_service.generate_text(prompt="...", mcp_max_rounds=3)
"""
def __init__(
self,
api_provider: Optional[str] = None,
@@ -19,106 +70,252 @@ class AIService:
api_base_url: Optional[str] = None,
default_model: Optional[str] = None,
default_temperature: Optional[float] = None,
default_max_tokens: Optional[int] = None
default_max_tokens: Optional[int] = None,
default_system_prompt: Optional[str] = None,
config: Optional[AIClientConfig] = None,
# MCP支持参数
user_id: Optional[str] = None,
db_session: Optional[Any] = None,
enable_mcp: bool = True,
):
"""
初始化AI客户端优化并发性能
Args:
api_provider: API提供商 (openai/anthropic)为None时使用全局配置
api_key: API密钥为None时使用全局配置
api_base_url: API基础URL为None时使用全局配置
default_model: 默认模型为None时使用全局配置
default_temperature: 默认温度为None时使用全局配置
default_max_tokens: 默认最大tokens为None时使用全局配置
"""
# 保存用户设置或使用全局配置
self.api_provider = api_provider or app_settings.default_ai_provider
self.default_model = default_model or app_settings.default_model
self.default_temperature = default_temperature or app_settings.default_temperature
self.default_max_tokens = default_max_tokens or app_settings.default_max_tokens
self.default_system_prompt = default_system_prompt
self.config = config or default_config
# 初始化OpenAI客户端
# MCP配置
self.user_id = user_id
self.db_session = db_session
self._enable_mcp = enable_mcp
self._cached_tools: Optional[List[Dict]] = None
self._tools_loaded = False
self._openai_provider: Optional[OpenAIProvider] = None
self._anthropic_provider: Optional[AnthropicProvider] = None
self._gemini_provider: Optional[GeminiProvider] = None
# 初始化 OpenAI
openai_key = api_key if api_provider == "openai" else app_settings.openai_api_key
if openai_key:
try:
limits = httpx.Limits(
max_keepalive_connections=50,
max_connections=100,
keepalive_expiry=30.0
)
http_client = httpx.AsyncClient(
timeout=httpx.Timeout(connect=60.0, read=180.0, write=60.0, pool=60.0),
limits=limits,
headers={
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36"
}
)
client_kwargs = {
"api_key": openai_key,
"http_client": http_client
}
base_url = api_base_url if api_provider == "openai" else app_settings.openai_base_url
if base_url:
client_kwargs["base_url"] = base_url
self.openai_client = AsyncOpenAI(**client_kwargs)
self.openai_http_client = http_client
self.openai_api_key = openai_key
self.openai_base_url = base_url
logger.info("✅ OpenAI客户端初始化成功")
except Exception as e:
logger.error(f"OpenAI客户端初始化失败: {e}")
self.openai_client = None
self.openai_http_client = None
self.openai_api_key = None
self.openai_base_url = None
else:
self.openai_client = None
self.openai_http_client = None
self.openai_api_key = None
self.openai_base_url = None
logger.warning("OpenAI API key未配置")
base_url = api_base_url if api_provider == "openai" else app_settings.openai_base_url
client = OpenAIClient(openai_key, base_url or "https://api.openai.com/v1", self.config)
self._openai_provider = OpenAIProvider(client)
# 初始化Anthropic客户端
# 初始化 Anthropic
anthropic_key = api_key if api_provider == "anthropic" else app_settings.anthropic_api_key
if anthropic_key:
try:
limits = httpx.Limits(
max_keepalive_connections=50,
max_connections=100,
keepalive_expiry=30.0
)
http_client = httpx.AsyncClient(
timeout=httpx.Timeout(connect=60.0, read=180.0, write=60.0, pool=60.0),
limits=limits,
headers={
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36"
}
)
client_kwargs = {
"api_key": anthropic_key,
"http_client": http_client
}
base_url = api_base_url if api_provider == "anthropic" else app_settings.anthropic_base_url
if base_url:
client_kwargs["base_url"] = base_url
self.anthropic_client = AsyncAnthropic(**client_kwargs)
logger.info("✅ Anthropic客户端初始化成功")
except Exception as e:
logger.error(f"Anthropic客户端初始化失败: {e}")
self.anthropic_client = None
else:
self.anthropic_client = None
logger.warning("Anthropic API key未配置")
base_url = api_base_url if api_provider == "anthropic" else app_settings.anthropic_base_url
client = AnthropicClient(anthropic_key, base_url, self.config)
self._anthropic_provider = AnthropicProvider(client)
# 初始化 Gemini
if api_provider == "gemini" and api_key:
client = GeminiClient(api_key, api_base_url, self.config)
self._gemini_provider = GeminiProvider(client)
@property
def enable_mcp(self) -> bool:
"""是否启用MCP工具"""
return self._enable_mcp
@enable_mcp.setter
def enable_mcp(self, value: bool):
"""设置MCP启用状态,如果禁用则清理缓存"""
if value is False and self._enable_mcp is True:
# 从启用变为禁用,清理缓存
self.clear_mcp_cache()
self._enable_mcp = value
def clear_mcp_cache(self):
"""
清理MCP工具缓存
当禁用MCP时调用此方法确保后续AI调用不会使用缓存的工具
同时更新 _tools_loaded 状态使下次调用时重新检查
"""
if self._cached_tools is not None:
logger.info(f"🔧 清理MCP工具缓存,移除 {len(self._cached_tools)} 个工具")
self._cached_tools = None
else:
logger.debug(f"🔧 MCP工具缓存已经是空,无需清理")
# 更新加载状态,确保下次调用会重新检查
self._tools_loaded = False
logger.debug(f"🔧 MCP工具状态已重置: enable_mcp={self._enable_mcp}, _tools_loaded=False")
def _get_provider(self, provider: Optional[str] = None) -> BaseAIProvider:
"""获取对应的 Provider"""
p = provider or self.api_provider
if p == "openai" and self._openai_provider:
return self._openai_provider
if p == "anthropic" and self._anthropic_provider:
return self._anthropic_provider
if p == "gemini" and self._gemini_provider:
return self._gemini_provider
raise ValueError(f"Provider {p} 未初始化")
async def _prepare_mcp_tools(self, auto_mcp: bool = True, force_refresh: bool = False) -> Optional[List[Dict]]:
"""
预处理MCP工具
检查用户MCP配置并加载可用工具
结果会被缓存避免重复加载
Args:
auto_mcp: 是否自动加载MCP工具来自调用方参数
force_refresh: 是否强制刷新缓存
Returns:
- None: 无可用工具未配置/未启用/加载失败
- List[Dict]: OpenAI格式的工具列表
"""
# 前置条件检查
if not self._enable_mcp:
logger.debug(f"🔧 MCP工具未启用 (_enable_mcp=False)")
# 即使有缓存也清理掉,确保不使用
self._cached_tools = None
self._tools_loaded = False
return None
if not auto_mcp:
logger.debug(f"🔧 auto_mcp=False,跳过MCP工具加载")
# 即使有缓存也清理掉,确保不使用
self._cached_tools = None
self._tools_loaded = False
return None
if not self.user_id:
logger.debug(f"🔧 MCP工具加载跳过: user_id未设置")
return None
if not self.db_session:
logger.debug(f"🔧 MCP工具加载跳过: db_session未设置")
return None
# 使用缓存(只有 enable_mcp=True 时才使用缓存)
if self._tools_loaded and not force_refresh:
if self._cached_tools:
logger.debug(f"🔧 使用缓存的MCP工具 ({len(self._cached_tools)}个)")
return self._cached_tools
try:
from app.services.mcp_tools_loader import mcp_tools_loader
self._cached_tools = await mcp_tools_loader.get_user_tools(
user_id=self.user_id,
db_session=self.db_session,
use_cache=True,
force_refresh=force_refresh
)
self._tools_loaded = True
if self._cached_tools:
logger.info(f"🔧 已加载 {len(self._cached_tools)} 个MCP工具")
else:
logger.debug(f"📭 用户 {self.user_id} 没有可用的MCP工具")
return self._cached_tools
except Exception as e:
logger.warning(f"⚠️ 加载MCP工具失败: {e}")
self._tools_loaded = True
self._cached_tools = None
return None
async def _handle_tool_calls(
self,
original_prompt: str,
response: Dict[str, Any],
max_rounds: int = 2,
**kwargs
) -> Dict[str, Any]:
"""
处理AI返回的工具调用
Args:
original_prompt: 原始提示词
response: AI响应包含tool_calls
max_rounds: 最大工具调用轮数
**kwargs: 传递给generate_text的其他参数
Returns:
最终的AI响应
"""
from app.mcp import mcp_client
tool_calls = response.get("tool_calls", [])
if not tool_calls or not self.user_id:
return response
result = {
"content": response.get("content", ""),
"tool_calls_made": 0,
"tools_used": [],
"finish_reason": response.get("finish_reason", ""),
"mcp_enhanced": True
}
prompt = original_prompt
for round_num in range(max_rounds):
logger.info(f"🔧 工具调用 - 第{round_num+1}/{max_rounds}轮,{len(tool_calls)}个工具")
try:
# 批量执行工具调用
tool_results = await mcp_client.batch_call_tools(
user_id=self.user_id,
tool_calls=tool_calls
)
# 记录使用的工具
for tc in tool_calls:
name = tc["function"]["name"]
if name not in result["tools_used"]:
result["tools_used"].append(name)
result["tool_calls_made"] += len(tool_calls)
# 构建工具上下文
tool_context = mcp_client.build_tool_context(tool_results, format="markdown")
# 更新提示词
if round_num == max_rounds - 1:
# 最后一轮,强制要求回答
prompt = f"{original_prompt}\n\n{tool_context}\n\n⚠️ 重要:请基于以上工具查询结果,给出完整详细的最终答案。不要再调用工具。"
tool_choice = "none"
else:
prompt = f"{original_prompt}\n\n{tool_context}\n\n请基于以上工具查询结果,继续完成任务。"
tool_choice = kwargs.get("tool_choice", "auto")
# 继续调用AI
prov = self._get_provider(kwargs.get("provider"))
next_response = await prov.generate(
prompt=prompt,
model=kwargs.get("model") or self.default_model,
temperature=kwargs.get("temperature") or self.default_temperature,
max_tokens=kwargs.get("max_tokens") or self.default_max_tokens,
system_prompt=kwargs.get("system_prompt") or self.default_system_prompt,
tools=None if tool_choice == "none" else self._cached_tools,
tool_choice=tool_choice,
)
tool_calls = next_response.get("tool_calls", [])
if not tool_calls:
# 没有更多工具调用,返回结果
result["content"] = next_response.get("content", "")
result["finish_reason"] = next_response.get("finish_reason", "stop")
break
except Exception as e:
logger.error(f"❌ 工具调用失败: {e}")
result["content"] = response.get("content", "")
result["finish_reason"] = "tool_error"
break
return result
async def generate_text(
self,
prompt: str,
@@ -126,38 +323,67 @@ class AIService:
model: Optional[str] = None,
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
system_prompt: Optional[str] = None
) -> str:
system_prompt: Optional[str] = None,
tools: Optional[List[Dict]] = None,
tool_choice: Optional[str] = None,
auto_mcp: bool = True,
handle_tool_calls: bool = True,
mcp_max_rounds: Optional[int] = None,
) -> Dict[str, Any]:
"""
生成文本
生成文本自动支持MCP工具
Args:
prompt: 用户提示词
provider: AI提供商 (openai/anthropic)
provider: AI提供商
model: 模型名称
temperature: 温度参数
max_tokens: 最大token
temperature: 温度
max_tokens: 最大令牌
system_prompt: 系统提示词
tools: 手动指定的工具列表优先级高于自动加载
tool_choice: 工具选择策略
auto_mcp: 是否自动加载MCP工具默认True
handle_tool_calls: 是否自动处理工具调用默认True
mcp_max_rounds: 最大工具调用轮数None使用默认值3
Returns:
生成的文本
包含生成内容的字典
"""
provider = provider or self.api_provider
model = model or self.default_model
temperature = temperature or self.default_temperature
max_tokens = max_tokens or self.default_max_tokens
# 使用全局配置的MCP轮数(如果未指定)
if mcp_max_rounds is None:
mcp_max_rounds = app_settings.mcp_max_rounds
if provider == "openai":
return await self._generate_openai(
prompt, model, temperature, max_tokens, system_prompt
# 自动加载MCP工具
if auto_mcp and tools is None:
tools = await self._prepare_mcp_tools(auto_mcp=auto_mcp)
prov = self._get_provider(provider)
response = await prov.generate(
prompt=prompt,
model=model or self.default_model,
temperature=temperature or self.default_temperature,
max_tokens=max_tokens or self.default_max_tokens,
system_prompt=system_prompt or self.default_system_prompt,
tools=tools,
tool_choice=tool_choice,
)
# 处理工具调用
if handle_tool_calls and response.get("tool_calls"):
return await self._handle_tool_calls(
original_prompt=prompt,
response=response,
provider=provider,
model=model,
temperature=temperature,
max_tokens=max_tokens,
system_prompt=system_prompt,
tool_choice=tool_choice,
max_rounds=mcp_max_rounds,
)
elif provider == "anthropic":
return await self._generate_anthropic(
prompt, model, temperature, max_tokens, system_prompt
)
else:
raise ValueError(f"不支持的AI提供商: {provider}")
return response
async def generate_text_stream(
self,
prompt: str,
@@ -165,301 +391,123 @@ class AIService:
model: Optional[str] = None,
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
system_prompt: Optional[str] = None
system_prompt: Optional[str] = None,
tool_choice: Optional[str] = None,
auto_mcp: bool = True,
mcp_max_rounds: Optional[int] = None,
) -> AsyncGenerator[str, None]:
"""
流式生成文本
流式生成文本自动支持MCP工具
工具调用在 Provider 层通过流式方式处理支持真正的流式工具调用
Args:
prompt: 用户提示词
provider: AI提供商
model: 模型名称
temperature: 温度参数
max_tokens: 最大token
temperature: 温度
max_tokens: 最大令牌
system_prompt: 系统提示词
tool_choice: 工具选择策略"auto"/"none"/"required"
auto_mcp: 是否自动加载MCP工具
mcp_max_rounds: 最大工具调用轮数None使用默认值3
Yields:
生成的文本片段
生成的文本
"""
provider = provider or self.api_provider
model = model or self.default_model
temperature = temperature or self.default_temperature
max_tokens = max_tokens or self.default_max_tokens
logger.debug(f"🔧 generate_text_stream: auto_mcp={auto_mcp}, tool_choice={tool_choice}")
if provider == "openai":
async for chunk in self._generate_openai_stream(
prompt, model, temperature, max_tokens, system_prompt
):
yield chunk
elif provider == "anthropic":
async for chunk in self._generate_anthropic_stream(
prompt, model, temperature, max_tokens, system_prompt
):
yield chunk
else:
raise ValueError(f"不支持的AI提供商: {provider}")
async def _generate_openai(
tools_to_use = None
# 加载MCP工具
if auto_mcp:
tools_to_use = await self._prepare_mcp_tools(auto_mcp=auto_mcp)
if tools_to_use:
logger.info(f"🔧 已获取 {len(tools_to_use)} 个MCP工具")
# 流式生成(Provider 层处理工具调用)
prov = self._get_provider(provider)
logger.debug(f"🔧 开始流式生成,provider={provider or self.api_provider}, tools_count={len(tools_to_use) if tools_to_use else 0}")
async for chunk in prov.generate_stream(
prompt=prompt,
model=model or self.default_model,
temperature=temperature or self.default_temperature,
max_tokens=max_tokens or self.default_max_tokens,
system_prompt=system_prompt or self.default_system_prompt,
tools=tools_to_use,
tool_choice=tool_choice,
user_id=self.user_id,
):
yield chunk
async def call_with_json_retry(
self,
prompt: str,
model: str,
temperature: float,
max_tokens: int,
system_prompt: Optional[str]
) -> str:
"""使用OpenAI生成文本"""
if not self.openai_http_client:
raise ValueError("OpenAI客户端未初始化,请检查API key配置")
system_prompt: Optional[str] = None,
max_retries: int = 3,
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
provider: Optional[str] = None,
model: Optional[str] = None,
expected_type: Optional[str] = None,
auto_mcp: bool = True,
) -> Union[Dict, List]:
"""
带重试的 JSON 调用自动支持MCP工具
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
messages.append({"role": "user", "content": prompt})
Args:
prompt: 用户提示词
system_prompt: 系统提示词
max_retries: 最大重试次数
temperature: 温度
max_tokens: 最大令牌数
provider: AI提供商
model: 模型名称
expected_type: 期望的返回类型"object""array"
auto_mcp: 是否自动加载MCP工具
Returns:
解析后的JSON数据
"""
last_response = ""
try:
logger.info(f"🔵 开始调用OpenAI API(直接HTTP请求)")
logger.info(f" - 模型: {model}")
logger.info(f" - 温度: {temperature}")
logger.info(f" - 最大tokens: {max_tokens}")
logger.info(f" - Prompt长度: {len(prompt)} 字符")
logger.info(f" - 消息数量: {len(messages)}")
for attempt in range(1, max_retries + 1):
current_prompt = prompt if attempt == 1 else self._add_json_hint(prompt, last_response, attempt)
url = f"{self.openai_base_url}/chat/completions"
headers = {
"Authorization": f"Bearer {self.openai_api_key}",
"Content-Type": "application/json"
}
payload = {
"model": model,
"messages": messages,
"temperature": temperature,
"max_tokens": max_tokens
}
logger.debug(f" - 请求URL: {url}")
logger.debug(f" - 请求头: Authorization=Bearer ***")
response = await self.openai_http_client.post(url, headers=headers, json=payload)
response.raise_for_status()
data = response.json()
logger.info(f"✅ OpenAI API调用成功")
logger.info(f" - 响应ID: {data.get('id', 'N/A')}")
logger.info(f" - 选项数量: {len(data.get('choices', []))}")
if not data.get('choices'):
logger.error("❌ OpenAI返回的choices为空")
raise ValueError("API返回的响应格式错误:choices字段为空")
choice = data['choices'][0]
message = choice.get('message', {})
finish_reason = choice.get('finish_reason')
# DeepSeek R1特殊处理:只使用content(最终答案),忽略reasoning_content(思考过程)
# reasoning_content是AI的思考过程,不是我们需要的JSON结果
content = message.get('content', '')
# 检查是否因达到长度限制而截断
if finish_reason == 'length':
logger.warning(f"⚠️ 响应因达到max_tokens限制而被截断")
logger.warning(f" - 当前max_tokens: {max_tokens}")
logger.warning(f" - 建议: 增加max_tokens参数(推荐2000+")
if content:
logger.info(f" - 返回内容长度: {len(content)} 字符")
logger.info(f" - 完成原因: {finish_reason}")
logger.info(f" - 返回内容预览(前200字符): {content[:200]}")
return content
else:
logger.error("❌ AI返回了空内容")
logger.error(f" - 完整响应: {data}")
logger.error(f" - 完成原因: {finish_reason}")
# 提供更详细的错误信息
if finish_reason == 'length':
raise ValueError(f"AI响应被截断且无有效内容。请增加max_tokens参数(当前: {max_tokens},建议: 2000+")
else:
raise ValueError(f"AI返回了空内容(finish_reason: {finish_reason}),请检查API配置或稍后重试")
except httpx.HTTPStatusError as e:
logger.error(f"❌ OpenAI API调用失败 (HTTP {e.response.status_code})")
logger.error(f" - 错误信息: {e.response.text}")
logger.error(f" - 模型: {model}")
raise Exception(f"API返回错误 ({e.response.status_code}): {e.response.text}")
except Exception as e:
logger.error(f"❌ OpenAI API调用失败")
logger.error(f" - 错误类型: {type(e).__name__}")
logger.error(f" - 错误信息: {str(e)}")
logger.error(f" - 模型: {model}")
raise
async def _generate_openai_stream(
self,
prompt: str,
model: str,
temperature: float,
max_tokens: int,
system_prompt: Optional[str]
) -> AsyncGenerator[str, None]:
"""使用OpenAI流式生成文本"""
if not self.openai_http_client:
raise ValueError("OpenAI客户端未初始化,请检查API key配置")
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
messages.append({"role": "user", "content": prompt})
try:
logger.info(f"🔵 开始调用OpenAI流式API(直接HTTP请求)")
logger.info(f" - 模型: {model}")
logger.info(f" - Prompt长度: {len(prompt)} 字符")
logger.info(f" - 最大tokens: {max_tokens}")
url = f"{self.openai_base_url}/chat/completions"
headers = {
"Authorization": f"Bearer {self.openai_api_key}",
"Content-Type": "application/json"
}
payload = {
"model": model,
"messages": messages,
"temperature": temperature,
"max_tokens": max_tokens,
"stream": True
}
async with self.openai_http_client.stream('POST', url, headers=headers, json=payload) as response:
response.raise_for_status()
logger.info(f"✅ OpenAI流式API连接成功,开始接收数据...")
chunk_count = 0
has_content = False
finish_reason = None
async for line in response.aiter_lines():
if line.startswith('data: '):
data_str = line[6:]
if data_str.strip() == '[DONE]':
break
try:
import json
data = json.loads(data_str)
if 'choices' in data and len(data['choices']) > 0:
choice = data['choices'][0]
delta = choice.get('delta', {})
finish_reason = choice.get('finish_reason') or finish_reason
# DeepSeek R1特殊处理:只收集content(最终答案),忽略reasoning_content(思考过程)
# reasoning_content是AI的思考过程,不是我们需要的JSON结果
content = delta.get('content', '')
if content:
chunk_count += 1
has_content = True
yield content
except json.JSONDecodeError:
continue
# 检查是否因长度限制截断
if finish_reason == 'length':
logger.warning(f"⚠️ 流式响应因达到max_tokens限制而被截断")
logger.warning(f" - 当前max_tokens: {max_tokens}")
logger.warning(f" - 建议: 增加max_tokens参数(推荐2000+")
if not has_content:
logger.warning(f"⚠️ 流式响应未返回任何内容")
logger.warning(f" - 完成原因: {finish_reason}")
logger.info(f"✅ OpenAI流式生成完成,共接收 {chunk_count} 个chunk,完成原因: {finish_reason}")
except httpx.TimeoutException as e:
logger.error(f"❌ OpenAI流式API超时")
logger.error(f" - 错误: {str(e)}")
logger.error(f" - 提示: 请检查网络连接或考虑缩短prompt长度")
raise TimeoutError(f"AI服务超时(180秒),请稍后重试或减少上下文长度") from e
except httpx.HTTPStatusError as e:
logger.error(f"❌ OpenAI流式API调用失败 (HTTP {e.response.status_code})")
logger.error(f" - 错误信息: {await e.response.aread()}")
raise
except Exception as e:
logger.error(f"❌ OpenAI流式API调用失败: {str(e)}")
logger.error(f" - 错误类型: {type(e).__name__}")
raise
async def _generate_anthropic(
self,
prompt: str,
model: str,
temperature: float,
max_tokens: int,
system_prompt: Optional[str]
) -> str:
"""使用Anthropic生成文本"""
if not self.anthropic_client:
raise ValueError("Anthropic客户端未初始化,请检查API key配置")
try:
response = await self.anthropic_client.messages.create(
result = await self.generate_text(
prompt=current_prompt,
provider=provider,
model=model,
max_tokens=max_tokens,
temperature=temperature,
system=system_prompt or "",
messages=[{"role": "user", "content": prompt}]
max_tokens=max_tokens,
system_prompt=system_prompt,
auto_mcp=auto_mcp,
handle_tool_calls=True,
)
return response.content[0].text
except Exception as e:
logger.error(f"Anthropic API调用失败: {str(e)}")
raise
async def _generate_anthropic_stream(
self,
prompt: str,
model: str,
temperature: float,
max_tokens: int,
system_prompt: Optional[str]
) -> AsyncGenerator[str, None]:
"""使用Anthropic流式生成文本"""
if not self.anthropic_client:
raise ValueError("Anthropic客户端未初始化,请检查API key配置")
try:
logger.info(f"🔵 开始调用Anthropic流式API")
logger.info(f" - 模型: {model}")
logger.info(f" - Prompt长度: {len(prompt)} 字符")
logger.info(f" - 最大tokens: {max_tokens}")
async with self.anthropic_client.messages.stream(
model=model,
max_tokens=max_tokens,
temperature=temperature,
system=system_prompt or "",
messages=[{"role": "user", "content": prompt}]
) as stream:
logger.info(f"✅ Anthropic流式API连接成功,开始接收数据...")
chunk_count = 0
async for text in stream.text_stream:
chunk_count += 1
yield text
logger.info(f"✅ Anthropic流式生成完成,共接收 {chunk_count} 个chunk")
except httpx.TimeoutException as e:
logger.error(f"❌ Anthropic流式API超时")
logger.error(f" - 错误: {str(e)}")
raise TimeoutError(f"AI服务超时(180秒),请稍后重试或减少上下文长度") from e
except Exception as e:
logger.error(f"❌ Anthropic流式API调用失败: {str(e)}")
logger.error(f" - 错误类型: {type(e).__name__}")
raise
last_response = result.get("content", "")
try:
data = parse_json(last_response)
if expected_type == "object" and not isinstance(data, dict):
raise ValueError("期望对象")
if expected_type == "array" and not isinstance(data, list):
raise ValueError("期望数组")
return data
except Exception as e:
if attempt == max_retries:
raise ValueError(f"JSON 解析失败: {e}")
raise ValueError("JSON 调用失败")
@staticmethod
def _add_json_hint(prompt: str, failed: str, attempt: int) -> str:
return f"{prompt}\n\n⚠️ 第{attempt}次重试,请返回纯JSON,不要markdown包裹。上次错误: {failed[:200]}..."
# 创建全局AI服务实例
ai_service = AIService()
@staticmethod
def _clean_json_response(text: str) -> str:
"""清洗 JSON 响应"""
return clean_json_response(text)
def create_user_ai_service(
@@ -468,21 +516,50 @@ def create_user_ai_service(
api_base_url: str,
model_name: str,
temperature: float,
max_tokens: int
max_tokens: int,
system_prompt: Optional[str] = None,
) -> AIService:
"""创建用户 AI 服务(不带MCP支持)"""
return AIService(
api_provider=api_provider,
api_key=api_key,
api_base_url=api_base_url,
default_model=model_name,
default_temperature=temperature,
default_max_tokens=max_tokens,
default_system_prompt=system_prompt,
)
def create_user_ai_service_with_mcp(
api_provider: str,
api_key: str,
api_base_url: str,
model_name: str,
temperature: float,
max_tokens: int,
user_id: str,
db_session,
system_prompt: Optional[str] = None,
enable_mcp: bool = True,
) -> AIService:
"""
根据用户设置创建AI服务实例
创建支持MCP的用户AI服务
Args:
api_provider: API提供商
api_provider: AI提供商
api_key: API密钥
api_base_url: API基础URL
model_name: 模型名称
temperature: 温度参数
max_tokens: 最大tokens
temperature: 温度
max_tokens: 最大令牌数
user_id: 用户ID用于加载MCP工具
db_session: 数据库会话
system_prompt: 系统提示词
enable_mcp: 是否启用MCP工具
Returns:
AIService实例
配置好的AIService实例
"""
return AIService(
api_provider=api_provider,
@@ -490,5 +567,9 @@ def create_user_ai_service(
api_base_url=api_base_url,
default_model=model_name,
default_temperature=temperature,
default_max_tokens=max_tokens
default_max_tokens=max_tokens,
default_system_prompt=system_prompt,
user_id=user_id,
db_session=db_session,
enable_mcp=enable_mcp,
)
@@ -0,0 +1,606 @@
"""自动角色引入服务 - 在续写大纲时根据剧情推进自动引入新角色"""
from typing import List, Dict, Any, Optional, Callable, Awaitable
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
import json
from app.models.character import Character
from app.models.relationship import CharacterRelationship, Organization, OrganizationMember, RelationshipType
from app.models.project import Project
from app.services.ai_service import AIService
from app.services.prompt_service import PromptService
from app.logger import get_logger
logger = get_logger(__name__)
class AutoCharacterService:
"""自动角色引入服务"""
def __init__(self, ai_service: AIService):
self.ai_service = ai_service
async def analyze_and_create_characters(
self,
project_id: str,
outline_content: str,
existing_characters: List[Character],
db: AsyncSession,
user_id: str = None,
enable_mcp: bool = True,
all_chapters_brief: str = "",
start_chapter: int = 1,
chapter_count: int = 3,
plot_stage: str = "发展",
story_direction: str = "继续推进主线剧情",
preview_only: bool = False,
progress_callback: Optional[Callable[[str], Awaitable[None]]] = None
) -> Dict[str, Any]:
"""
预测性分析并创建需要的新角色方案A先角色后大纲
Args:
project_id: 项目ID
outline_content: 当前批次大纲内容用于向后兼容实际不使用
existing_characters: 现有角色列表
db: 数据库会话
user_id: 用户ID(用于MCP和自定义提示词)
enable_mcp: 是否启用MCP增强
all_chapters_brief: 已有章节概览
start_chapter: 起始章节号
chapter_count: 续写章节数
plot_stage: 剧情阶段
story_direction: 故事发展方向
preview_only: 仅预测不创建用于角色确认机制
Returns:
{
"new_characters": [角色对象列表], # preview_only=True时为空
"relationships_created": [关系对象列表], # preview_only=True时为空
"character_count": 新增角色数量,
"analysis_result": AI分析结果,
"predicted_characters": [预测的角色数据] # 仅preview_only=True时返回
"needs_new_characters": bool,
"reason": str
}
"""
logger.info(f"🎭 【方案A】预测性分析:检测是否需要引入新角色...")
logger.info(f" - 项目ID: {project_id}")
logger.info(f" - 续写计划: 第{start_chapter}章起,共{chapter_count}")
logger.info(f" - 剧情阶段: {plot_stage}")
logger.info(f" - 发展方向: {story_direction}")
logger.info(f" - 现有角色数: {len(existing_characters)}")
# 1. 获取项目信息
project_result = await db.execute(
select(Project).where(Project.id == project_id)
)
project = project_result.scalar_one_or_none()
if not project:
raise ValueError("项目不存在")
# 2. 构建现有角色信息摘要
existing_chars_summary = self._build_character_summary(existing_characters)
# 3. AI预测性分析是否需要新角色
analysis_result = await self._analyze_character_needs(
project=project,
outline_content=outline_content, # 保留参数向后兼容
existing_chars_summary=existing_chars_summary,
db=db,
user_id=user_id,
enable_mcp=enable_mcp,
all_chapters_brief=all_chapters_brief,
start_chapter=start_chapter,
chapter_count=chapter_count,
plot_stage=plot_stage,
story_direction=story_direction
)
# 4. 判断是否需要创建角色
if not analysis_result or not analysis_result.get("needs_new_characters"):
logger.info("✅ AI判断:当前剧情不需要引入新角色")
return {
"new_characters": [],
"relationships_created": [],
"character_count": 0,
"analysis_result": analysis_result,
"predicted_characters": [],
"needs_new_characters": False,
"reason": analysis_result.get("reason", "当前剧情不需要新角色")
}
# 5. 如果是预览模式,仅返回预测结果,不创建角色
if preview_only:
character_specs = analysis_result.get("character_specifications", [])
logger.info(f"🔮 预览模式:预测到 {len(character_specs)} 个角色,不创建数据库记录")
return {
"new_characters": [],
"relationships_created": [],
"character_count": 0,
"analysis_result": analysis_result,
"predicted_characters": character_specs,
"needs_new_characters": True,
"reason": analysis_result.get("reason", "预测需要新角色")
}
# 6. 批量生成新角色(非预览模式)
new_characters = []
relationships_created = []
character_specs = analysis_result.get("character_specifications", [])
logger.info(f"🎯 AI建议引入 {len(character_specs)} 个新角色")
for idx, spec in enumerate(character_specs):
try:
spec_name = spec.get('name', spec.get('role_description', '未命名'))
logger.info(f" [{idx+1}/{len(character_specs)}] 生成角色规格: {spec_name}")
logger.debug(f" 角色规格内容: {json.dumps(spec, ensure_ascii=False)}")
if progress_callback:
await progress_callback(f"🎨 [{idx+1}/{len(character_specs)}] 生成角色详情: {spec_name}")
# 生成角色详细信息
character_data = await self._generate_character_details(
spec=spec,
project=project,
existing_characters=existing_characters + new_characters, # 包含新创建的
db=db,
user_id=user_id,
enable_mcp=enable_mcp
)
logger.debug(f" AI生成的角色数据: {json.dumps(character_data, ensure_ascii=False)[:200]}")
if progress_callback:
await progress_callback(f"💾 [{idx+1}/{len(character_specs)}] 保存角色: {character_data.get('name', spec_name)}")
# 创建角色记录
character = await self._create_character_record(
project_id=project_id,
character_data=character_data,
db=db
)
new_characters.append(character)
logger.info(f" ✅ 创建新角色: {character.name} ({character.role_type}), ID: {character.id}")
if progress_callback:
await progress_callback(f"✅ [{idx+1}/{len(character_specs)}] 角色创建成功: {character.name}")
# 建立关系(兼容两种字段名)
relationships_data = character_data.get("relationships") or character_data.get("relationships_array", [])
logger.info(f" 🔍 检查关系数据:")
logger.info(f" - relationships字段: {character_data.get('relationships')}")
logger.info(f" - relationships_array字段: {character_data.get('relationships_array')}")
logger.info(f" - 最终使用的数据: {relationships_data}")
logger.info(f" - 关系数量: {len(relationships_data) if relationships_data else 0}")
if relationships_data:
logger.info(f" 🔗 开始创建 {len(relationships_data)} 条关系...")
for idx, rel in enumerate(relationships_data):
logger.info(f" [{idx+1}] {rel.get('target_character_name')} - {rel.get('relationship_type')}")
if progress_callback:
await progress_callback(f"🔗 [{idx+1}/{len(character_specs)}] 建立 {len(relationships_data)} 个关系")
else:
logger.warning(f" ⚠️ AI返回的角色数据中没有关系信息!")
logger.warning(f" 完整的character_data keys: {list(character_data.keys())}")
rels = await self._create_relationships(
new_character=character,
relationship_specs=relationships_data,
existing_characters=existing_characters + new_characters,
project_id=project_id,
db=db
)
relationships_created.extend(rels)
logger.info(f" ✅ 实际创建了 {len(rels)} 条关系记录")
except Exception as e:
logger.error(f" ❌ 创建角色失败: {e}", exc_info=True)
continue
# 7. 提交事务(注意:这里只flush,让调用方commit
await db.flush()
logger.info(f"🎉 自动角色引入完成: 新增{len(new_characters)}个角色, {len(relationships_created)}条关系")
return {
"new_characters": new_characters,
"relationships_created": relationships_created,
"character_count": len(new_characters),
"analysis_result": analysis_result
}
def _build_character_summary(self, characters: List[Character]) -> str:
"""构建现有角色摘要"""
if not characters:
return "暂无角色"
summary = []
for char in characters:
char_type = "组织" if char.is_organization else "角色"
role_desc = char.role_type or "未知"
personality = (char.personality or "")[:50]
summary.append(f"- {char.name} ({char_type}, {role_desc}): {personality}")
return "\n".join(summary[:20]) # 最多显示20个
async def _analyze_character_needs(
self,
project: Project,
outline_content: str,
existing_chars_summary: str,
db: AsyncSession,
user_id: str,
enable_mcp: bool,
all_chapters_brief: str = "",
start_chapter: int = 1,
chapter_count: int = 3,
plot_stage: str = "发展",
story_direction: str = "继续推进主线剧情"
) -> Dict[str, Any]:
"""AI预测性分析是否需要新角色(方案A"""
# 构建分析提示词
template = await PromptService.get_template(
"AUTO_CHARACTER_ANALYSIS",
user_id,
db
)
# 使用新的预测性分析参数
prompt = PromptService.format_prompt(
template,
title=project.title,
theme=project.theme or "未设定",
genre=project.genre or "未设定",
time_period=project.world_time_period or "未设定",
location=project.world_location or "未设定",
atmosphere=project.world_atmosphere or "未设定",
existing_characters=existing_chars_summary,
all_chapters_brief=all_chapters_brief,
start_chapter=start_chapter,
chapter_count=chapter_count,
plot_stage=plot_stage,
story_direction=story_direction
)
try:
# 使用统一的JSON调用方法(支持自动MCP工具加载)
analysis = await self.ai_service.call_with_json_retry(
prompt=prompt,
max_retries=3,
)
logger.info(f" ✅ AI分析完成: needs_new_characters={analysis.get('needs_new_characters')}")
return analysis
except json.JSONDecodeError as e:
logger.error(f" ❌ 角色需求分析JSON解析失败: {e}")
return {"needs_new_characters": False}
except Exception as e:
logger.error(f" ❌ 角色需求分析失败: {e}")
return {"needs_new_characters": False}
async def _generate_character_details(
self,
spec: Dict[str, Any],
project: Project,
existing_characters: List[Character],
db: AsyncSession,
user_id: str,
enable_mcp: bool
) -> Dict[str, Any]:
"""生成角色详细信息"""
# 🎯 获取项目职业列表
from app.models.career import Career
careers_result = await db.execute(
select(Career)
.where(Career.project_id == project.id)
.order_by(Career.type, Career.name)
)
careers = careers_result.scalars().all()
# 构建职业信息摘要(包含最高阶段信息)
careers_info = ""
if careers:
main_careers = [c for c in careers if c.type == 'main']
sub_careers = [c for c in careers if c.type == 'sub']
if main_careers:
careers_info += "\n\n可用主职业列表(请在career_info中填写职业名称和阶段):\n"
for career in main_careers:
careers_info += f"- 名称: {career.name}, 最高阶段: {career.max_stage}"
if career.description:
careers_info += f", 描述: {career.description[:50]}"
careers_info += "\n"
if sub_careers:
careers_info += "\n可用副职业列表(请在career_info中填写职业名称和阶段):\n"
for career in sub_careers[:5]:
careers_info += f"- 名称: {career.name}, 最高阶段: {career.max_stage}"
if career.description:
careers_info += f", 描述: {career.description[:50]}"
careers_info += "\n"
careers_info += "\n⚠️ 重要提示:生成角色时,职业阶段不能超过该职业的最高阶段!\n"
# 构建角色生成提示词
template = await PromptService.get_template(
"AUTO_CHARACTER_GENERATION",
user_id,
db
)
existing_chars_summary = self._build_character_summary(existing_characters)
prompt = PromptService.format_prompt(
template,
title=project.title,
genre=project.genre or "未设定",
theme=project.theme or "未设定",
time_period=project.world_time_period or "未设定",
location=project.world_location or "未设定",
atmosphere=project.world_atmosphere or "未设定",
rules=project.world_rules or "未设定",
existing_characters=existing_chars_summary + careers_info,
plot_context="根据剧情需要引入的新角色",
character_specification=json.dumps(spec, ensure_ascii=False, indent=2),
mcp_references="" # MCP工具通过AI服务自动加载
)
logger.info(f"🔧 角色详情生成: enable_mcp={enable_mcp}")
# 调用AI生成
try:
character_data = await self.ai_service.call_with_json_retry(
prompt=prompt,
max_retries=2, # 减少重试次数以加快速度
)
char_name = character_data.get('name', '未知')
logger.info(f" ✅ 角色详情生成成功: {char_name}")
logger.debug(f" 角色数据字段: {list(character_data.keys())}")
# 确保关键字段存在
if 'name' not in character_data or not character_data['name']:
logger.warning(f" ⚠️ AI返回的角色数据缺少name字段,使用规格中的信息")
character_data['name'] = spec.get('name', f"新角色{spec.get('role_description', '')[:10]}")
return character_data
except Exception as e:
logger.error(f" ❌ 生成角色详情失败: {e}")
raise
async def _create_character_record(
self,
project_id: str,
character_data: Dict[str, Any],
db: AsyncSession
) -> Character:
"""创建角色数据库记录"""
is_organization = character_data.get("is_organization", False)
# 提取职业信息(支持通过名称匹配)
career_info = character_data.get("career_info", {})
raw_main_career_name = career_info.get("main_career_name") if career_info else None
main_career_stage = career_info.get("main_career_stage", 1) if career_info else None
raw_sub_careers_data = career_info.get("sub_careers", []) if career_info else []
# 🔧 通过职业名称匹配数据库中的职业ID
from app.models.career import Career, CharacterCareer
main_career_id = None
sub_careers_data = []
# 匹配主职业名称
if raw_main_career_name and not is_organization:
career_check = await db.execute(
select(Career).where(
Career.name == raw_main_career_name,
Career.project_id == project_id,
Career.type == 'main'
)
)
matched_career = career_check.scalar_one_or_none()
if matched_career:
main_career_id = matched_career.id
# ✅ 验证阶段不超过最高阶段
if main_career_stage and main_career_stage > matched_career.max_stage:
logger.warning(f" ⚠️ AI返回的主职业阶段({main_career_stage})超过最高阶段({matched_career.max_stage}),自动修正为最高阶段")
main_career_stage = matched_career.max_stage
logger.info(f" ✅ 主职业名称匹配成功: {raw_main_career_name} -> ID: {main_career_id}, 阶段: {main_career_stage}/{matched_career.max_stage}")
else:
logger.warning(f" ⚠️ AI返回的主职业名称未找到: {raw_main_career_name}")
# 匹配副职业名称
if raw_sub_careers_data and not is_organization and isinstance(raw_sub_careers_data, list):
for sub_data in raw_sub_careers_data[:2]:
if isinstance(sub_data, dict):
career_name = sub_data.get('career_name')
if career_name:
career_check = await db.execute(
select(Career).where(
Career.name == career_name,
Career.project_id == project_id,
Career.type == 'sub'
)
)
matched_career = career_check.scalar_one_or_none()
if matched_career:
sub_stage = sub_data.get('stage', 1)
# ✅ 验证阶段不超过最高阶段
if sub_stage > matched_career.max_stage:
logger.warning(f" ⚠️ AI返回的副职业阶段({sub_stage})超过最高阶段({matched_career.max_stage}),自动修正为最高阶段")
sub_stage = matched_career.max_stage
sub_careers_data.append({
'career_id': matched_career.id,
'stage': sub_stage
})
logger.info(f" ✅ 副职业名称匹配成功: {career_name} -> ID: {matched_career.id}, 阶段: {sub_stage}/{matched_career.max_stage}")
else:
logger.warning(f" ⚠️ AI返回的副职业名称未找到: {career_name}")
# 创建角色
character = Character(
project_id=project_id,
name=character_data.get("name", "未命名角色"),
age=str(character_data.get("age", "")),
gender=character_data.get("gender"),
is_organization=is_organization,
role_type=character_data.get("role_type", "supporting"),
personality=character_data.get("personality", ""),
background=character_data.get("background", ""),
appearance=character_data.get("appearance", ""),
relationships=character_data.get("relationships_text", ""),
organization_type=character_data.get("organization_type") if is_organization else None,
organization_purpose=character_data.get("organization_purpose") if is_organization else None,
traits=json.dumps(character_data.get("traits", []), ensure_ascii=False) if character_data.get("traits") else None,
main_career_id=main_career_id,
main_career_stage=main_career_stage if main_career_id else None,
sub_careers=json.dumps(sub_careers_data, ensure_ascii=False) if sub_careers_data else None
)
db.add(character)
await db.flush()
# 处理主职业关联
if main_career_id and not is_organization:
char_career = CharacterCareer(
character_id=character.id,
career_id=main_career_id,
career_type='main',
current_stage=main_career_stage,
stage_progress=0
)
db.add(char_career)
logger.info(f" ✅ 创建主职业关联: {character.name} -> {raw_main_career_name}")
# 处理副职业关联
if sub_careers_data and not is_organization:
for sub_data in sub_careers_data:
char_career = CharacterCareer(
character_id=character.id,
career_id=sub_data['career_id'],
career_type='sub',
current_stage=sub_data['stage'],
stage_progress=0
)
db.add(char_career)
logger.info(f" ✅ 创建副职业关联: {character.name}, 数量: {len(sub_careers_data)}")
# 如果是组织,创建Organization记录
if is_organization:
org = Organization(
character_id=character.id,
project_id=project_id,
member_count=0,
power_level=character_data.get("power_level", 50),
location=character_data.get("location"),
motto=character_data.get("motto"),
color=character_data.get("color")
)
db.add(org)
await db.flush()
logger.info(f" ✅ 创建组织详情: {character.name}")
return character
async def _create_relationships(
self,
new_character: Character,
relationship_specs: List[Dict[str, Any]],
existing_characters: List[Character],
project_id: str,
db: AsyncSession
) -> List[CharacterRelationship]:
"""创建角色关系"""
if not relationship_specs:
return []
relationships = []
for rel_spec in relationship_specs:
try:
target_name = rel_spec.get("target_character_name")
if not target_name:
continue
# 查找目标角色
target_char = next(
(c for c in existing_characters if c.name == target_name),
None
)
if not target_char:
logger.warning(f" ⚠️ 目标角色不存在: {target_name}")
continue
# 检查关系是否已存在
existing_rel = await db.execute(
select(CharacterRelationship).where(
CharacterRelationship.project_id == project_id,
CharacterRelationship.character_from_id == new_character.id,
CharacterRelationship.character_to_id == target_char.id
)
)
if existing_rel.scalar_one_or_none():
logger.debug(f" ️ 关系已存在: {new_character.name} -> {target_name}")
continue
# 创建关系
relationship = CharacterRelationship(
project_id=project_id,
character_from_id=new_character.id,
character_to_id=target_char.id,
relationship_name=rel_spec.get("relationship_type", "未知关系"),
intimacy_level=rel_spec.get("intimacy_level", 50),
description=rel_spec.get("description", ""),
status=rel_spec.get("status", "active"),
source="auto" # 标记为自动生成
)
# 尝试匹配预定义关系类型
rel_type_name = rel_spec.get("relationship_type")
if rel_type_name:
rel_type_result = await db.execute(
select(RelationshipType).where(
RelationshipType.name == rel_type_name
)
)
rel_type = rel_type_result.scalar_one_or_none()
if rel_type:
relationship.relationship_type_id = rel_type.id
db.add(relationship)
relationships.append(relationship)
logger.info(
f" ✅ 创建关系: {new_character.name} -> {target_name} "
f"({rel_spec.get('relationship_type', '未知')})"
)
except Exception as e:
logger.warning(f" ❌ 创建关系失败: {e}")
continue
return relationships
# 全局实例缓存
_auto_character_service_instance: Optional[AutoCharacterService] = None
def get_auto_character_service(ai_service: AIService) -> AutoCharacterService:
"""获取自动角色服务实例(单例模式)"""
global _auto_character_service_instance
if _auto_character_service_instance is None:
_auto_character_service_instance = AutoCharacterService(ai_service)
return _auto_character_service_instance
@@ -0,0 +1,497 @@
"""自动组织引入服务 - 在续写大纲时根据剧情推进自动引入新组织"""
from typing import List, Dict, Any, Optional, Callable, Awaitable
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
import json
from app.models.character import Character
from app.models.relationship import Organization, OrganizationMember
from app.models.project import Project
from app.services.ai_service import AIService
from app.services.prompt_service import PromptService
from app.logger import get_logger
logger = get_logger(__name__)
class AutoOrganizationService:
"""自动组织引入服务"""
def __init__(self, ai_service: AIService):
self.ai_service = ai_service
async def analyze_and_create_organizations(
self,
project_id: str,
outline_content: str,
existing_characters: List[Character],
existing_organizations: List[Dict[str, Any]],
db: AsyncSession,
user_id: str = None,
enable_mcp: bool = True,
all_chapters_brief: str = "",
start_chapter: int = 1,
chapter_count: int = 3,
plot_stage: str = "发展",
story_direction: str = "继续推进主线剧情",
preview_only: bool = False,
progress_callback: Optional[Callable[[str], Awaitable[None]]] = None
) -> Dict[str, Any]:
"""
预测性分析并创建需要的新组织
Args:
project_id: 项目ID
outline_content: 当前批次大纲内容用于向后兼容实际不使用
existing_characters: 现有角色列表
existing_organizations: 现有组织列表
db: 数据库会话
user_id: 用户ID(用于MCP和自定义提示词)
enable_mcp: 是否启用MCP增强
all_chapters_brief: 已有章节概览
start_chapter: 起始章节号
chapter_count: 续写章节数
plot_stage: 剧情阶段
story_direction: 故事发展方向
preview_only: 仅预测不创建用于组织确认机制
Returns:
{
"new_organizations": [组织对象列表], # preview_only=True时为空
"members_created": [成员关系列表], # preview_only=True时为空
"organization_count": 新增组织数量,
"analysis_result": AI分析结果,
"predicted_organizations": [预测的组织数据] # 仅preview_only=True时返回
"needs_new_organizations": bool,
"reason": str
}
"""
logger.info(f"🏛️ 【组织引入】预测性分析:检测是否需要引入新组织...")
logger.info(f" - 项目ID: {project_id}")
logger.info(f" - 续写计划: 第{start_chapter}章起,共{chapter_count}")
logger.info(f" - 剧情阶段: {plot_stage}")
logger.info(f" - 发展方向: {story_direction}")
logger.info(f" - 现有角色数: {len(existing_characters)}")
logger.info(f" - 现有组织数: {len(existing_organizations)}")
# 1. 获取项目信息
project_result = await db.execute(
select(Project).where(Project.id == project_id)
)
project = project_result.scalar_one_or_none()
if not project:
raise ValueError("项目不存在")
# 2. 构建现有组织信息摘要
existing_orgs_summary = self._build_organization_summary(existing_organizations)
existing_chars_summary = self._build_character_summary(existing_characters)
# 3. AI预测性分析是否需要新组织
if progress_callback:
await progress_callback("🤖 AI分析组织需求...")
analysis_result = await self._analyze_organization_needs(
project=project,
outline_content=outline_content,
existing_orgs_summary=existing_orgs_summary,
existing_chars_summary=existing_chars_summary,
db=db,
user_id=user_id,
enable_mcp=enable_mcp,
all_chapters_brief=all_chapters_brief,
start_chapter=start_chapter,
chapter_count=chapter_count,
plot_stage=plot_stage,
story_direction=story_direction
)
if progress_callback:
await progress_callback("✅ 组织需求分析完成")
# 4. 判断是否需要创建组织
if not analysis_result or not analysis_result.get("needs_new_organizations"):
logger.info("✅ AI判断:当前剧情不需要引入新组织")
return {
"new_organizations": [],
"members_created": [],
"organization_count": 0,
"analysis_result": analysis_result,
"predicted_organizations": [],
"needs_new_organizations": False,
"reason": analysis_result.get("reason", "当前剧情不需要新组织")
}
# 5. 如果是预览模式,仅返回预测结果,不创建组织
if preview_only:
organization_specs = analysis_result.get("organization_specifications", [])
logger.info(f"🔮 预览模式:预测到 {len(organization_specs)} 个组织,不创建数据库记录")
return {
"new_organizations": [],
"members_created": [],
"organization_count": 0,
"analysis_result": analysis_result,
"predicted_organizations": organization_specs,
"needs_new_organizations": True,
"reason": analysis_result.get("reason", "预测需要新组织")
}
# 6. 批量生成新组织(非预览模式)
new_organizations = []
members_created = []
organization_specs = analysis_result.get("organization_specifications", [])
logger.info(f"🎯 AI建议引入 {len(organization_specs)} 个新组织")
for idx, spec in enumerate(organization_specs):
try:
spec_name = spec.get('name', spec.get('organization_description', '未命名'))
logger.info(f" [{idx+1}/{len(organization_specs)}] 生成组织规格: {spec_name}")
logger.debug(f" 组织规格内容: {json.dumps(spec, ensure_ascii=False)}")
if progress_callback:
await progress_callback(f"🏛️ [{idx+1}/{len(organization_specs)}] 生成组织详情: {spec_name}")
# 生成组织详细信息
organization_data = await self._generate_organization_details(
spec=spec,
project=project,
existing_characters=existing_characters,
existing_organizations=existing_organizations,
db=db,
user_id=user_id,
enable_mcp=enable_mcp
)
logger.debug(f" AI生成的组织数据: {json.dumps(organization_data, ensure_ascii=False)[:200]}")
if progress_callback:
await progress_callback(f"💾 [{idx+1}/{len(organization_specs)}] 保存组织: {organization_data.get('name', spec_name)}")
# 创建组织记录(先创建Character记录,再创建Organization记录)
character, organization = await self._create_organization_record(
project_id=project_id,
organization_data=organization_data,
db=db
)
new_organizations.append({
"character": character,
"organization": organization
})
logger.info(f" ✅ 创建新组织: {character.name}, ID: {organization.id}")
if progress_callback:
await progress_callback(f"✅ [{idx+1}/{len(organization_specs)}] 组织创建成功: {character.name}")
# 建立成员关系
members_data = organization_data.get("initial_members", [])
if members_data:
logger.info(f" 🔗 开始创建 {len(members_data)} 个成员关系...")
if progress_callback:
await progress_callback(f"🔗 [{idx+1}/{len(organization_specs)}] 建立 {len(members_data)} 个成员关系")
members = await self._create_member_relationships(
organization=organization,
member_specs=members_data,
existing_characters=existing_characters,
project_id=project_id,
db=db
)
members_created.extend(members)
logger.info(f" ✅ 实际创建了 {len(members)} 个成员关系记录")
except Exception as e:
logger.error(f" ❌ 创建组织失败: {e}", exc_info=True)
continue
# 7. 提交事务(注意:这里只flush,让调用方commit
await db.flush()
logger.info(f"🎉 自动组织引入完成: 新增{len(new_organizations)}个组织, {len(members_created)}个成员关系")
return {
"new_organizations": new_organizations,
"members_created": members_created,
"organization_count": len(new_organizations),
"analysis_result": analysis_result,
"predicted_organizations": [],
"needs_new_organizations": True,
"reason": analysis_result.get("reason", "")
}
def _build_organization_summary(self, organizations: List[Dict[str, Any]]) -> str:
"""构建现有组织摘要"""
if not organizations:
return "暂无组织"
summary = []
for org in organizations:
org_name = org.get("name", "未知")
org_type = org.get("organization_type", "未知类型")
power_level = org.get("power_level", 50)
purpose = (org.get("organization_purpose") or "")[:50]
summary.append(f"- {org_name} ({org_type}, 势力等级:{power_level}): {purpose}")
return "\n".join(summary[:15]) # 最多显示15个
def _build_character_summary(self, characters: List[Character]) -> str:
"""构建现有角色摘要"""
if not characters:
return "暂无角色"
summary = []
for char in characters:
if not char.is_organization: # 只统计非组织角色
char_role = char.role_type or "未知"
personality = (char.personality or "")[:30]
summary.append(f"- {char.name} ({char_role}): {personality}")
return "\n".join(summary[:20]) # 最多显示20个
async def _analyze_organization_needs(
self,
project: Project,
outline_content: str,
existing_orgs_summary: str,
existing_chars_summary: str,
db: AsyncSession,
user_id: str,
enable_mcp: bool,
all_chapters_brief: str = "",
start_chapter: int = 1,
chapter_count: int = 3,
plot_stage: str = "发展",
story_direction: str = "继续推进主线剧情"
) -> Dict[str, Any]:
"""AI预测性分析是否需要新组织"""
# 构建分析提示词
template = await PromptService.get_template(
"AUTO_ORGANIZATION_ANALYSIS",
user_id,
db
)
# 使用新的预测性分析参数
prompt = PromptService.format_prompt(
template,
title=project.title,
theme=project.theme or "未设定",
genre=project.genre or "未设定",
time_period=project.world_time_period or "未设定",
location=project.world_location or "未设定",
atmosphere=project.world_atmosphere or "未设定",
existing_organizations=existing_orgs_summary,
existing_characters=existing_chars_summary,
all_chapters_brief=all_chapters_brief,
start_chapter=start_chapter,
chapter_count=chapter_count,
plot_stage=plot_stage,
story_direction=story_direction
)
try:
# 使用统一的JSON调用方法(支持自动MCP工具加载)
analysis = await self.ai_service.call_with_json_retry(
prompt=prompt,
max_retries=3,
)
logger.info(f" ✅ AI分析完成: needs_new_organizations={analysis.get('needs_new_organizations')}")
return analysis
except json.JSONDecodeError as e:
logger.error(f" ❌ 组织需求分析JSON解析失败: {e}")
return {"needs_new_organizations": False}
except Exception as e:
logger.error(f" ❌ 组织需求分析失败: {e}")
return {"needs_new_organizations": False}
async def _generate_organization_details(
self,
spec: Dict[str, Any],
project: Project,
existing_characters: List[Character],
existing_organizations: List[Dict[str, Any]],
db: AsyncSession,
user_id: str,
enable_mcp: bool
) -> Dict[str, Any]:
"""生成组织详细信息"""
# 构建组织生成提示词
template = await PromptService.get_template(
"AUTO_ORGANIZATION_GENERATION",
user_id,
db
)
existing_orgs_summary = self._build_organization_summary(existing_organizations)
existing_chars_summary = self._build_character_summary(existing_characters)
prompt = PromptService.format_prompt(
template,
title=project.title,
genre=project.genre or "未设定",
theme=project.theme or "未设定",
time_period=project.world_time_period or "未设定",
location=project.world_location or "未设定",
atmosphere=project.world_atmosphere or "未设定",
rules=project.world_rules or "未设定",
existing_organizations=existing_orgs_summary,
existing_characters=existing_chars_summary,
plot_context="根据剧情需要引入的新组织",
organization_specification=json.dumps(spec, ensure_ascii=False, indent=2),
mcp_references="" # 暂时不使用MCP增强
)
# 调用AI生成(使用统一的JSON调用方法)
try:
# 使用统一的JSON调用方法(支持自动MCP工具加载)
organization_data = await self.ai_service.call_with_json_retry(
prompt=prompt,
max_retries=3,
)
org_name = organization_data.get('name', '未知')
logger.info(f" ✅ 组织详情生成成功: {org_name}")
logger.debug(f" 组织数据字段: {list(organization_data.keys())}")
# 确保关键字段存在
if 'name' not in organization_data or not organization_data['name']:
logger.warning(f" ⚠️ AI返回的组织数据缺少name字段,使用规格中的信息")
organization_data['name'] = spec.get('name', f"新组织{spec.get('organization_description', '')[:10]}")
return organization_data
except Exception as e:
logger.error(f" ❌ 生成组织详情失败: {e}")
raise
async def _create_organization_record(
self,
project_id: str,
organization_data: Dict[str, Any],
db: AsyncSession
) -> tuple:
"""创建组织数据库记录(包括Character和Organization"""
# 首先创建Character记录(is_organization=True
character = Character(
project_id=project_id,
name=organization_data.get("name", "未命名组织"),
is_organization=True,
role_type=organization_data.get("role_type", "supporting"),
personality=organization_data.get("personality", ""), # 组织特性
background=organization_data.get("background", ""), # 组织背景
appearance=organization_data.get("appearance", ""), # 外在表现
organization_type=organization_data.get("organization_type"),
organization_purpose=organization_data.get("organization_purpose"),
traits=json.dumps(organization_data.get("traits", []), ensure_ascii=False) if organization_data.get("traits") else None
)
db.add(character)
await db.flush()
# 然后创建Organization记录
organization = Organization(
character_id=character.id,
project_id=project_id,
power_level=organization_data.get("power_level", 50),
member_count=0,
location=organization_data.get("location"),
motto=organization_data.get("motto"),
color=organization_data.get("color")
)
db.add(organization)
await db.flush()
logger.info(f" ✅ 创建组织记录: {character.name}, Organization ID: {organization.id}")
return character, organization
async def _create_member_relationships(
self,
organization: Organization,
member_specs: List[Dict[str, Any]],
existing_characters: List[Character],
project_id: str,
db: AsyncSession
) -> List[OrganizationMember]:
"""创建组织成员关系"""
if not member_specs:
return []
members = []
for member_spec in member_specs:
try:
character_name = member_spec.get("character_name")
if not character_name:
continue
# 查找目标角色
target_char = next(
(c for c in existing_characters if c.name == character_name and not c.is_organization),
None
)
if not target_char:
logger.warning(f" ⚠️ 目标角色不存在: {character_name}")
continue
# 检查成员关系是否已存在
existing_member = await db.execute(
select(OrganizationMember).where(
OrganizationMember.organization_id == organization.id,
OrganizationMember.character_id == target_char.id
)
)
if existing_member.scalar_one_or_none():
logger.debug(f" ️ 成员关系已存在: {character_name} -> {organization.id}")
continue
# 创建成员关系
member = OrganizationMember(
organization_id=organization.id,
character_id=target_char.id,
position=member_spec.get("position", "成员"),
rank=member_spec.get("rank", 0),
loyalty=member_spec.get("loyalty", 50),
status=member_spec.get("status", "active"),
joined_at=member_spec.get("joined_at"),
source="auto" # 标记为自动生成
)
db.add(member)
members.append(member)
logger.info(
f" ✅ 创建成员关系: {character_name} -> {organization.id} "
f"({member_spec.get('position', '成员')})"
)
except Exception as e:
logger.warning(f" ❌ 创建成员关系失败: {e}")
continue
# 更新组织成员数量
if members:
organization.member_count = (organization.member_count or 0) + len(members)
return members
# 全局实例缓存
_auto_organization_service_instance: Optional[AutoOrganizationService] = None
def get_auto_organization_service(ai_service: AIService) -> AutoOrganizationService:
"""获取自动组织服务实例(单例模式)"""
global _auto_organization_service_instance
if _auto_organization_service_instance is None:
_auto_organization_service_instance = AutoOrganizationService(ai_service)
return _auto_organization_service_instance
+234
View File
@@ -0,0 +1,234 @@
"""职业生成服务"""
from typing import Dict, Any, List
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
import json
from app.models.project import Project
from app.models.career import Career
from app.services.ai_service import AIService
from app.logger import get_logger
logger = get_logger(__name__)
class CareerService:
"""职业相关业务逻辑服务"""
@staticmethod
async def get_career_generation_prompt(
project: Project,
main_career_count: int = 2,
sub_career_count: int = 6
) -> str:
"""
构建职业体系生成的提示词
Args:
project: 项目对象
main_career_count: 主职业数量
sub_career_count: 副职业数量
Returns:
完整的提示词
"""
project_context = f"""
项目信息
- 书名{project.title}
- 类型{project.genre or '未设定'}
- 主题{project.theme or '未设定'}
- 时间背景{project.world_time_period or '未设定'}
- 地理位置{project.world_location or '未设定'}
- 氛围基调{project.world_atmosphere or '未设定'}
- 世界规则{project.world_rules or '未设定'}
"""
user_requirements = f"""
生成要求
- 主职业数量{main_career_count}
- 副职业数量{sub_career_count}
- 主职业必须严格符合世界观规则体现核心能力体系
- 副职业可以更加自由灵活包含生产辅助特殊类型
"""
prompt = f"""{project_context}
{user_requirements}
请为这个小说项目生成完整的职业体系返回JSON格式结构如下
{{
"main_careers": [
{{
"name": "职业名称",
"description": "职业描述(100-200字)",
"category": "职业分类(如:战斗系、法术系、体修系等)",
"stages": [
{{"level": 1, "name": "阶段名称", "description": "阶段描述"}},
{{"level": 2, "name": "阶段名称", "description": "阶段描述"}},
...共10个阶段
],
"max_stage": 10,
"requirements": "职业要求(如:需要特定天赋、资质等)",
"special_abilities": "特殊能力描述",
"worldview_rules": "世界观规则关联(说明该职业如何融入世界观)",
"attribute_bonuses": {{"strength": "+10%", "intelligence": "+5%"}}
}}
],
"sub_careers": [
{{
"name": "副职业名称",
"description": "职业描述",
"category": "生产系/辅助系/特殊系",
"stages": [
{{"level": 1, "name": "阶段名称", "description": "阶段描述"}},
...5-8个阶段
],
"max_stage": 5,
"requirements": "职业要求",
"special_abilities": "特殊能力"
}}
]
}}
重要注意事项
1. 主职业的阶段设定要详细体现明确的成长路径阶段名称要有特色
2. 根据小说类型选择合适的职业
- 修仙类剑修体修法修符修等阶段如炼气筑基金丹元婴...
- 玄幻类战士法师刺客等阶段如见习初级中级高级...
- 都市异能异能者分类阶段如觉醒初阶中阶高阶...
- 科幻未来基因战士机甲师等阶段如E级D级C级B级...
3. 副职业要有实用性和趣味性炼丹师炼器师阵法师驯兽师医师等
4. 所有职业都要符合项目的整体世界观设定
5. 阶段描述要简洁明了体现该阶段的核心特征
6. **只返回纯JSON对象不要添加任何解释文字或markdown标记**
"""
return prompt
@staticmethod
async def parse_and_save_careers(
career_data: Dict[str, Any],
project_id: str,
db: AsyncSession
) -> Dict[str, List[str]]:
"""
解析AI返回的职业数据并保存到数据库
Args:
career_data: AI返回的职业数据已解析为dict
project_id: 项目ID
db: 数据库会话
Returns:
{"main_careers": [...], "sub_careers": [...]} 创建的职业名称列表
"""
result = {
"main_careers": [],
"sub_careers": []
}
# 保存主职业
for idx, career_info in enumerate(career_data.get("main_careers", [])):
try:
stages_json = json.dumps(career_info.get("stages", []), ensure_ascii=False)
attribute_bonuses = career_info.get("attribute_bonuses")
attribute_bonuses_json = json.dumps(attribute_bonuses, ensure_ascii=False) if attribute_bonuses else None
career = Career(
project_id=project_id,
name=career_info.get("name", f"未命名主职业{idx+1}"),
type="main",
description=career_info.get("description"),
category=career_info.get("category"),
stages=stages_json,
max_stage=career_info.get("max_stage", 10),
requirements=career_info.get("requirements"),
special_abilities=career_info.get("special_abilities"),
worldview_rules=career_info.get("worldview_rules"),
attribute_bonuses=attribute_bonuses_json,
source="ai"
)
db.add(career)
await db.flush()
result["main_careers"].append(career.name)
logger.info(f" ✅ 创建主职业:{career.name}")
except Exception as e:
logger.error(f" ❌ 创建主职业失败:{str(e)}")
continue
# 保存副职业
for idx, career_info in enumerate(career_data.get("sub_careers", [])):
try:
stages_json = json.dumps(career_info.get("stages", []), ensure_ascii=False)
attribute_bonuses = career_info.get("attribute_bonuses")
attribute_bonuses_json = json.dumps(attribute_bonuses, ensure_ascii=False) if attribute_bonuses else None
career = Career(
project_id=project_id,
name=career_info.get("name", f"未命名副职业{idx+1}"),
type="sub",
description=career_info.get("description"),
category=career_info.get("category"),
stages=stages_json,
max_stage=career_info.get("max_stage", 5),
requirements=career_info.get("requirements"),
special_abilities=career_info.get("special_abilities"),
worldview_rules=career_info.get("worldview_rules"),
attribute_bonuses=attribute_bonuses_json,
source="ai"
)
db.add(career)
await db.flush()
result["sub_careers"].append(career.name)
logger.info(f" ✅ 创建副职业:{career.name}")
except Exception as e:
logger.error(f" ❌ 创建副职业失败:{str(e)}")
continue
await db.commit()
return result
@staticmethod
async def get_project_careers_summary(project_id: str, db: AsyncSession) -> Dict[str, Any]:
"""
获取项目职业体系摘要
Args:
project_id: 项目ID
db: 数据库会话
Returns:
职业体系摘要信息
"""
result = await db.execute(
select(Career).where(Career.project_id == project_id)
)
careers = result.scalars().all()
main_careers = []
sub_careers = []
for career in careers:
career_info = {
"id": career.id,
"name": career.name,
"category": career.category,
"max_stage": career.max_stage
}
if career.type == "main":
main_careers.append(career_info)
else:
sub_careers.append(career_info)
return {
"main_careers": main_careers,
"sub_careers": sub_careers,
"total_count": len(careers)
}
# 创建全局服务实例
career_service = CareerService()
@@ -0,0 +1,398 @@
"""职业更新服务 - 根据章节分析自动更新角色职业信息"""
from typing import Dict, Any, List, Optional
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
from app.models.character import Character
from app.models.career import Career, CharacterCareer
from app.logger import get_logger
logger = get_logger(__name__)
class CareerUpdateService:
"""职业更新服务 - 根据章节分析结果自动更新角色职业"""
@staticmethod
async def update_careers_from_analysis(
db: AsyncSession,
project_id: str,
character_states: List[Dict[str, Any]],
chapter_id: str,
chapter_number: int
) -> Dict[str, Any]:
"""
根据章节分析结果更新角色职业
Args:
db: 数据库会话
project_id: 项目ID
character_states: 角色状态变化列表来自PlotAnalysis
chapter_id: 章节ID
chapter_number: 章节编号
Returns:
更新结果字典包含更新数量和变更日志
"""
if not character_states:
logger.info("📋 角色状态列表为空,跳过职业更新")
return {"updated_count": 0, "changes": []}
updated_count = 0
changes_log = []
logger.info(f"🔍 开始分析第{chapter_number}章的角色职业变化...")
for char_state in character_states:
char_name = char_state.get('character_name')
career_changes = char_state.get('career_changes', {})
# 如果没有职业变化信息,跳过
if not career_changes or not isinstance(career_changes, dict):
continue
# 检查是否有实质性的职业变化
main_stage_change = career_changes.get('main_career_stage_change', 0)
sub_career_changes = career_changes.get('sub_career_changes', [])
new_careers = career_changes.get('new_careers', [])
if main_stage_change == 0 and not sub_career_changes and not new_careers:
continue
logger.info(f" 👤 检测到角色 [{char_name}] 有职业变化")
# 1. 查询角色
char_result = await db.execute(
select(Character).where(
Character.name == char_name,
Character.project_id == project_id
)
)
character = char_result.scalar_one_or_none()
if not character:
logger.warning(f" ⚠️ 角色不存在: {char_name},跳过")
continue
# 2. 更新主职业阶段
if main_stage_change != 0 and character.main_career_id:
success = await CareerUpdateService._update_main_career_stage(
db=db,
character=character,
stage_change=main_stage_change,
chapter_number=chapter_number,
career_changes=career_changes,
changes_log=changes_log
)
if success:
updated_count += 1
# 3. 更新副职业(如果有)
if sub_career_changes and isinstance(sub_career_changes, list):
for sub_change in sub_career_changes:
success = await CareerUpdateService._update_sub_career_stage(
db=db,
character=character,
project_id=project_id,
sub_change=sub_change,
chapter_number=chapter_number,
changes_log=changes_log
)
if success:
updated_count += 1
# 4. 添加新职业(如果有)
if new_careers and isinstance(new_careers, list):
for new_career_name in new_careers:
success = await CareerUpdateService._add_new_career(
db=db,
character=character,
project_id=project_id,
career_name=new_career_name,
chapter_number=chapter_number,
changes_log=changes_log
)
if success:
updated_count += 1
# 提交所有更改
if updated_count > 0:
await db.commit()
logger.info(f"✅ 职业更新完成: 共更新了 {updated_count} 个角色的职业信息")
else:
logger.info("📋 本章没有角色职业变化")
return {
"updated_count": updated_count,
"changes": changes_log
}
@staticmethod
async def _update_main_career_stage(
db: AsyncSession,
character: Character,
stage_change: int,
chapter_number: int,
career_changes: Dict[str, Any],
changes_log: List[Dict[str, Any]]
) -> bool:
"""更新主职业阶段"""
try:
# 查询主职业关联
char_career_result = await db.execute(
select(CharacterCareer).where(
CharacterCareer.character_id == character.id,
CharacterCareer.career_type == 'main'
)
)
char_career = char_career_result.scalar_one_or_none()
if not char_career:
logger.warning(f" ⚠️ {character.name} 没有主职业关联记录")
return False
# 查询职业信息
career_result = await db.execute(
select(Career).where(Career.id == char_career.career_id)
)
career = career_result.scalar_one_or_none()
if not career:
logger.warning(f" ⚠️ 职业ID {char_career.career_id} 不存在")
return False
# 计算新阶段(不超过最大阶段,不低于1)
old_stage = char_career.current_stage
new_stage = min(max(1, old_stage + stage_change), career.max_stage)
# 如果没有实际变化,跳过
if new_stage == old_stage:
logger.info(f" 📊 {character.name}{career.name} 已达到边界,无法变更")
return False
# 更新CharacterCareer表
char_career.current_stage = new_stage
# 同步更新Character表的冗余字段
character.main_career_stage = new_stage
# 记录变更日志
change_desc = f"{'晋升' if stage_change > 0 else '降级'}"
breakthrough_desc = career_changes.get('career_breakthrough', '')
changes_log.append({
'character': character.name,
'career': career.name,
'career_type': 'main',
'old_stage': old_stage,
'new_stage': new_stage,
'change': stage_change,
'chapter': chapter_number,
'description': breakthrough_desc
})
logger.info(
f"{character.name} 的主职业 [{career.name}] "
f"{old_stage}阶 → {new_stage}阶 ({change_desc})"
)
if breakthrough_desc:
logger.info(f" 突破描述: {breakthrough_desc[:50]}...")
return True
except Exception as e:
logger.error(f" ❌ 更新主职业失败: {str(e)}")
return False
@staticmethod
async def _update_sub_career_stage(
db: AsyncSession,
character: Character,
project_id: str,
sub_change: Dict[str, Any],
chapter_number: int,
changes_log: List[Dict[str, Any]]
) -> bool:
"""更新副职业阶段"""
try:
career_name = sub_change.get('career_name')
stage_change = sub_change.get('stage_change', 0)
if not career_name or stage_change == 0:
return False
# 1. 查询职业(通过名称)
career_result = await db.execute(
select(Career).where(
Career.name == career_name,
Career.project_id == project_id,
Career.type == 'sub'
)
)
career = career_result.scalar_one_or_none()
if not career:
logger.warning(f" ⚠️ 副职业 [{career_name}] 不存在")
return False
# 2. 查询角色-职业关联
char_career_result = await db.execute(
select(CharacterCareer).where(
CharacterCareer.character_id == character.id,
CharacterCareer.career_id == career.id,
CharacterCareer.career_type == 'sub'
)
)
char_career = char_career_result.scalar_one_or_none()
if not char_career:
logger.warning(f" ⚠️ {character.name} 没有 [{career_name}] 副职业")
return False
# 3. 计算新阶段
old_stage = char_career.current_stage
new_stage = min(max(1, old_stage + stage_change), career.max_stage)
if new_stage == old_stage:
return False
# 4. 更新阶段
char_career.current_stage = new_stage
# 5. 同步更新Character表的sub_careers JSON字段
import json
sub_careers = json.loads(character.sub_careers) if character.sub_careers else []
for sc in sub_careers:
if sc.get('career_id') == career.id:
sc['stage'] = new_stage
break
character.sub_careers = json.dumps(sub_careers, ensure_ascii=False)
# 6. 记录变更
changes_log.append({
'character': character.name,
'career': career.name,
'career_type': 'sub',
'old_stage': old_stage,
'new_stage': new_stage,
'change': stage_change,
'chapter': chapter_number
})
logger.info(
f"{character.name} 的副职业 [{career.name}] "
f"{old_stage}阶 → {new_stage}"
)
return True
except Exception as e:
logger.error(f" ❌ 更新副职业失败: {str(e)}")
return False
@staticmethod
async def _add_new_career(
db: AsyncSession,
character: Character,
project_id: str,
career_name: str,
chapter_number: int,
changes_log: List[Dict[str, Any]]
) -> bool:
"""为角色添加新职业"""
try:
# 1. 查询职业
career_result = await db.execute(
select(Career).where(
Career.name == career_name,
Career.project_id == project_id
)
)
career = career_result.scalar_one_or_none()
if not career:
logger.warning(f" ⚠️ 职业 [{career_name}] 不存在,无法添加")
return False
# 2. 检查是否已存在
existing_result = await db.execute(
select(CharacterCareer).where(
CharacterCareer.character_id == character.id,
CharacterCareer.career_id == career.id
)
)
if existing_result.scalar_one_or_none():
logger.info(f" 📋 {character.name} 已拥有 [{career_name}],跳过")
return False
# 3. 根据职业类型添加
if career.type == 'main':
# 检查是否已有主职业
if character.main_career_id:
logger.warning(f" ⚠️ {character.name} 已有主职业,无法添加新主职业")
return False
# 添加主职业
import uuid
new_char_career = CharacterCareer(
id=str(uuid.uuid4()),
character_id=character.id,
career_id=career.id,
career_type='main',
current_stage=1
)
db.add(new_char_career)
# 更新Character表
character.main_career_id = career.id
character.main_career_stage = 1
logger.info(f"{character.name} 获得新主职业 [{career_name}]")
else: # sub职业
# 检查副职业数量(最多2个)
sub_count_result = await db.execute(
select(CharacterCareer).where(
CharacterCareer.character_id == character.id,
CharacterCareer.career_type == 'sub'
)
)
if len(sub_count_result.scalars().all()) >= 2:
logger.warning(f" ⚠️ {character.name} 的副职业已达上限(2个)")
return False
# 添加副职业
import uuid
new_char_career = CharacterCareer(
id=str(uuid.uuid4()),
character_id=character.id,
career_id=career.id,
career_type='sub',
current_stage=1
)
db.add(new_char_career)
# 更新Character表的sub_careers JSON
import json
sub_careers = json.loads(character.sub_careers) if character.sub_careers else []
sub_careers.append({
'career_id': career.id,
'stage': 1
})
character.sub_careers = json.dumps(sub_careers, ensure_ascii=False)
logger.info(f"{character.name} 获得新副职业 [{career_name}]")
# 记录变更
changes_log.append({
'character': character.name,
'career': career.name,
'career_type': career.type,
'action': 'new',
'chapter': chapter_number
})
return True
except Exception as e:
logger.error(f" ❌ 添加新职业失败: {str(e)}")
return False
@@ -0,0 +1,745 @@
"""章节上下文构建服务 - 实现RTCO框架的智能上下文构建"""
from dataclasses import dataclass, field
from typing import Dict, Any, Optional, List
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
import json
from app.models.chapter import Chapter
from app.models.project import Project
from app.models.outline import Outline
from app.models.character import Character
from app.models.career import Career, CharacterCareer
from app.models.memory import StoryMemory
from app.logger import get_logger
logger = get_logger(__name__)
@dataclass
class ChapterContext:
"""
章节上下文数据结构
采用RTCO框架的分层设计
- P0-核心必须大纲衔接点字数要求
- P1-重要按需角色情感基调风格
- P2-参考条件触发记忆故事骨架MCP资料
"""
# === P0-核心信息(必须包含)===
chapter_outline: str = "" # 本章大纲
continuation_point: Optional[str] = None # 衔接锚点(上一章结尾)
target_word_count: int = 3000 # 目标字数
min_word_count: int = 2500 # 最小字数
max_word_count: int = 4000 # 最大字数
narrative_perspective: str = "第三人称" # 叙事视角
# === 本章基本信息 ===
chapter_number: int = 1 # 章节序号
chapter_title: str = "" # 章节标题
# === 项目基本信息 ===
title: str = "" # 书名
genre: str = "" # 类型
theme: str = "" # 主题
# === P1-重要信息(按需包含)===
chapter_characters: str = "" # 本章涉及角色(精简)
emotional_tone: str = "" # 情感基调
style_instruction: str = "" # 写作风格指令(摘要化)
# === P2-参考信息(条件触发)===
relevant_memories: Optional[str] = None # 相关记忆(精简版)
story_skeleton: Optional[str] = None # 故事骨架(50章+启用)
mcp_references: Optional[str] = None # MCP参考资料
# === 元信息 ===
context_stats: Dict[str, Any] = field(default_factory=dict) # 统计信息
def get_total_context_length(self) -> int:
"""计算总上下文长度"""
total = 0
for field_name in ['chapter_outline', 'continuation_point', 'chapter_characters',
'relevant_memories', 'story_skeleton', 'style_instruction']:
value = getattr(self, field_name, None)
if value:
total += len(value)
return total
class ChapterContextBuilder:
"""
章节上下文构建器
实现动态裁剪逻辑根据章节序号自动调整上下文复杂度
- 第1章无前置上下文仅提供大纲和角色
- 第2-10上一章结尾300字 + 涉及角色
- 第11-50上一章结尾500字 + 相关记忆3条
- 第51章+上一章结尾500字 + 故事骨架 + 智能记忆5条
"""
# 配置常量
ENDING_LENGTH_SHORT = 300 # 1-10章:短衔接
ENDING_LENGTH_NORMAL = 500 # 11章+:标准衔接
MEMORY_COUNT_LIGHT = 3 # 11-50章:轻量记忆
MEMORY_COUNT_FULL = 5 # 51章+:完整记忆
SKELETON_THRESHOLD = 50 # 启用故事骨架的章节阈值
SKELETON_SAMPLE_INTERVAL = 10 # 故事骨架采样间隔
MEMORY_IMPORTANCE_THRESHOLD = 0.7 # 记忆重要性阈值
STYLE_MAX_LENGTH = 200 # 风格描述最大长度
MAX_CONTEXT_LENGTH = 3000 # 总上下文最大字符数
def __init__(self, memory_service=None):
"""
初始化构建器
Args:
memory_service: 记忆服务实例可选用于检索相关记忆
"""
self.memory_service = memory_service
async def build(
self,
chapter: Chapter,
project: Project,
outline: Optional[Outline],
user_id: str,
db: AsyncSession,
style_content: Optional[str] = None,
target_word_count: int = 3000,
temp_narrative_perspective: Optional[str] = None
) -> ChapterContext:
"""
构建章节生成所需的上下文
Args:
chapter: 章节对象
project: 项目对象
outline: 大纲对象可选
user_id: 用户ID
db: 数据库会话
style_content: 写作风格内容可选
target_word_count: 目标字数
temp_narrative_perspective: 临时叙事视角可选覆盖项目默认
Returns:
ChapterContext: 结构化的上下文对象
"""
chapter_number = chapter.chapter_number
logger.info(f"📝 开始构建章节上下文: 第{chapter_number}")
# 确定叙事视角
narrative_perspective = (
temp_narrative_perspective or
project.narrative_perspective or
"第三人称"
)
# 初始化上下文
context = ChapterContext(
chapter_number=chapter_number,
chapter_title=chapter.title or "",
title=project.title or "",
genre=project.genre or "",
theme=project.theme or "",
target_word_count=target_word_count,
min_word_count=max(500, target_word_count - 500),
max_word_count=target_word_count + 1000,
narrative_perspective=narrative_perspective
)
# === P0-核心信息(始终构建)===
context.chapter_outline = await self._build_chapter_outline(
chapter, outline, project.outline_mode
)
# === 衔接锚点(根据章节调整长度)===
if chapter_number == 1:
context.continuation_point = None
logger.info(" ✅ 第1章无需衔接锚点")
elif chapter_number <= 10:
context.continuation_point = await self._get_last_ending(
chapter, db, self.ENDING_LENGTH_SHORT
)
logger.info(f" ✅ 衔接锚点(短): {len(context.continuation_point or '')}字符")
else:
context.continuation_point = await self._get_last_ending(
chapter, db, self.ENDING_LENGTH_NORMAL
)
logger.info(f" ✅ 衔接锚点(标准): {len(context.continuation_point or '')}字符")
# === P1-重要信息 ===
context.chapter_characters = await self._build_chapter_characters(
chapter, project, outline, db
)
context.emotional_tone = self._extract_emotional_tone(chapter, outline)
# 写作风格(摘要化)
if style_content:
context.style_instruction = self._summarize_style(style_content)
# === P2-参考信息(条件触发)===
if chapter_number > 10 and self.memory_service:
memory_limit = (
self.MEMORY_COUNT_LIGHT if chapter_number <= 50
else self.MEMORY_COUNT_FULL
)
context.relevant_memories = await self._get_relevant_memories(
user_id, project.id, chapter_number,
context.chapter_outline,
limit=memory_limit
)
logger.info(f" ✅ 相关记忆: {len(context.relevant_memories or '')}字符")
# 故事骨架(50章+
if chapter_number > self.SKELETON_THRESHOLD:
context.story_skeleton = await self._build_story_skeleton(
project.id, chapter_number, db
)
logger.info(f" ✅ 故事骨架: {len(context.story_skeleton or '')}字符")
# === 统计信息 ===
context.context_stats = {
"chapter_number": chapter_number,
"has_continuation": context.continuation_point is not None,
"continuation_length": len(context.continuation_point or ""),
"characters_length": len(context.chapter_characters),
"memories_length": len(context.relevant_memories or ""),
"skeleton_length": len(context.story_skeleton or ""),
"total_length": context.get_total_context_length()
}
logger.info(f"📊 上下文构建完成: 总长度 {context.context_stats['total_length']} 字符")
return context
async def _build_chapter_outline(
self,
chapter: Chapter,
outline: Optional[Outline],
outline_mode: str
) -> str:
"""
构建本章大纲内容
Args:
chapter: 章节对象
outline: 大纲对象
outline_mode: 大纲模式one-to-one/one-to-many
Returns:
本章大纲文本
"""
if outline_mode == 'one-to-one':
# 一对一模式:使用大纲的 content
return outline.content if outline else chapter.summary or '暂无大纲'
else:
# 一对多模式:优先使用 expansion_plan 的详细规划
if chapter.expansion_plan:
try:
plan = json.loads(chapter.expansion_plan)
outline_content = f"""剧情摘要:{plan.get('plot_summary', '')}
关键事件
{chr(10).join(f'- {event}' for event in plan.get('key_events', []))}
角色焦点{', '.join(plan.get('character_focus', []))}
情感基调{plan.get('emotional_tone', '未设定')}
叙事目标{plan.get('narrative_goal', '未设定')}
冲突类型{plan.get('conflict_type', '未设定')}"""
return outline_content
except json.JSONDecodeError:
pass
# 回退到大纲内容
return outline.content if outline else chapter.summary or '暂无大纲'
async def _get_last_ending(
self,
chapter: Chapter,
db: AsyncSession,
max_length: int
) -> Optional[str]:
"""
获取上一章结尾内容作为衔接锚点
Args:
chapter: 当前章节
db: 数据库会话
max_length: 最大长度
Returns:
上一章结尾内容
"""
if chapter.chapter_number <= 1:
return None
# 查询上一章
result = await db.execute(
select(Chapter)
.where(Chapter.project_id == chapter.project_id)
.where(Chapter.chapter_number == chapter.chapter_number - 1)
)
prev_chapter = result.scalar_one_or_none()
if not prev_chapter or not prev_chapter.content:
return None
# 提取结尾内容
content = prev_chapter.content.strip()
if len(content) <= max_length:
return content
return content[-max_length:]
async def _build_chapter_characters(
self,
chapter: Chapter,
project: Project,
outline: Optional[Outline],
db: AsyncSession
) -> str:
"""
构建本章涉及的角色信息精简版
只返回本章相关的角色而非全部角色
Args:
chapter: 章节对象
project: 项目对象
outline: 大纲对象
db: 数据库会话
Returns:
本章角色信息文本
"""
# 获取所有角色
characters_result = await db.execute(
select(Character).where(Character.project_id == project.id)
)
characters = characters_result.scalars().all()
if not characters:
return "暂无角色信息"
# 提取本章相关角色名单
filter_character_names = None
# 从大纲或扩展计划中提取角色
if project.outline_mode == 'one-to-one':
if outline and outline.structure:
try:
structure = json.loads(outline.structure)
filter_character_names = structure.get('characters', [])
except json.JSONDecodeError:
pass
else:
if chapter.expansion_plan:
try:
plan = json.loads(chapter.expansion_plan)
filter_character_names = plan.get('character_focus', [])
except json.JSONDecodeError:
pass
# 筛选角色
if filter_character_names:
characters = [c for c in characters if c.name in filter_character_names]
if not characters:
return "暂无相关角色"
# 构建精简的角色信息(每个角色最多100字符)
char_lines = []
for c in characters[:10]: # 最多10个角色
role_type = "主角" if c.role_type == "protagonist" else (
"反派" if c.role_type == "antagonist" else "配角"
)
# 性格摘要(最多50字符)
personality_brief = ""
if c.personality:
personality_brief = c.personality[:50]
if len(c.personality) > 50:
personality_brief += "..."
char_lines.append(f"- {c.name}({role_type}): {personality_brief}")
return "\n".join(char_lines)
def _extract_emotional_tone(
self,
chapter: Chapter,
outline: Optional[Outline]
) -> str:
"""
提取本章情感基调
Args:
chapter: 章节对象
outline: 大纲对象
Returns:
情感基调描述
"""
# 尝试从扩展计划中提取
if chapter.expansion_plan:
try:
plan = json.loads(chapter.expansion_plan)
tone = plan.get('emotional_tone')
if tone:
return tone
except json.JSONDecodeError:
pass
# 尝试从大纲结构中提取
if outline and outline.structure:
try:
structure = json.loads(outline.structure)
tone = structure.get('emotion') or structure.get('emotional_tone')
if tone:
return tone
except json.JSONDecodeError:
pass
return "未设定"
def _summarize_style(self, style_content: str) -> str:
"""
将风格描述压缩为关键要点
Args:
style_content: 完整风格描述
Returns:
摘要化的风格描述
"""
if not style_content:
return ""
if len(style_content) <= self.STYLE_MAX_LENGTH:
return style_content
# 简单截断(后续可以用AI提取关键词)
return style_content[:self.STYLE_MAX_LENGTH] + "..."
async def _get_relevant_memories(
self,
user_id: str,
project_id: str,
chapter_number: int,
chapter_outline: str,
limit: int = 3
) -> Optional[str]:
"""
获取与本章最相关的记忆精简版
策略
1. 仅检索与大纲语义最相关的记忆
2. 提高重要性阈值过滤低质量记忆
3. 优先返回未回收的伏笔
Args:
user_id: 用户ID
project_id: 项目ID
chapter_number: 当前章节号
chapter_outline: 本章大纲
limit: 返回数量限制
Returns:
格式化的记忆文本
"""
if not self.memory_service:
return None
try:
# 1. 语义检索相关记忆(提高阈值)
relevant = await self.memory_service.search_memories(
user_id=user_id,
project_id=project_id,
query=chapter_outline,
limit=limit,
min_importance=self.MEMORY_IMPORTANCE_THRESHOLD
)
# 2. 检查即将到期的伏笔
foreshadows = await self._get_due_foreshadows(
user_id, project_id, chapter_number,
lookahead=5 # 仅看5章内需要回收的
)
# 3. 合并并格式化
return self._format_memories(relevant, foreshadows, max_length=500)
except Exception as e:
logger.error(f"❌ 获取相关记忆失败: {str(e)}")
return None
async def _get_due_foreshadows(
self,
user_id: str,
project_id: str,
chapter_number: int,
lookahead: int = 5
) -> List[Dict[str, Any]]:
"""
获取即将需要回收的伏笔
Args:
user_id: 用户ID
project_id: 项目ID
chapter_number: 当前章节号
lookahead: 往前看的章节数
Returns:
待回收伏笔列表
"""
if not self.memory_service:
return []
try:
foreshadows = await self.memory_service.find_unresolved_foreshadows(
user_id, project_id, chapter_number
)
# 过滤:只保留埋下时间较长(超过lookahead章)的伏笔
due_foreshadows = []
for fs in foreshadows:
meta = fs.get('metadata', {})
fs_chapter = meta.get('chapter_number', 0)
if chapter_number - fs_chapter >= lookahead:
due_foreshadows.append({
'chapter': fs_chapter,
'content': fs.get('content', '')[:60],
'importance': meta.get('importance', 0.5)
})
return due_foreshadows[:2] # 最多2条
except Exception as e:
logger.error(f"❌ 获取待回收伏笔失败: {str(e)}")
return []
def _format_memories(
self,
relevant: List[Dict[str, Any]],
foreshadows: List[Dict[str, Any]],
max_length: int = 500
) -> str:
"""
格式化记忆为简洁文本严格限制长度
Args:
relevant: 相关记忆列表
foreshadows: 待回收伏笔列表
max_length: 最大长度
Returns:
格式化的记忆文本
"""
lines = []
current_length = 0
# 优先添加待回收伏笔
if foreshadows:
lines.append("【待回收伏笔】")
for fs in foreshadows[:2]:
text = f"- 第{fs['chapter']}章埋下:{fs['content']}"
if current_length + len(text) > max_length:
break
lines.append(text)
current_length += len(text)
# 添加相关记忆
if relevant and current_length < max_length:
lines.append("【相关记忆】")
for mem in relevant:
content = mem.get('content', '')[:80]
text = f"- {content}"
if current_length + len(text) > max_length:
break
lines.append(text)
current_length += len(text)
return "\n".join(lines) if lines else None
async def _build_story_skeleton(
self,
project_id: str,
chapter_number: int,
db: AsyncSession
) -> Optional[str]:
"""
构建故事骨架每N章采样
Args:
project_id: 项目ID
chapter_number: 当前章节号
db: 数据库会话
Returns:
故事骨架文本
"""
try:
# 获取所有已完成章节的摘要
result = await db.execute(
select(Chapter.chapter_number, Chapter.title)
.where(Chapter.project_id == project_id)
.where(Chapter.chapter_number < chapter_number)
.where(Chapter.content != None)
.where(Chapter.content != "")
.order_by(Chapter.chapter_number)
)
chapters = result.all()
if not chapters:
return None
# 采样:每N章取一个
skeleton_lines = ["【故事骨架】"]
for i, (ch_num, ch_title) in enumerate(chapters):
if i % self.SKELETON_SAMPLE_INTERVAL == 0:
# 尝试获取章节摘要
summary_result = await db.execute(
select(StoryMemory.content)
.where(StoryMemory.project_id == project_id)
.where(StoryMemory.story_timeline == ch_num)
.where(StoryMemory.memory_type == 'chapter_summary')
.limit(1)
)
summary = summary_result.scalar_one_or_none()
if summary:
skeleton_lines.append(f"{ch_num}章《{ch_title}》:{summary[:100]}")
else:
skeleton_lines.append(f"{ch_num}章《{ch_title}")
if len(skeleton_lines) <= 1:
return None
return "\n".join(skeleton_lines)
except Exception as e:
logger.error(f"❌ 构建故事骨架失败: {str(e)}")
return None
class FocusedMemoryRetriever:
"""
精简记忆检索器
相比原有的memory_service提供更精准更简洁的记忆检索
"""
def __init__(self, memory_service):
"""
初始化检索器
Args:
memory_service: 基础记忆服务实例
"""
self.memory_service = memory_service
async def get_relevant_memories(
self,
user_id: str,
project_id: str,
chapter_number: int,
chapter_outline: str,
limit: int = 3
) -> str:
"""
获取与本章最相关的记忆
策略
1. 仅检索与大纲语义最相关的记忆
2. 提高重要性阈值过滤低质量记忆
3. 优先返回未回收的伏笔
Args:
user_id: 用户ID
project_id: 项目ID
chapter_number: 当前章节号
chapter_outline: 本章大纲
limit: 返回数量限制
Returns:
格式化的记忆文本
"""
# 1. 语义检索相关记忆(提高阈值)
relevant = await self.memory_service.search_memories(
user_id=user_id,
project_id=project_id,
query=chapter_outline,
limit=limit,
min_importance=0.7 # 从0.4提高到0.7
)
# 2. 检查即将到期的伏笔
due_foreshadows = await self._get_due_foreshadows(
user_id, project_id, chapter_number,
lookahead=5 # 仅看5章内需要回收的
)
# 3. 合并并格式化
return self._format_memories(relevant, due_foreshadows, max_length=500)
async def _get_due_foreshadows(
self,
user_id: str,
project_id: str,
chapter_number: int,
lookahead: int = 5
) -> List[Dict[str, Any]]:
"""获取即将需要回收的伏笔"""
foreshadows = await self.memory_service.find_unresolved_foreshadows(
user_id, project_id, chapter_number
)
# 过滤:只保留埋下时间较长的伏笔
due_foreshadows = []
for fs in foreshadows:
meta = fs.get('metadata', {})
fs_chapter = meta.get('chapter_number', 0)
if chapter_number - fs_chapter >= lookahead:
due_foreshadows.append({
'chapter': fs_chapter,
'content': fs.get('content', '')[:60],
'importance': meta.get('importance', 0.5)
})
return due_foreshadows[:2] # 最多2条
def _format_memories(
self,
relevant: List[Dict[str, Any]],
foreshadows: List[Dict[str, Any]],
max_length: int = 500
) -> str:
"""格式化为简洁文本,严格限制长度"""
lines = []
current_length = 0
# 优先添加待回收伏笔
if foreshadows:
lines.append("【待回收伏笔】")
for fs in foreshadows[:2]:
text = f"- 第{fs['chapter']}章埋下:{fs['content']}"
if current_length + len(text) > max_length:
break
lines.append(text)
current_length += len(text)
# 添加相关记忆
if relevant and current_length < max_length:
lines.append("【相关记忆】")
for mem in relevant:
content = mem.get('content', '')[:80]
text = f"- {content}"
if current_length + len(text) > max_length:
break
lines.append(text)
current_length += len(text)
return "\n".join(lines) if lines else ""
+248
View File
@@ -0,0 +1,248 @@
"""章节重新生成服务"""
from typing import Dict, Any, AsyncGenerator, Optional, List
from sqlalchemy.ext.asyncio import AsyncSession
from app.services.ai_service import AIService
from app.services.prompt_service import prompt_service, PromptService
from app.models.chapter import Chapter
from app.models.memory import PlotAnalysis
from app.schemas.regeneration import ChapterRegenerateRequest, PreserveElementsConfig
from app.logger import get_logger
import difflib
logger = get_logger(__name__)
class ChapterRegenerator:
"""章节重新生成服务"""
def __init__(self, ai_service: AIService):
self.ai_service = ai_service
logger.info("✅ ChapterRegenerator初始化成功")
async def regenerate_with_feedback(
self,
chapter: Chapter,
analysis: Optional[PlotAnalysis],
regenerate_request: ChapterRegenerateRequest,
project_context: Dict[str, Any],
style_content: str = "",
user_id: str = None,
db: AsyncSession = None
) -> AsyncGenerator[Dict[str, Any], None]:
"""
根据反馈重新生成章节流式
Args:
chapter: 原始章节对象
analysis: 分析结果可选
regenerate_request: 重新生成请求参数
project_context: 项目上下文项目信息角色大纲等
style_content: 写作风格
user_id: 用户ID用于获取自定义提示词
db: 数据库会话用于查询自定义提示词
Yields:
包含类型和数据的字典: {'type': 'progress'/'chunk', 'data': ...}
"""
try:
logger.info(f"🔄 开始重新生成章节: 第{chapter.chapter_number}")
# 1. 构建修改指令
yield {'type': 'progress', 'progress': 5, 'message': '正在构建修改指令...'}
modification_instructions = self._build_modification_instructions(
analysis=analysis,
regenerate_request=regenerate_request
)
logger.info(f"📝 修改指令构建完成,长度: {len(modification_instructions)}字符")
# 2. 构建完整提示词
yield {'type': 'progress', 'progress': 10, 'message': '正在构建生成提示词...'}
full_prompt = await self._build_regeneration_prompt(
chapter=chapter,
modification_instructions=modification_instructions,
project_context=project_context,
regenerate_request=regenerate_request,
style_content=style_content,
user_id=user_id,
db=db
)
logger.info(f"🎯 提示词构建完成,开始AI生成")
yield {'type': 'progress', 'progress': 15, 'message': '开始AI生成内容...'}
# 3. 构建系统提示词(注入写作风格)
system_prompt_with_style = None
if style_content:
system_prompt_with_style = f"""【🎨 写作风格要求 - 最高优先级】
{style_content}
请严格遵循上述写作风格要求进行重写这是最重要的指令
确保在整个章节重写过程中始终保持风格的一致性"""
logger.info(f"✅ 已将写作风格注入系统提示词({len(style_content)}字符)")
# 4. 流式生成新内容,同时跟踪进度
target_word_count = regenerate_request.target_word_count
accumulated_length = 0
async for chunk in self.ai_service.generate_text_stream(
prompt=full_prompt,
system_prompt=system_prompt_with_style,
temperature=0.7
):
# 发送内容块
yield {'type': 'chunk', 'content': chunk}
# 更新累积字数并计算进度(15%-95%)
accumulated_length += len(chunk)
# 进度从15%开始,到95%结束,为后处理预留5%
generation_progress = min(15 + (accumulated_length / target_word_count) * 80, 95)
yield {'type': 'progress', 'progress': int(generation_progress), 'word_count': accumulated_length}
logger.info(f"✅ 章节重新生成完成,共生成 {accumulated_length}")
yield {'type': 'progress', 'progress': 100, 'message': '生成完成'}
except Exception as e:
logger.error(f"❌ 重新生成失败: {str(e)}", exc_info=True)
raise
def _build_modification_instructions(
self,
analysis: Optional[PlotAnalysis],
regenerate_request: ChapterRegenerateRequest
) -> str:
"""构建修改指令"""
instructions = []
# 标题
instructions.append("# 章节修改指令\n")
# 1. 来自分析的建议
if (analysis and
regenerate_request.selected_suggestion_indices and
analysis.suggestions):
instructions.append("## 📋 需要改进的问题(来自AI分析):\n")
for idx in regenerate_request.selected_suggestion_indices:
if 0 <= idx < len(analysis.suggestions):
suggestion = analysis.suggestions[idx]
instructions.append(f"{idx + 1}. {suggestion}")
instructions.append("")
# 2. 用户自定义指令
if regenerate_request.custom_instructions:
instructions.append("## ✍️ 用户自定义修改要求:\n")
instructions.append(regenerate_request.custom_instructions)
instructions.append("")
# 3. 重点优化方向
if regenerate_request.focus_areas:
instructions.append("## 🎯 重点优化方向:\n")
focus_map = {
"pacing": "节奏把控 - 调整叙事速度,避免拖沓或过快",
"emotion": "情感渲染 - 深化人物情感表达,增强感染力",
"description": "场景描写 - 丰富环境细节,增强画面感",
"dialogue": "对话质量 - 让对话更自然真实,推动剧情",
"conflict": "冲突强度 - 强化矛盾冲突,提升戏剧张力"
}
for area in regenerate_request.focus_areas:
if area in focus_map:
instructions.append(f"- {focus_map[area]}")
instructions.append("")
# 4. 保留要求
if regenerate_request.preserve_elements:
preserve = regenerate_request.preserve_elements
instructions.append("## 🔒 必须保留的元素:\n")
if preserve.preserve_structure:
instructions.append("- 保持原章节的整体结构和情节框架")
if preserve.preserve_dialogues:
instructions.append("- 必须保留以下关键对话:")
for dialogue in preserve.preserve_dialogues:
instructions.append(f" * {dialogue}")
if preserve.preserve_plot_points:
instructions.append("- 必须保留以下关键情节点:")
for plot in preserve.preserve_plot_points:
instructions.append(f" * {plot}")
if preserve.preserve_character_traits:
instructions.append("- 保持所有角色的性格特征和行为模式一致")
instructions.append("")
return "\n".join(instructions)
async def _build_regeneration_prompt(
self,
chapter: Chapter,
modification_instructions: str,
project_context: Dict[str, Any],
regenerate_request: ChapterRegenerateRequest,
style_content: str = "",
user_id: str = None,
db: AsyncSession = None
) -> str:
"""构建完整的重新生成提示词"""
# 使用PromptService的get_chapter_regeneration_prompt方法
# 该方法会处理自定义模板加载和完整提示词构建
return await PromptService.get_chapter_regeneration_prompt(
chapter_number=chapter.chapter_number,
title=chapter.title,
word_count=chapter.word_count,
content=chapter.content,
modification_instructions=modification_instructions,
project_context=project_context,
style_content=style_content,
target_word_count=regenerate_request.target_word_count,
user_id=user_id,
db=db
)
def calculate_content_diff(
self,
original_content: str,
new_content: str
) -> Dict[str, Any]:
"""
计算两个版本的差异
Returns:
差异统计信息
"""
# 基本统计
diff_stats = {
'original_length': len(original_content),
'new_length': len(new_content),
'length_change': len(new_content) - len(original_content),
'length_change_percent': round((len(new_content) - len(original_content)) / len(original_content) * 100, 2) if len(original_content) > 0 else 0
}
# 计算相似度
similarity = difflib.SequenceMatcher(None, original_content, new_content).ratio()
diff_stats['similarity'] = round(similarity * 100, 2)
diff_stats['difference'] = round((1 - similarity) * 100, 2)
# 段落统计
original_paragraphs = [p for p in original_content.split('\n\n') if p.strip()]
new_paragraphs = [p for p in new_content.split('\n\n') if p.strip()]
diff_stats['original_paragraph_count'] = len(original_paragraphs)
diff_stats['new_paragraph_count'] = len(new_paragraphs)
return diff_stats
# 全局实例
_regenerator_instance = None
def get_chapter_regenerator(ai_service: AIService) -> ChapterRegenerator:
"""获取章节重新生成器实例"""
global _regenerator_instance
if _regenerator_instance is None:
_regenerator_instance = ChapterRegenerator(ai_service)
return _regenerator_instance
+505 -32
View File
@@ -1,7 +1,7 @@
"""导入导出服务"""
import json
from datetime import datetime
from typing import Dict, List, Optional, Tuple
from typing import Dict, List, Optional, Tuple, Any
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
from app.models.project import Project
@@ -77,6 +77,8 @@ class ImportExportService:
"chapter_count": project.chapter_count,
"narrative_perspective": project.narrative_perspective,
"character_count": project.character_count,
"outline_mode": project.outline_mode,
"user_id": project.user_id,
"created_at": project.created_at.isoformat() if project.created_at else None,
}
@@ -143,18 +145,41 @@ class ImportExportService:
)
chapters = result.scalars().all()
return [
ChapterExportData(
# 构建大纲ID到标题的映射
outline_mapping = {}
if chapters:
outline_ids = [ch.outline_id for ch in chapters if ch.outline_id]
if outline_ids:
outline_result = await db.execute(
select(Outline).where(Outline.id.in_(outline_ids))
)
outlines = outline_result.scalars().all()
outline_mapping = {ol.id: ol.title for ol in outlines}
exported_chapters = []
for ch in chapters:
# 解析expansion_plan JSON
expansion_plan = None
if ch.expansion_plan:
try:
expansion_plan = json.loads(ch.expansion_plan) if isinstance(ch.expansion_plan, str) else ch.expansion_plan
except:
expansion_plan = None
exported_chapters.append(ChapterExportData(
title=ch.title,
content=ch.content,
summary=ch.summary,
chapter_number=ch.chapter_number,
word_count=ch.word_count or 0,
status=ch.status,
created_at=ch.created_at.isoformat() if ch.created_at else None
)
for ch in chapters
]
created_at=ch.created_at.isoformat() if ch.created_at else None,
outline_title=outline_mapping.get(ch.outline_id) if ch.outline_id else None,
sub_index=ch.sub_index,
expansion_plan=expansion_plan
))
return exported_chapters
@staticmethod
async def _export_characters(project_id: str, db: AsyncSession) -> List[CharacterExportData]:
@@ -315,10 +340,19 @@ class ImportExportService:
@staticmethod
async def _export_writing_styles(project_id: str, db: AsyncSession) -> List[WritingStyleExportData]:
"""导出写作风格"""
"""导出写作风格(用户自定义风格)"""
# 获取项目所属用户
project_result = await db.execute(
select(Project).where(Project.id == project_id)
)
project = project_result.scalar_one_or_none()
if not project:
return []
# 导出该用户的自定义风格(不包括全局预设)
result = await db.execute(
select(WritingStyle)
.where(WritingStyle.project_id == project_id)
.where(WritingStyle.user_id == project.user_id)
.order_by(WritingStyle.order_index)
)
styles = result.scalars().all()
@@ -423,7 +457,8 @@ class ImportExportService:
@staticmethod
async def import_project(
data: Dict,
db: AsyncSession
db: AsyncSession,
user_id: str
) -> ImportResult:
"""
导入项目数据创建新项目
@@ -431,6 +466,7 @@ class ImportExportService:
Args:
data: 导入的JSON数据
db: 数据库会话
user_id: 目标用户ID导入后的项目归属
Returns:
ImportResult: 导入结果
@@ -456,6 +492,7 @@ class ImportExportService:
# 创建项目
project_data = data["project"]
new_project = Project(
user_id=user_id, # 设置为当前用户ID
title=project_data.get("title"),
description=project_data.get("description"),
theme=project_data.get("theme"),
@@ -469,6 +506,7 @@ class ImportExportService:
chapter_count=project_data.get("chapter_count"),
narrative_perspective=project_data.get("narrative_perspective"),
character_count=project_data.get("character_count"),
outline_mode=project_data.get("outline_mode", "one-to-many"), # ✅ 导入大纲模式,默认为一对多
current_words=project_data.get("current_words", 0), # 保留原项目的字数
wizard_step=4, # 导入的项目设置为向导完成状态
wizard_status="completed" # 标记向导已完成
@@ -478,26 +516,26 @@ class ImportExportService:
logger.info(f"创建项目成功: {new_project.id}")
# 导入章节
chapters_count = await ImportExportService._import_chapters(
new_project.id, data.get("chapters", []), db
)
statistics["chapters"] = chapters_count
logger.info(f"导入章节数: {chapters_count}")
# 导入角色(包括组织)
# 导入角色(包括组织)- 需要先导入角色,因为大纲可能需要角色信息
char_mapping = await ImportExportService._import_characters(
new_project.id, data.get("characters", []), db
)
statistics["characters"] = len(char_mapping)
logger.info(f"导入角色数: {len(char_mapping)}")
# 导入大纲
outlines_count = await ImportExportService._import_outlines(
# 导入大纲 - 需要在章节之前导入,以便建立关联
outline_mapping = await ImportExportService._import_outlines(
new_project.id, data.get("outlines", []), db
)
statistics["outlines"] = outlines_count
logger.info(f"导入大纲数: {outlines_count}")
statistics["outlines"] = len(outline_mapping)
logger.info(f"导入大纲数: {len(outline_mapping)}")
# 导入章节 - 使用大纲映射重建关联关系
chapters_count = await ImportExportService._import_chapters(
new_project.id, data.get("chapters", []), outline_mapping, db
)
statistics["chapters"] = chapters_count
logger.info(f"导入章节数: {chapters_count}")
# 导入关系
relationships_count = await ImportExportService._import_relationships(
@@ -554,11 +592,23 @@ class ImportExportService:
async def _import_chapters(
project_id: str,
chapters_data: List[Dict],
outline_mapping: Dict[str, str],
db: AsyncSession
) -> int:
"""导入章节"""
count = 0
for ch_data in chapters_data:
# 根据大纲标题查找对应的新大纲ID
outline_id = None
outline_title = ch_data.get("outline_title")
if outline_title and outline_title in outline_mapping:
outline_id = outline_mapping[outline_title]
# 处理expansion_plan
expansion_plan = ch_data.get("expansion_plan")
if expansion_plan and isinstance(expansion_plan, dict):
expansion_plan = json.dumps(expansion_plan, ensure_ascii=False)
chapter = Chapter(
project_id=project_id,
title=ch_data.get("title"),
@@ -566,7 +616,10 @@ class ImportExportService:
summary=ch_data.get("summary"),
chapter_number=ch_data.get("chapter_number"),
word_count=ch_data.get("word_count", 0),
status=ch_data.get("status", "draft")
status=ch_data.get("status", "draft"),
outline_id=outline_id,
sub_index=ch_data.get("sub_index"),
expansion_plan=expansion_plan
)
db.add(chapter)
count += 1
@@ -585,7 +638,7 @@ class ImportExportService:
for char_data in characters_data:
# 处理traits
traits = char_data.get("traits")
if traits and isinstance(traits, list):
if isinstance(traits, list):
traits = json.dumps(traits, ensure_ascii=False)
character = Character(
@@ -613,9 +666,10 @@ class ImportExportService:
project_id: str,
outlines_data: List[Dict],
db: AsyncSession
) -> int:
"""导入大纲"""
count = 0
) -> Dict[str, str]:
"""导入大纲,返回标题到ID的映射"""
outline_mapping = {}
for ol_data in outlines_data:
outline = Outline(
project_id=project_id,
@@ -625,9 +679,10 @@ class ImportExportService:
order_index=ol_data.get("order_index")
)
db.add(outline)
count += 1
await db.flush() # 获取ID
outline_mapping[ol_data.get("title")] = outline.id
return count
return outline_mapping
@staticmethod
async def _import_relationships(
@@ -751,11 +806,30 @@ class ImportExportService:
styles_data: List[Dict],
db: AsyncSession
) -> int:
"""导入写作风格"""
"""导入写作风格(用户自定义风格)"""
# 获取项目所属用户
project_result = await db.execute(
select(Project).where(Project.id == project_id)
)
project = project_result.scalar_one_or_none()
if not project:
return 0
count = 0
for style_data in styles_data:
# 检查是否已存在同名风格(避免重复导入)
existing = await db.execute(
select(WritingStyle).where(
WritingStyle.user_id == project.user_id,
WritingStyle.name == style_data.get("name")
)
)
if existing.scalar_one_or_none():
logger.debug(f"风格 {style_data.get('name')} 已存在,跳过导入")
continue
style = WritingStyle(
project_id=project_id,
user_id=project.user_id, # 使用 user_id 而不是 project_id
name=style_data.get("name"),
style_type=style_data.get("style_type"),
preset_id=style_data.get("preset_id"),
@@ -766,4 +840,403 @@ class ImportExportService:
db.add(style)
count += 1
return count
return count
@staticmethod
async def export_characters(
character_ids: List[str],
db: AsyncSession
) -> Dict[str, Any]:
"""
导出角色/组织卡片
Args:
character_ids: 要导出的角色/组织ID列表
db: 数据库会话
Returns:
Dict: 导出的角色数据
"""
logger.info(f"开始导出角色/组织: {len(character_ids)}")
# 查询角色数据
result = await db.execute(
select(Character).where(Character.id.in_(character_ids))
)
characters = result.scalars().all()
if not characters:
raise ValueError("未找到指定的角色/组织")
# 导出角色数据
exported_characters = []
for char in characters:
# 解析 traits
traits = None
if char.traits:
try:
traits = json.loads(char.traits) if isinstance(char.traits, str) else char.traits
except:
traits = None
# 基础角色数据
char_data = {
"name": char.name,
"age": char.age,
"gender": char.gender,
"is_organization": char.is_organization or False,
"role_type": char.role_type,
"personality": char.personality,
"background": char.background,
"appearance": char.appearance,
"relationships": char.relationships,
"traits": traits,
"organization_type": char.organization_type,
"organization_purpose": char.organization_purpose,
"organization_members": char.organization_members,
"avatar_url": char.avatar_url,
"main_career_id": char.main_career_id,
"main_career_stage": char.main_career_stage,
"sub_careers": char.sub_careers,
"created_at": char.created_at.isoformat() if char.created_at else None
}
# 如果是组织,添加组织专属字段
if char.is_organization:
org_result = await db.execute(
select(Organization).where(Organization.character_id == char.id)
)
org = org_result.scalar_one_or_none()
if org:
char_data.update({
"power_level": org.power_level,
"location": org.location,
"motto": org.motto,
"color": org.color
})
exported_characters.append(char_data)
export_data = {
"version": ImportExportService.SUPPORTED_VERSION,
"export_time": datetime.utcnow().isoformat(),
"export_type": "characters",
"count": len(exported_characters),
"data": exported_characters
}
logger.info(f"角色/组织导出完成: {len(exported_characters)}")
return export_data
@staticmethod
async def import_characters(
data: Dict,
project_id: str,
user_id: str,
db: AsyncSession
) -> Dict[str, Any]:
"""
导入角色/组织卡片
Args:
data: 导入的JSON数据
project_id: 目标项目ID
user_id: 用户ID
db: 数据库会话
Returns:
Dict: 导入结果
"""
from app.models.career import CharacterCareer, Career
warnings = []
imported_characters = []
imported_organizations = []
skipped = []
errors = []
try:
# 验证数据格式
if "data" not in data:
raise ValueError("导入数据格式错误:缺少data字段")
characters_data = data["data"]
if not isinstance(characters_data, list):
raise ValueError("导入数据格式错误:data字段必须是数组")
# 验证项目权限
project_result = await db.execute(
select(Project).where(
Project.id == project_id,
Project.user_id == user_id
)
)
project = project_result.scalar_one_or_none()
if not project:
raise ValueError("项目不存在或无权访问")
logger.info(f"开始导入 {len(characters_data)} 个角色/组织到项目 {project_id}")
# 处理每个角色/组织
for idx, char_data in enumerate(characters_data):
try:
name = char_data.get("name")
if not name:
errors.append(f"{idx+1}个角色缺少name字段")
continue
# 检查重复名称
existing_result = await db.execute(
select(Character).where(
Character.project_id == project_id,
Character.name == name
)
)
existing = existing_result.scalar_one_or_none()
if existing:
warnings.append(f"角色'{name}'已存在,已跳过")
skipped.append(name)
continue
# 处理traits
traits = char_data.get("traits")
if isinstance(traits, list):
traits = json.dumps(traits, ensure_ascii=False)
is_organization = char_data.get("is_organization", False)
# 创建角色
character = Character(
project_id=project_id,
name=name,
age=char_data.get("age"),
gender=char_data.get("gender"),
is_organization=is_organization,
role_type=char_data.get("role_type"),
personality=char_data.get("personality"),
background=char_data.get("background"),
appearance=char_data.get("appearance"),
relationships=char_data.get("relationships"),
traits=traits,
organization_type=char_data.get("organization_type"),
organization_purpose=char_data.get("organization_purpose"),
organization_members=char_data.get("organization_members"),
avatar_url=char_data.get("avatar_url"),
main_career_id=None, # 职业ID需要验证后再设置
main_career_stage=char_data.get("main_career_stage"),
sub_careers=None # 副职业需要验证后再设置
)
db.add(character)
await db.flush() # 获取character.id
# 处理主职业(如果有)
main_career_id = char_data.get("main_career_id")
main_career_stage = char_data.get("main_career_stage")
if main_career_id and not is_organization:
# 验证职业是否存在
career_result = await db.execute(
select(Career).where(
Career.id == main_career_id,
Career.project_id == project_id,
Career.type == 'main'
)
)
career = career_result.scalar_one_or_none()
if career:
character.main_career_id = main_career_id
character.main_career_stage = main_career_stage or 1
# 创建职业关联
char_career = CharacterCareer(
character_id=character.id,
career_id=main_career_id,
career_type='main',
current_stage=main_career_stage or 1,
stage_progress=0
)
db.add(char_career)
else:
warnings.append(f"角色'{name}'的主职业ID不存在,已忽略职业信息")
# 处理副职业(如果有)
sub_careers = char_data.get("sub_careers")
if sub_careers and not is_organization:
try:
sub_careers_data = json.loads(sub_careers) if isinstance(sub_careers, str) else sub_careers
if isinstance(sub_careers_data, list):
valid_sub_careers = []
for sub_data in sub_careers_data[:2]: # 最多2个副职业
if isinstance(sub_data, dict):
career_id = sub_data.get('career_id')
stage = sub_data.get('stage', 1)
if career_id:
# 验证副职业是否存在
career_result = await db.execute(
select(Career).where(
Career.id == career_id,
Career.project_id == project_id,
Career.type == 'sub'
)
)
career = career_result.scalar_one_or_none()
if career:
valid_sub_careers.append({
'career_id': career_id,
'stage': stage
})
# 创建副职业关联
char_career = CharacterCareer(
character_id=character.id,
career_id=career_id,
career_type='sub',
current_stage=stage,
stage_progress=0
)
db.add(char_career)
if valid_sub_careers:
character.sub_careers = json.dumps(valid_sub_careers, ensure_ascii=False)
elif sub_careers_data:
warnings.append(f"角色'{name}'的副职业ID不存在,已忽略副职业信息")
except Exception as e:
warnings.append(f"角色'{name}'的副职业数据解析失败: {str(e)}")
# 如果是组织,创建Organization记录
if is_organization:
organization = Organization(
character_id=character.id,
project_id=project_id,
member_count=0,
power_level=char_data.get("power_level", 50),
location=char_data.get("location"),
motto=char_data.get("motto"),
color=char_data.get("color")
)
db.add(organization)
await db.flush()
imported_organizations.append(name)
else:
imported_characters.append(name)
logger.info(f"导入{'组织' if is_organization else '角色'}成功: {name}")
except Exception as e:
error_msg = f"导入角色'{char_data.get('name', f'{idx+1}')}'失败: {str(e)}"
logger.error(error_msg)
errors.append(error_msg)
continue
# 提交事务
await db.commit()
total = len(imported_characters) + len(imported_organizations)
result = {
"success": True,
"message": f"成功导入 {total} 个角色/组织",
"statistics": {
"total": len(characters_data),
"imported": total,
"skipped": len(skipped),
"errors": len(errors)
},
"details": {
"imported_characters": imported_characters,
"imported_organizations": imported_organizations,
"skipped": skipped,
"errors": errors
},
"warnings": warnings
}
logger.info(f"角色/组织导入完成: 成功{total}个,跳过{len(skipped)}个,失败{len(errors)}")
return result
except Exception as e:
await db.rollback()
logger.error(f"导入角色/组织失败: {str(e)}", exc_info=True)
return {
"success": False,
"message": f"导入失败: {str(e)}",
"statistics": {
"total": len(characters_data) if "data" in data else 0,
"imported": len(imported_characters) + len(imported_organizations),
"skipped": len(skipped),
"errors": len(errors)
},
"details": {
"imported_characters": imported_characters,
"imported_organizations": imported_organizations,
"skipped": skipped,
"errors": errors
},
"warnings": warnings
}
@staticmethod
def validate_characters_import(data: Dict) -> Dict[str, Any]:
"""
验证角色/组织导入数据
Args:
data: 导入的JSON数据
Returns:
Dict: 验证结果
"""
errors = []
warnings = []
# 检查版本
version = data.get("version", "")
if not version:
errors.append("缺少版本信息")
elif version != ImportExportService.SUPPORTED_VERSION:
warnings.append(f"版本不匹配: 导入文件版本为 {version}, 当前支持版本为 {ImportExportService.SUPPORTED_VERSION}")
# 检查导出类型
export_type = data.get("export_type", "")
if export_type != "characters":
errors.append(f"导出类型错误: 期望'characters',实际'{export_type}'")
# 检查数据字段
if "data" not in data:
errors.append("缺少data字段")
elif not isinstance(data["data"], list):
errors.append("data字段必须是数组")
else:
characters_data = data["data"]
# 统计信息
character_count = sum(1 for c in characters_data if not c.get("is_organization", False))
org_count = sum(1 for c in characters_data if c.get("is_organization", False))
# 检查必填字段
for idx, char_data in enumerate(characters_data):
if not char_data.get("name"):
errors.append(f"{idx+1}个角色缺少name字段")
statistics = {
"characters": character_count,
"organizations": org_count
}
if "data" not in data or errors:
statistics = {"characters": 0, "organizations": 0}
return {
"valid": len(errors) == 0,
"version": version,
"statistics": statistics,
"errors": errors,
"warnings": warnings
}
+159
View File
@@ -0,0 +1,159 @@
"""JSON 处理工具类"""
import json
import re
from typing import Any, Dict, List, Union
from app.logger import get_logger
logger = get_logger(__name__)
def clean_json_response(text: str) -> str:
"""清洗 AI 返回的 JSON(改进版 - 流式安全)"""
try:
if not text:
logger.warning("⚠️ clean_json_response: 输入为空")
return text
original_length = len(text)
logger.debug(f"🔍 开始清洗JSON,原始长度: {original_length}")
# 去除 markdown 代码块
text = re.sub(r'^```json\s*\n?', '', text, flags=re.MULTILINE | re.IGNORECASE)
text = re.sub(r'^```\s*\n?', '', text, flags=re.MULTILINE)
text = re.sub(r'\n?```\s*$', '', text, flags=re.MULTILINE)
text = text.strip()
if len(text) != original_length:
logger.debug(f" 移除markdown后长度: {len(text)}")
# 尝试直接解析(快速路径)
try:
json.loads(text)
logger.debug(f"✅ 直接解析成功,无需清洗")
return text
except:
pass
# 找到第一个 { 或 [
start = -1
for i, c in enumerate(text):
if c in ('{', '['):
start = i
break
if start == -1:
logger.warning(f"⚠️ 未找到JSON起始符号 {{ 或 [")
logger.debug(f" 文本预览: {text[:200]}")
return text
if start > 0:
logger.debug(f" 跳过前{start}个字符")
text = text[start:]
# 改进的括号匹配算法(更严格的字符串处理)
stack = []
i = 0
end = -1
in_string = False
while i < len(text):
c = text[i]
# 处理字符串状态
if c == '"':
if not in_string:
# 进入字符串
in_string = True
else:
# 检查是否是转义的引号
num_backslashes = 0
j = i - 1
while j >= 0 and text[j] == '\\':
num_backslashes += 1
j -= 1
# 偶数个反斜杠表示引号未被转义,字符串结束
if num_backslashes % 2 == 0:
in_string = False
i += 1
continue
# 在字符串内部,跳过所有字符
if in_string:
i += 1
continue
# 处理括号(只有在字符串外部才有效)
if c == '{' or c == '[':
stack.append(c)
elif c == '}':
if len(stack) > 0 and stack[-1] == '{':
stack.pop()
if len(stack) == 0:
end = i + 1
logger.debug(f"✅ 找到JSON结束位置: {end}")
break
elif len(stack) > 0:
# 括号不匹配,可能是损坏的JSON,尝试继续
logger.warning(f"⚠️ 括号不匹配:遇到 }} 但栈顶是 {stack[-1]}")
else:
# 栈为空遇到 },忽略多余的闭合括号
logger.warning(f"⚠️ 遇到多余的 }},忽略")
elif c == ']':
if len(stack) > 0 and stack[-1] == '[':
stack.pop()
if len(stack) == 0:
end = i + 1
logger.debug(f"✅ 找到JSON结束位置: {end}")
break
elif len(stack) > 0:
# 括号不匹配,可能是损坏的JSON,尝试继续
logger.warning(f"⚠️ 括号不匹配:遇到 ] 但栈顶是 {stack[-1]}")
else:
# 栈为空遇到 ],忽略多余的闭合括号
logger.warning(f"⚠️ 遇到多余的 ],忽略")
i += 1
# 检查未闭合的字符串
if in_string:
logger.warning(f"⚠️ 字符串未闭合,JSON可能不完整")
# 提取结果
if end > 0:
result = text[:end]
logger.debug(f"✅ JSON清洗完成,结果长度: {len(result)}")
else:
result = text
logger.warning(f"⚠️ 未找到JSON结束位置,返回全部内容(长度: {len(result)}")
logger.debug(f" 栈状态: {stack}")
# 验证清洗后的结果
try:
json.loads(result)
logger.debug(f"✅ 清洗后JSON验证成功")
except json.JSONDecodeError as e:
logger.error(f"❌ 清洗后JSON仍然无效: {e}")
logger.debug(f" 结果预览: {result[:500]}")
logger.debug(f" 结果结尾: ...{result[-200:]}")
return result
except Exception as e:
logger.error(f"❌ clean_json_response 出错: {e}")
logger.error(f" 文本长度: {len(text) if text else 0}")
logger.error(f" 文本预览: {text[:200] if text else 'None'}")
raise
def parse_json(text: str) -> Union[Dict, List]:
"""解析 JSON"""
try:
cleaned = clean_json_response(text)
return json.loads(cleaned)
except Exception as e:
logger.error(f"❌ parse_json 出错: {e}")
logger.error(f" 原始文本长度: {len(text) if text else 0}")
logger.error(f" 清洗后文本长度: {len(cleaned) if cleaned else 0}")
raise
+347
View File
@@ -0,0 +1,347 @@
"""MCP插件测试服务 - 专门处理插件测试逻辑
重构后使用统一的MCPClientFacade门面来管理所有MCP操作
"""
import time
import json
from typing import Dict, Any, Optional
from datetime import datetime
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
from app.models.mcp_plugin import MCPPlugin
from app.models.settings import Settings as UserSettings
from app.mcp import mcp_client, MCPPluginConfig # 使用新的统一门面
from app.services.ai_service import create_user_ai_service
from app.schemas.mcp_plugin import MCPTestResult
from app.services.prompt_service import prompt_service
from app.logger import get_logger
from app.user_manager import User
logger = get_logger(__name__)
class MCPTestService:
"""MCP插件测试服务(使用统一门面重构)"""
async def _ensure_plugin_registered(
self,
plugin: MCPPlugin,
user_id: str
) -> bool:
"""
确保插件已注册到统一门面
Args:
plugin: 插件配置
user_id: 用户ID
Returns:
是否成功
"""
if plugin.plugin_type in ("http", "streamable_http", "sse") and plugin.server_url:
return await mcp_client.ensure_registered(
user_id=user_id,
plugin_name=plugin.plugin_name,
url=plugin.server_url,
plugin_type=plugin.plugin_type,
headers=plugin.headers
)
return False
async def test_plugin_connection(
self,
plugin: MCPPlugin,
user_id: str
) -> MCPTestResult:
"""
简单连接测试
Args:
plugin: 插件配置
user_id: 用户ID
Returns:
测试结果
"""
start_time = time.time()
try:
# 确保插件已注册
registered = await self._ensure_plugin_registered(plugin, user_id)
if not registered:
return MCPTestResult(
success=False,
message="插件注册失败",
error="无法创建MCP客户端",
suggestions=["请检查插件配置", "请确认服务器URL正确"]
)
# 使用统一门面测试连接
test_result = await mcp_client.test_connection(user_id, plugin.plugin_name)
end_time = time.time()
response_time = round((end_time - start_time) * 1000, 2)
if test_result["success"]:
return MCPTestResult(
success=True,
message=f"✅ 连接测试成功",
response_time_ms=response_time,
tools_count=test_result.get("tools_count", 0),
suggestions=[
f"响应时间: {response_time}ms",
f"可用工具数: {test_result.get('tools_count', 0)}"
]
)
else:
return MCPTestResult(
success=False,
message="❌ 连接测试失败",
response_time_ms=response_time,
error=test_result.get("message", "未知错误"),
error_type=test_result.get("error_type"),
suggestions=[
"请检查服务器是否在线",
"请确认配置正确",
"请检查API Key是否有效"
]
)
except Exception as e:
end_time = time.time()
response_time = round((end_time - start_time) * 1000, 2)
logger.error(f"测试插件失败: {plugin.plugin_name}, 错误: {e}")
return MCPTestResult(
success=False,
message="❌ 测试失败",
response_time_ms=response_time,
error=str(e),
error_type=type(e).__name__,
suggestions=[
"请检查服务器是否在线",
"请确认配置正确",
"请检查API Key是否有效"
]
)
async def test_plugin_with_ai(
self,
plugin: MCPPlugin,
user: User,
db_session: AsyncSession
) -> MCPTestResult:
"""
使用AI进行智能工具调用测试
Args:
plugin: 插件配置
user: 用户对象
db_session: 数据库会话
Returns:
测试结果
"""
start_time = time.time()
try:
# 1. 先进行连接测试
connection_result = await self.test_plugin_connection(plugin, user.user_id)
if not connection_result.success:
return connection_result
# 2. 使用统一门面获取工具列表
tools = await mcp_client.get_tools(user.user_id, plugin.plugin_name)
if not tools:
return MCPTestResult(
success=False,
message="插件没有提供任何工具",
error="工具列表为空",
response_time_ms=connection_result.response_time_ms,
suggestions=["请检查插件配置", "请确认MCP服务器正常运行"]
)
# 3. 获取用户的AI设置
settings_result = await db_session.execute(
select(UserSettings).where(UserSettings.user_id == user.user_id)
)
user_settings = settings_result.scalar_one_or_none()
if not user_settings or not user_settings.api_key:
# 没有AI配置,返回简单测试结果
logger.warning("用户未配置AI服务,跳过智能测试")
return MCPTestResult(
success=True,
message=f"✅ 连接测试成功(未配置AI,跳过工具调用测试)",
response_time_ms=connection_result.response_time_ms,
tools_count=len(tools),
suggestions=[
f"连接测试: 成功",
f"可用工具数: {len(tools)}",
"提示: 配置AI服务后可进行智能工具调用测试"
]
)
# 4. 使用AI选择工具并生成测试参数
logger.info(f"使用AI分析工具并生成测试计划...")
ai_service = create_user_ai_service(
api_provider=user_settings.api_provider,
api_key=user_settings.api_key,
api_base_url=user_settings.api_base_url,
model_name=user_settings.llm_model,
temperature=0.3,
max_tokens=1000
)
# 使用统一门面转换为OpenAI Function Calling格式
openai_tools = mcp_client.format_tools_for_openai(tools, plugin.plugin_name)
logger.info(f"📋 转换后的OpenAI工具数量: {len(openai_tools)}")
logger.debug(f"📋 OpenAI工具列表: {[t['function']['name'] for t in openai_tools]}")
# 调用AI选择工具(使用自定义模板系统)
prompts = await prompt_service.get_mcp_tool_test_prompts(
plugin_name=plugin.plugin_name,
user_id=user.user_id,
db=db_session
)
# 使用 generate_text 进行 Function Calling(非流式)
ai_response = await ai_service.generate_text(
prompt=prompts["user"],
system_prompt=prompts["system"],
tools=openai_tools,
tool_choice="auto"
)
accumulated_text = ai_response.get("content", "")
tool_calls = ai_response.get("tool_calls")
# 5. 检查AI是否返回工具调用
if not tool_calls:
logger.error(f"❌ AI未返回工具调用")
return MCPTestResult(
success=False,
message="❌ AI Function Calling失败",
error=f"AI未返回工具调用请求。响应: {accumulated_text[:200] if accumulated_text else 'N/A'}",
tools_count=len(tools),
suggestions=[
"请确认使用的AI模型支持Function Calling",
f"当前Provider: {user_settings.api_provider}",
f"当前模型: {user_settings.llm_model}"
]
)
# 6. 解析工具调用
tool_call = tool_calls[0]
function = tool_call["function"]
tool_name_with_prefix = function["name"]
test_arguments = function["arguments"]
if isinstance(test_arguments, str):
try:
# 使用统一的JSON清洗方法
cleaned_args = ai_service._clean_json_response(test_arguments)
test_arguments = json.loads(cleaned_args)
except json.JSONDecodeError as e:
logger.error(f"❌ 解析AI参数失败: {e}")
return MCPTestResult(
success=False,
message="❌ AI返回的参数格式错误",
error=f"无法解析参数JSON: {str(e)}",
tools_count=len(tools)
)
# 解析插件名和工具名
try:
_, tool_name = mcp_client.parse_function_name(tool_name_with_prefix)
except ValueError:
tool_name = tool_name_with_prefix
logger.info(f"🤖 AI选择的工具: {tool_name}")
logger.info(f"📝 AI生成的参数: {test_arguments}")
# 7. 使用统一门面调用MCP工具
call_start = time.time()
try:
tool_result = await mcp_client.call_tool(
user_id=user.user_id,
plugin_name=plugin.plugin_name,
tool_name=tool_name,
arguments=test_arguments
)
call_end = time.time()
call_time = round((call_end - call_start) * 1000, 2)
total_time = round((call_end - start_time) * 1000, 2)
# 格式化结果
result_str = str(tool_result)
if len(result_str) > 800:
result_preview = result_str[:800] + "\n...(结果已截断)"
else:
result_preview = result_str
return MCPTestResult(
success=True,
message=f"✅ Function Calling测试成功!工具 '{tool_name}' 调用正常",
response_time_ms=total_time,
tools_count=len(tools),
suggestions=[
f"🤖 AI选择: {tool_name}",
f"📝 参数: {json.dumps(test_arguments, ensure_ascii=False)}",
f"⏱️ 耗时: {call_time}ms",
f"📊 结果:\n{result_preview}"
]
)
except Exception as call_error:
call_end = time.time()
total_time = round((call_end - start_time) * 1000, 2)
logger.warning(f"工具调用失败: {tool_name}, 错误: {call_error}")
return MCPTestResult(
success=True, # 连接成功就算测试通过
message=f"⚠️ 连接成功,但工具调用失败",
response_time_ms=total_time,
tools_count=len(tools),
error=f"工具 '{tool_name}' 调用失败: {str(call_error)}",
suggestions=[
f"✅ 连接测试: 成功",
f"❌ 工具调用测试: 失败",
f"🤖 AI选择: {tool_name}",
f"❌ 错误: {str(call_error)}",
"💡 可能原因: API Key无效、参数错误或服务限制"
]
)
except Exception as e:
end_time = time.time()
total_time = round((end_time - start_time) * 1000, 2)
logger.error(f"测试插件失败: {plugin.plugin_name}, 错误: {e}")
return MCPTestResult(
success=False,
message="❌ 测试失败",
response_time_ms=total_time,
error=str(e),
error_type=type(e).__name__,
suggestions=[
"请检查服务器是否在线",
"请确认配置正确",
"请检查API Key是否有效"
]
)
# 全局单例
mcp_test_service = MCPTestService()
+235
View File
@@ -0,0 +1,235 @@
"""MCP工具加载器 - 统一的工具获取入口
在AI请求之前自动检查用户MCP配置并加载可用工具
"""
from typing import Optional, List, Dict, Any
from datetime import datetime, timedelta
from dataclasses import dataclass, field
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.logger import get_logger
from app.models.mcp_plugin import MCPPlugin
from app.mcp import mcp_client
logger = get_logger(__name__)
@dataclass
class UserToolsCache:
"""用户工具缓存条目"""
tools: Optional[List[Dict[str, Any]]]
expire_time: datetime
hit_count: int = 0
class MCPToolsLoader:
"""
MCP工具加载器
负责
1. 检查用户是否配置并启用了MCP插件
2. 从各个启用的插件加载工具列表
3. 将工具转换为OpenAI Function Calling格式
4. 缓存结果以提升性能
"""
_instance: Optional['MCPToolsLoader'] = None
def __new__(cls):
"""单例模式"""
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance._initialized = False
return cls._instance
def __init__(self):
if self._initialized:
return
# 用户工具缓存: user_id -> UserToolsCache
self._cache: Dict[str, UserToolsCache] = {}
# 缓存TTL5分钟)
self._cache_ttl = timedelta(minutes=5)
self._initialized = True
logger.info("✅ MCPToolsLoader 初始化完成")
async def has_enabled_plugins(
self,
user_id: str,
db_session: AsyncSession
) -> bool:
"""
检查用户是否有启用的MCP插件
Args:
user_id: 用户ID
db_session: 数据库会话
Returns:
是否有启用的插件
"""
try:
query = select(MCPPlugin.id).where(
MCPPlugin.user_id == user_id,
MCPPlugin.enabled == True,
MCPPlugin.plugin_type.in_(["http", "streamable_http", "sse"])
).limit(1)
result = await db_session.execute(query)
return result.scalar() is not None
except Exception as e:
logger.warning(f"检查用户MCP插件失败: {e}")
return False
async def get_user_tools(
self,
user_id: str,
db_session: AsyncSession,
use_cache: bool = True,
force_refresh: bool = False
) -> Optional[List[Dict[str, Any]]]:
"""
获取用户的MCP工具列表OpenAI格式
Args:
user_id: 用户ID
db_session: 数据库会话
use_cache: 是否使用缓存
force_refresh: 是否强制刷新
Returns:
- None: 用户未配置或未启用任何MCP插件
- []: 有配置但没有可用工具
- List[Dict]: OpenAI Function Calling格式的工具列表
"""
now = datetime.now()
# 检查缓存
if use_cache and not force_refresh and user_id in self._cache:
cache_entry = self._cache[user_id]
if now < cache_entry.expire_time:
cache_entry.hit_count += 1
logger.debug(f"🎯 用户工具缓存命中: {user_id} (命中次数: {cache_entry.hit_count})")
return cache_entry.tools
else:
del self._cache[user_id]
logger.debug(f"⏰ 用户工具缓存过期: {user_id}")
# 从数据库加载
try:
tools = await self._load_user_tools(user_id, db_session)
# 更新缓存
self._cache[user_id] = UserToolsCache(
tools=tools,
expire_time=now + self._cache_ttl
)
if tools:
logger.info(f"🔧 用户 {user_id} 加载了 {len(tools)} 个MCP工具")
else:
logger.debug(f"📭 用户 {user_id} 没有可用的MCP工具")
return tools
except Exception as e:
logger.error(f"❌ 加载用户MCP工具失败: {e}")
return None
async def _load_user_tools(
self,
user_id: str,
db_session: AsyncSession
) -> Optional[List[Dict[str, Any]]]:
"""
从数据库加载用户启用的MCP插件并获取工具
"""
# 查询启用的插件
query = select(MCPPlugin).where(
MCPPlugin.user_id == user_id,
MCPPlugin.enabled == True,
MCPPlugin.plugin_type.in_(["http", "streamable_http", "sse"])
).order_by(MCPPlugin.sort_order)
result = await db_session.execute(query)
plugins = result.scalars().all()
if not plugins:
return None
all_tools = []
for plugin in plugins:
try:
# 确定插件类型
plugin_type = plugin.plugin_type
if plugin_type == "http":
plugin_type = "streamable_http" # 默认使用streamable_http
# 确保插件已注册到MCP客户端
await mcp_client.ensure_registered(
user_id=user_id,
plugin_name=plugin.plugin_name,
url=plugin.server_url,
plugin_type=plugin_type,
headers=plugin.headers
)
# 获取工具列表
plugin_tools = await mcp_client.get_tools(user_id, plugin.plugin_name)
# 转换为OpenAI格式
formatted = mcp_client.format_tools_for_openai(plugin_tools, plugin.plugin_name)
all_tools.extend(formatted)
logger.debug(f"✅ 从插件 {plugin.plugin_name} 加载了 {len(formatted)} 个工具")
except Exception as e:
logger.warning(f"⚠️ 加载插件 {plugin.plugin_name} 工具失败: {e}")
continue
return all_tools if all_tools else None
def invalidate_cache(self, user_id: Optional[str] = None):
"""
使缓存失效
Args:
user_id: 用户ID为None时清空所有缓存
"""
if user_id:
if user_id in self._cache:
del self._cache[user_id]
logger.debug(f"🧹 清理用户工具缓存: {user_id}")
else:
count = len(self._cache)
self._cache.clear()
logger.info(f"🧹 清理所有用户工具缓存 ({count}个)")
def get_cache_stats(self) -> Dict[str, Any]:
"""获取缓存统计"""
now = datetime.now()
return {
"total_entries": len(self._cache),
"total_hits": sum(e.hit_count for e in self._cache.values()),
"cache_ttl_minutes": self._cache_ttl.total_seconds() / 60,
"entries": [
{
"user_id": uid,
"tools_count": len(e.tools) if e.tools else 0,
"hit_count": e.hit_count,
"expired": now >= e.expire_time,
"expire_time": e.expire_time.isoformat()
}
for uid, e in self._cache.items()
]
}
# 全局单例
mcp_tools_loader = MCPToolsLoader()
+162 -22
View File
@@ -10,10 +10,60 @@ import hashlib
logger = get_logger(__name__)
# 配置模型缓存目录(不设置离线模式,让它自动选择)
# 如果本地有模型就用本地的,没有才联网下载
# 配置模型缓存目录
# 优先使用 backend/embedding 目录(打包后的实际位置)
import sys
from pathlib import Path
if 'SENTENCE_TRANSFORMERS_HOME' not in os.environ:
os.environ['SENTENCE_TRANSFORMERS_HOME'] = 'embedding'
# 根据运行环境确定模型目录
if getattr(sys, 'frozen', False):
# PyInstaller 打包后 - 需要检查多个可能的位置
exe_dir = Path(sys.executable).parent
# 检查顺序:
# 1. _MEIPASS/backend/embedding (临时解压目录)
# 2. exe同级/_internal/backend/embedding
# 3. exe同级/backend/embedding
possible_paths = []
if hasattr(sys, '_MEIPASS'):
possible_paths.append(Path(sys._MEIPASS) / 'backend' / 'embedding')
possible_paths.extend([
exe_dir / '_internal' / 'backend' / 'embedding',
exe_dir / 'backend' / 'embedding',
exe_dir / '_internal' / 'embedding',
exe_dir / 'embedding'
])
model_dir = None
for path in possible_paths:
if path.exists():
model_dir = path
logger.info(f"🔧 找到打包环境模型目录: {model_dir}")
break
if model_dir:
os.environ['SENTENCE_TRANSFORMERS_HOME'] = str(model_dir)
else:
# 最后降级方案
fallback_dir = exe_dir / 'embedding'
os.environ['SENTENCE_TRANSFORMERS_HOME'] = str(fallback_dir)
logger.warning(f"⚠️ 未找到预打包模型,使用降级目录: {fallback_dir}")
logger.warning(f" 检查过的路径: {[str(p) for p in possible_paths]}")
else:
# 开发模式,从当前文件位置向上找到项目根目录
base_dir = Path(__file__).parent.parent.parent
model_dir = base_dir / 'backend' / 'embedding'
if model_dir.exists():
os.environ['SENTENCE_TRANSFORMERS_HOME'] = str(model_dir)
logger.info(f"🔧 设置开发环境模型目录: {model_dir}")
else:
# 降级到项目根目录的 embedding
fallback_dir = base_dir / 'embedding'
os.environ['SENTENCE_TRANSFORMERS_HOME'] = str(fallback_dir)
logger.info(f"🔧 使用降级模型目录: {fallback_dir}")
class MemoryService:
@@ -44,9 +94,10 @@ class MemoryService:
# 初始化多语言embedding模型(支持中文)
logger.info("🔄 正在加载Embedding模型...")
# 确保模型缓存目录存在
model_cache_dir = 'embedding'
# 使用环境变量中配置的模型目录
model_cache_dir = os.environ.get('SENTENCE_TRANSFORMERS_HOME', 'embedding')
os.makedirs(model_cache_dir, exist_ok=True)
logger.info(f"📂 使用模型缓存目录: {os.path.abspath(model_cache_dir)}")
# 调试信息:打印环境变量和路径
logger.info(f"📂 当前工作目录: {os.getcwd()}")
@@ -56,40 +107,91 @@ class MemoryService:
logger.info(f"🔧 HF_HUB_OFFLINE: {os.environ.get('HF_HUB_OFFLINE', '未设置')}")
# 检查模型目录内容
if os.path.exists(model_cache_dir):
abs_cache_dir = os.path.abspath(model_cache_dir)
logger.info(f"📂 检查模型缓存目录: {abs_cache_dir}")
if os.path.exists(abs_cache_dir):
logger.info(f"📁 模型目录存在,检查内容...")
try:
items = os.listdir(model_cache_dir)
logger.info(f"📁 模型目录内容: {items}")
items = os.listdir(abs_cache_dir)
logger.info(f"📁 模型目录内容 ({len(items)} 项): {items}")
# 检查是否有预期的模型文件夹
expected_model_dir = os.path.join(model_cache_dir, 'models--sentence-transformers--paraphrase-multilingual-MiniLM-L12-v2')
expected_model_dir = os.path.join(abs_cache_dir, 'models--sentence-transformers--paraphrase-multilingual-MiniLM-L12-v2')
logger.info(f"🔍 检查预期路径: {expected_model_dir}")
if os.path.exists(expected_model_dir):
logger.info(f"✅ 找到本地模型目录: {expected_model_dir}")
logger.info(f"✅ 找到本地模型目录!")
# 检查快照目录
snapshots_dir = os.path.join(expected_model_dir, 'snapshots')
if os.path.exists(snapshots_dir):
snapshots = os.listdir(snapshots_dir)
logger.info(f"📁 模型快照: {snapshots}")
logger.info(f"📁 模型快照 ({len(snapshots)} 个): {snapshots}")
# 检查是否有有效的快照
if snapshots:
logger.info(f"✅ 发现有效快照,可以使用离线模式")
else:
logger.warning(f"⚠️ 未找到本地模型目录: {expected_model_dir}")
logger.warning(f"⚠️ 未找到本地模型目录")
logger.warning(f" 预期位置: {expected_model_dir}")
except Exception as e:
logger.error(f"❌ 检查模型目录失败: {str(e)}")
import traceback
logger.error(f" 堆栈: {traceback.format_exc()}")
else:
logger.warning(f"⚠️ 模型目录不存在: {os.path.abspath(model_cache_dir)}")
logger.warning(f"⚠️ 模型目录不存在: {abs_cache_dir}")
try:
logger.info("🔄 尝试加载主模型: sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2")
# 优先使用本地缓存的模型
# cache_folder会让模型优先从本地加载,只有不存在时才联网下载
# 注意:不要设置local_files_only=True,这会阻止fallback到联网下载
self.embedding_model = SentenceTransformer(
'sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2',
cache_folder=model_cache_dir,
device='cpu', # 明确指定使用CPU
trust_remote_code=False # 安全起见
# 使用绝对路径检查本地模型
abs_cache_dir = os.path.abspath(model_cache_dir)
local_model_path = os.path.join(
abs_cache_dir,
'models--sentence-transformers--paraphrase-multilingual-MiniLM-L12-v2'
)
logger.info("✅ Embedding模型加载成功 (paraphrase-multilingual-MiniLM-L12-v2)")
logger.info(f"🔍 检查本地模型路径: {local_model_path}")
logger.info(f"🔍 路径存在检查: {os.path.exists(local_model_path)}")
# 检查快照目录是否存在且有内容
snapshots_dir = os.path.join(local_model_path, 'snapshots')
has_valid_model = False
if os.path.exists(snapshots_dir):
try:
snapshots = os.listdir(snapshots_dir)
if snapshots:
logger.info(f"✅ 发现本地模型快照: {snapshots}")
has_valid_model = True
except Exception as e:
logger.warning(f"⚠️ 检查快照失败: {e}")
# 优先尝试从本地路径加载
if has_valid_model:
logger.info(f"✅ 检测到完整本地模型,使用离线模式加载")
try:
self.embedding_model = SentenceTransformer(
'sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2',
cache_folder=abs_cache_dir,
device='cpu',
trust_remote_code=True,
local_files_only=True # 强制使用本地文件
)
logger.info("✅ Embedding模型加载成功 (离线模式)")
except Exception as local_err:
logger.warning(f"⚠️ 离线模式加载失败: {str(local_err)}")
logger.info("🔄 尝试在线模式...")
raise local_err
else:
logger.info("📥 本地模型不完整或不存在,将联网下载...")
logger.info(f" 下载后将保存到: {abs_cache_dir}")
self.embedding_model = SentenceTransformer(
'sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2',
cache_folder=abs_cache_dir,
device='cpu',
trust_remote_code=True,
local_files_only=False # 允许联网下载
)
logger.info("✅ Embedding模型加载成功 (在线下载)")
except Exception as e:
logger.warning(f"⚠️ 无法加载多语言模型: {str(e)}")
logger.error(f"❌ 详细错误: {repr(e)}")
@@ -659,6 +761,44 @@ class MemoryService:
logger.error(f"❌ 删除章节记忆失败: {str(e)}")
return False
async def delete_project_memories(
self,
user_id: str,
project_id: str
) -> bool:
"""
删除指定项目的所有记忆(包括向量数据库)
Args:
user_id: 用户ID
project_id: 项目ID
Returns:
是否删除成功
"""
try:
# 生成collection名称
user_hash = hashlib.sha256(user_id.encode()).hexdigest()[:8]
project_hash = hashlib.sha256(project_id.encode()).hexdigest()[:8]
collection_name = f"u_{user_hash}_p_{project_hash}"
# 删除整个collection(这会清理所有向量数据)
try:
self.client.delete_collection(name=collection_name)
logger.info(f"🗑️ 已删除项目{project_id[:8]}的向量数据库collection: {collection_name}")
return True
except Exception as e:
# 如果collection不存在,也算成功
if "does not exist" in str(e).lower():
logger.info(f"️ 项目{project_id[:8]}的collection不存在,无需删除")
return True
else:
raise
except Exception as e:
logger.error(f"❌ 删除项目记忆失败: {str(e)}")
return False
async def update_memory(
self,
user_id: str,
+7 -4
View File
@@ -20,11 +20,14 @@ class LinuxDOOAuthService:
self.client_secret = settings.LINUXDO_CLIENT_SECRET
self.redirect_uri = settings.LINUXDO_REDIRECT_URI
# 验证redirect_uri配置
# 如果未配置,使用默认值(本地开发)
if not self.redirect_uri:
raise ValueError(
"LINUXDO_REDIRECT_URI 未配置!\n"
"请在 .env 文件中设置正确的回调地址:\n"
self.redirect_uri = "http://localhost:8000/api/auth/callback"
import logging
logger = logging.getLogger(__name__)
logger.warning(
"⚠️ LINUXDO_REDIRECT_URI 未配置,使用默认值: http://localhost:8000/api/auth/callback\n"
"如需使用 OAuth 登录,请在 .env 文件中配置:\n"
"本地开发: LINUXDO_REDIRECT_URI=http://localhost:8000/api/auth/callback\n"
"Docker部署: LINUXDO_REDIRECT_URI=https://your-domain.com/api/auth/callback"
)
+152 -221
View File
@@ -1,9 +1,12 @@
"""剧情分析服务 - 自动分析章节的钩子、伏笔、冲突等元素"""
from typing import Dict, Any, List, Optional
from sqlalchemy.ext.asyncio import AsyncSession
from app.services.ai_service import AIService
from app.services.prompt_service import prompt_service, PromptService
from app.logger import get_logger
import json
import re
import asyncio
logger = get_logger(__name__)
@@ -11,169 +14,6 @@ logger = get_logger(__name__)
class PlotAnalyzer:
"""剧情分析器 - 使用AI分析章节内容"""
# AI分析提示词模板
ANALYSIS_PROMPT = """你是一位专业的小说编辑和剧情分析师。请深度分析以下章节内容:
**章节信息:**
- 章节: {chapter_number}
- 标题: {title}
- 字数: {word_count}
**章节内容:**
{content}
---
**分析任务:**
请从专业编辑的角度,全面分析这一章节:
### 1. 剧情钩子 (Hooks) - 吸引读者的元素
识别能够吸引读者继续阅读的关键元素:
- **悬念钩子**: 未解之谜疑问谜团
- **情感钩子**: 引发共鸣的情感点触动心弦的时刻
- **冲突钩子**: 矛盾对抗紧张局势
- **认知钩子**: 颠覆认知的信息惊人真相
每个钩子需要:
- 类型分类
- 具体内容描述
- 强度评分(1-10)
- 出现位置(开头/中段/结尾)
- **关键词**: 必填从章节原文中逐字复制一段关键文本(8-25)必须是原文中真实存在的连续文字用于在文本中精确定位不要概括或改写必须原样复制
### 2. 伏笔分析 (Foreshadowing)
- **埋下的新伏笔**: 描述内容预期作用隐藏程度(1-10)
- **回收的旧伏笔**: 呼应哪一章回收效果评分
- **伏笔质量**: 巧妙性和合理性评估
- **关键词**: 必填从章节原文中逐字复制一段关键文本(8-25)必须是原文中真实存在的连续文字用于在文本中精确定位不要概括或改写必须原样复制
### 3. 冲突分析 (Conflict)
- 冲突类型: 人与人/人与己/人与环境/人与社会
- 冲突各方及其立场
- 冲突强度评分(1-10)
- 冲突解决进度(0-100%)
### 4. 情感曲线 (Emotional Arc)
- 主导情绪: 紧张/温馨/悲伤/激昂/平静等
- 情感强度(1-10)
- 情绪变化轨迹描述
### 5. 角色状态追踪 (Character Development)
对每个出场角色分析:
- 心理状态变化()
- 关系变化
- 关键行动和决策
- 成长或退步
### 6. 关键情节点 (Plot Points)
列出3-5个核心情节点:
- 情节内容
- 类型(revelation/conflict/resolution/transition)
- 重要性(0.0-1.0)
- 对故事的影响
- **关键词**: 必填从章节原文中逐字复制一段关键文本(8-25)必须是原文中真实存在的连续文字用于在文本中精确定位不要概括或改写必须原样复制
### 7. 场景与节奏
- 主要场景
- 叙事节奏(//)
- 对话与描写的比例
### 8. 质量评分
- 节奏把控: 1-10
- 吸引力: 1-10
- 连贯性: 1-10
- 整体质量: 1-10
### 9. 改进建议
提供3-5条具体的改进建议
---
**输出格式(纯JSON,不要markdown标记):**
{{
"hooks": [
{{
"type": "悬念",
"content": "具体描述",
"strength": 8,
"position": "中段",
"keyword": "必须从原文逐字复制的文本片段"
}}
],
"foreshadows": [
{{
"content": "伏笔内容",
"type": "planted",
"strength": 7,
"subtlety": 8,
"reference_chapter": null,
"keyword": "必须从原文逐字复制的文本片段"
}}
],
"conflict": {{
"types": ["人与人", "人与己"],
"parties": ["主角-复仇", "反派-维护现状"],
"level": 8,
"description": "冲突描述",
"resolution_progress": 0.3
}},
"emotional_arc": {{
"primary_emotion": "紧张",
"intensity": 8,
"curve": "平静→紧张→高潮→释放",
"secondary_emotions": ["期待", "焦虑"]
}},
"character_states": [
{{
"character_name": "张三",
"state_before": "犹豫",
"state_after": "坚定",
"psychological_change": "心理变化描述",
"key_event": "触发事件",
"relationship_changes": {{"李四": "关系改善"}}
}}
],
"plot_points": [
{{
"content": "情节点描述",
"type": "revelation",
"importance": 0.9,
"impact": "推动故事发展",
"keyword": "必须从原文逐字复制的文本片段"
}}
],
"scenes": [
{{
"location": "地点",
"atmosphere": "氛围",
"duration": "时长估计"
}}
],
"pacing": "varied",
"dialogue_ratio": 0.4,
"description_ratio": 0.3,
"scores": {{
"pacing": 8,
"engagement": 9,
"coherence": 8,
"overall": 8.5
}},
"plot_stage": "发展",
"suggestions": [
"具体建议1",
"具体建议2"
]
}}
**重要提示:**
1. 每个钩子伏笔情节点的keyword字段是必填的不能为空
2. keyword必须是从章节原文中逐字复制的文本长度8-25
3. keyword用于在前端标注文本位置所以必须能在原文中精确找到
4. 不要使用概括性语句或改写后的文字作为keyword
只返回JSON,不要其他说明"""
def __init__(self, ai_service: AIService):
"""
初始化剧情分析器
@@ -189,63 +29,135 @@ class PlotAnalyzer:
chapter_number: int,
title: str,
content: str,
word_count: int
word_count: int,
user_id: str = None,
db: AsyncSession = None,
max_retries: int = 3
) -> Optional[Dict[str, Any]]:
"""
分析单章内容
分析单章内容带重试机制
Args:
chapter_number: 章节号
title: 章节标题
content: 章节内容
word_count: 字数
user_id: 用户ID用于获取自定义提示词
db: 数据库会话用于查询自定义提示词
max_retries: 最大重试次数默认3次
Returns:
分析结果字典,失败返回None
"""
logger.info(f"🔍 开始分析第{chapter_number}章: {title}")
# 如果内容过长,截取前8000字(避免超token)
analysis_content = content[:8000] if len(content) > 8000 else content
# 获取自定义提示词模板
try:
logger.info(f"🔍 开始分析第{chapter_number}章: {title}")
# 如果内容过长,截取前8000字(避免超token)
analysis_content = content[:8000] if len(content) > 8000 else content
# 构建提示词
prompt = self.ANALYSIS_PROMPT.format(
chapter_number=chapter_number,
title=title,
word_count=word_count,
content=analysis_content
)
# 调用AI进行分析
# 注意:不指定max_tokens,使用用户在设置中配置的值
logger.info(f" 调用AI分析(内容长度: {len(analysis_content)}字)...")
response = await self.ai_service.generate_text(
prompt=prompt,
temperature=0.3 # 降低温度以获得更稳定的JSON输出
)
# 解析JSON结果
analysis_result = self._parse_analysis_response(response)
if analysis_result:
logger.info(f"✅ 第{chapter_number}章分析完成")
logger.info(f" - 钩子: {len(analysis_result.get('hooks', []))}")
logger.info(f" - 伏笔: {len(analysis_result.get('foreshadows', []))}")
logger.info(f" - 情节点: {len(analysis_result.get('plot_points', []))}")
logger.info(f" - 整体评分: {analysis_result.get('scores', {}).get('overall', 'N/A')}")
return analysis_result
if user_id and db:
template = await PromptService.get_template("PLOT_ANALYSIS", user_id, db)
else:
logger.error(f"❌ 第{chapter_number}章分析失败: JSON解析错误")
return None
# 降级到系统默认模板
template = PromptService.PLOT_ANALYSIS
except Exception as e:
logger.error(f"❌ 章节分析异常: {str(e)}")
return None
logger.warning(f"⚠️ 获取提示词模板失败,使用默认模板: {str(e)}")
template = PromptService.PLOT_ANALYSIS
# 格式化提示词
prompt = PromptService.format_prompt(
template,
chapter_number=chapter_number,
title=title,
word_count=word_count,
content=analysis_content
)
last_error = None
for attempt in range(1, max_retries + 1):
try:
# 调用AI进行分析
logger.info(f" 📡 调用AI分析(内容长度: {len(analysis_content)}字, 尝试 {attempt}/{max_retries})...")
accumulated_text = ""
try:
async for chunk in self.ai_service.generate_text_stream(
prompt=prompt,
temperature=0.3 # 降低温度以获得更稳定的JSON输出
):
accumulated_text += chunk
except GeneratorExit:
# 流式响应被中断
logger.warning(f"⚠️ 流式响应被中断(GeneratorExit),已累积 {len(accumulated_text)} 字符")
# 如果已经累积了足够内容,继续尝试解析
if len(accumulated_text) < 100:
raise Exception("流式响应中断,内容不足")
except Exception as stream_error:
logger.error(f"❌ 流式生成出错: {str(stream_error)}")
raise
# 检查响应是否为空
if not accumulated_text or len(accumulated_text.strip()) < 10:
logger.warning(f"⚠️ AI响应为空或过短(长度: {len(accumulated_text)}), 尝试 {attempt}/{max_retries}")
last_error = "AI响应为空或过短"
if attempt < max_retries:
wait_time = min(2 ** attempt, 10)
logger.info(f" ⏳ 等待 {wait_time} 秒后重试...")
await asyncio.sleep(wait_time)
continue
else:
logger.error(f"❌ 第{chapter_number}章分析失败: AI响应为空,已达最大重试次数")
return None
# 提取内容
response_text = accumulated_text
logger.debug(f" 收到AI响应,长度: {len(response_text)} 字符")
# 解析JSON结果
analysis_result = self._parse_analysis_response(response_text)
if analysis_result:
logger.info(f"✅ 第{chapter_number}章分析完成 (尝试 {attempt}/{max_retries})")
logger.info(f" - 钩子: {len(analysis_result.get('hooks', []))}")
logger.info(f" - 伏笔: {len(analysis_result.get('foreshadows', []))}")
logger.info(f" - 情节点: {len(analysis_result.get('plot_points', []))}")
logger.info(f" - 整体评分: {analysis_result.get('scores', {}).get('overall', 'N/A')}")
return analysis_result
else:
# JSON解析失败,重试
logger.warning(f"⚠️ JSON解析失败, 尝试 {attempt}/{max_retries}")
last_error = "JSON解析失败"
if attempt < max_retries:
wait_time = min(2 ** attempt, 10)
logger.info(f" ⏳ 等待 {wait_time} 秒后重试...")
await asyncio.sleep(wait_time)
continue
else:
logger.error(f"❌ 第{chapter_number}章分析失败: JSON解析错误,已达最大重试次数")
return None
except Exception as e:
last_error = str(e)
logger.error(f"❌ 章节分析异常(尝试 {attempt}/{max_retries}): {last_error}")
if attempt < max_retries:
wait_time = min(2 ** attempt, 10)
logger.info(f" ⏳ 等待 {wait_time} 秒后重试...")
await asyncio.sleep(wait_time)
continue
else:
logger.error(f"❌ 第{chapter_number}章分析失败: {last_error},已达最大重试次数")
return None
# 不应该到达这里,但作为安全措施
logger.error(f"❌ 第{chapter_number}章分析失败: {last_error}")
return None
def _parse_analysis_response(self, response: str) -> Optional[Dict[str, Any]]:
"""
解析AI返回的分析结果
解析AI返回的分析结果使用统一的JSON清洗方法
Args:
response: AI返回的文本
@@ -254,13 +166,8 @@ class PlotAnalyzer:
解析后的字典,失败返回None
"""
try:
# 清理响应文本
cleaned = response.strip()
# 移除可能的markdown标记
cleaned = re.sub(r'^```json\s*', '', cleaned)
cleaned = re.sub(r'^```\s*', '', cleaned)
cleaned = re.sub(r'\s*```$', '', cleaned)
# 使用统一的JSON清洗方法
cleaned = self.ai_service._clean_json_response(response)
# 尝试解析JSON
result = json.loads(cleaned)
@@ -272,22 +179,12 @@ class PlotAnalyzer:
logger.warning(f"⚠️ 分析结果缺少字段: {field}")
result[field] = [] if field != 'scores' else {}
logger.info("✅ 成功解析分析结果")
return result
except json.JSONDecodeError as e:
logger.error(f"❌ JSON解析失败: {str(e)}")
logger.error(f" 原始响应(前500字): {response[:500]}")
# 尝试提取JSON部分
json_match = re.search(r'\{[\s\S]*\}', response)
if json_match:
try:
result = json.loads(json_match.group())
logger.info("✅ 通过正则提取成功解析JSON")
return result
except:
pass
return None
except Exception as e:
logger.error(f"❌ 解析异常: {str(e)}")
@@ -298,7 +195,8 @@ class PlotAnalyzer:
analysis: Dict[str, Any],
chapter_id: str,
chapter_number: int,
chapter_content: str = ""
chapter_content: str = "",
chapter_title: str = ""
) -> List[Dict[str, Any]]:
"""
从分析结果中提取记忆片段
@@ -308,6 +206,7 @@ class PlotAnalyzer:
chapter_id: 章节ID
chapter_number: 章节号
chapter_content: 章节完整内容(用于计算位置)
chapter_title: 章节标题
Returns:
记忆片段列表
@@ -315,6 +214,38 @@ class PlotAnalyzer:
memories = []
try:
# 【新增】0. 提取章节摘要作为记忆(用于语义检索相关章节)
chapter_summary = ""
# 尝试从分析结果获取摘要
if analysis.get('summary'):
chapter_summary = analysis.get('summary')
# 或者从情节点组合生成摘要
elif analysis.get('plot_points'):
plot_summaries = [p.get('content', '') for p in analysis.get('plot_points', [])[:3]]
chapter_summary = "".join(plot_summaries)
# 或者使用内容前300字
elif chapter_content:
chapter_summary = chapter_content[:300] + ("..." if len(chapter_content) > 300 else "")
# 如果有摘要,添加到记忆中
if chapter_summary:
memories.append({
'type': 'chapter_summary',
'content': chapter_summary,
'title': f"{chapter_number}章《{chapter_title}》摘要",
'metadata': {
'chapter_id': chapter_id,
'chapter_number': chapter_number,
'importance_score': 0.6, # 中等重要性
'tags': ['摘要', '章节概览', chapter_title],
'is_foreshadow': 0,
'text_position': 0,
'text_length': len(chapter_summary)
}
})
logger.info(f" ✅ 添加章节摘要记忆: {len(chapter_summary)}")
# 1. 提取钩子作为记忆
for i, hook in enumerate(analysis.get('hooks', [])):
if hook.get('strength', 0) >= 6: # 只保存强度>=6的钩子
@@ -0,0 +1,645 @@
"""大纲剧情展开服务 - 将大纲节点展开为多个章节"""
from typing import List, Dict, Any, Optional
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, func
import json
from app.models.outline import Outline
from app.models.project import Project
from app.models.character import Character
from app.models.chapter import Chapter
from app.services.ai_service import AIService
from app.services.prompt_service import prompt_service, PromptService
from app.logger import get_logger
logger = get_logger(__name__)
class PlotExpansionService:
"""大纲剧情展开服务"""
def __init__(self, ai_service: AIService):
self.ai_service = ai_service
async def analyze_outline_for_chapters(
self,
outline: Outline,
project: Project,
db: AsyncSession,
target_chapter_count: int = 3,
expansion_strategy: str = "balanced",
enable_scene_analysis: bool = True,
provider: Optional[str] = None,
model: Optional[str] = None,
batch_size: int = 5,
progress_callback: Optional[callable] = None
) -> List[Dict[str, Any]]:
"""
分析单个大纲,生成多章节规划支持分批生成
Args:
outline: 大纲对象
project: 项目对象
db: 数据库会话
target_chapter_count: 目标生成章节数
expansion_strategy: 展开策略(balanced/climax/detail)
enable_scene_analysis: 是否启用场景级分析
provider: AI提供商
model: AI模型
batch_size: 每批生成的章节数默认5章
progress_callback: 进度回调函数(可选)
Returns:
章节规划列表
"""
logger.info(f"开始分析大纲 {outline.id},目标生成 {target_chapter_count}")
# 如果章节数较少,直接生成
if target_chapter_count <= batch_size:
return await self._generate_chapters_single_batch(
outline=outline,
project=project,
db=db,
target_chapter_count=target_chapter_count,
expansion_strategy=expansion_strategy,
enable_scene_analysis=enable_scene_analysis,
provider=provider,
model=model
)
# 章节数较多,分批生成
logger.info(f"章节数({target_chapter_count})超过批次大小({batch_size}),启用分批生成")
return await self._generate_chapters_in_batches(
outline=outline,
project=project,
db=db,
target_chapter_count=target_chapter_count,
expansion_strategy=expansion_strategy,
enable_scene_analysis=enable_scene_analysis,
provider=provider,
model=model,
batch_size=batch_size,
progress_callback=progress_callback
)
async def _generate_chapters_single_batch(
self,
outline: Outline,
project: Project,
db: AsyncSession,
target_chapter_count: int,
expansion_strategy: str,
enable_scene_analysis: bool,
provider: Optional[str],
model: Optional[str]
) -> List[Dict[str, Any]]:
"""单批次生成章节规划"""
# 获取角色信息
characters_result = await db.execute(
select(Character).where(Character.project_id == project.id)
)
characters = characters_result.scalars().all()
characters_info = "\n".join([
f"- {char.name} ({'组织' if char.is_organization else '角色'}, {char.role_type}): "
f"{char.personality[:100] if char.personality else '暂无描述'}"
for char in characters
])
# 获取大纲上下文(前后大纲)
context_info = await self._get_outline_context(outline, project.id, db)
# 获取自定义提示词模板
template = await PromptService.get_template("OUTLINE_EXPAND_SINGLE", project.user_id, db)
# 格式化提示词
prompt = PromptService.format_prompt(
template,
project_title=project.title,
project_genre=project.genre or '通用',
project_theme=project.theme or '未设定',
project_narrative_perspective=project.narrative_perspective or '第三人称',
project_world_time_period=project.world_time_period or '未设定',
project_world_location=project.world_location or '未设定',
project_world_atmosphere=project.world_atmosphere or '未设定',
characters_info=characters_info or '暂无角色',
outline_order_index=outline.order_index,
outline_title=outline.title,
outline_content=outline.content,
context_info=context_info,
strategy_instruction=expansion_strategy,
target_chapter_count=target_chapter_count,
scene_instruction="", # 暂时为空
scene_field="" # 暂时为空
)
# 调用AI生成章节规划
logger.info(f"调用AI生成章节规划...")
accumulated_text = ""
async for chunk in self.ai_service.generate_text_stream(
prompt=prompt,
provider=provider,
model=model
):
accumulated_text += chunk
# 提取内容
ai_content = accumulated_text
# 解析AI响应
chapter_plans = self._parse_expansion_response(ai_content, outline.id)
logger.info(f"成功生成 {len(chapter_plans)} 个章节规划")
return chapter_plans
async def _generate_chapters_in_batches(
self,
outline: Outline,
project: Project,
db: AsyncSession,
target_chapter_count: int,
expansion_strategy: str,
enable_scene_analysis: bool,
provider: Optional[str],
model: Optional[str],
batch_size: int,
progress_callback: Optional[callable]
) -> List[Dict[str, Any]]:
"""分批生成章节规划"""
# 计算批次数
total_batches = (target_chapter_count + batch_size - 1) // batch_size
logger.info(f"分批生成计划: 总共{target_chapter_count}章,分{total_batches}批,每批{batch_size}")
# 获取角色信息(所有批次共用)
characters_result = await db.execute(
select(Character).where(Character.project_id == project.id)
)
characters = characters_result.scalars().all()
characters_info = "\n".join([
f"- {char.name} ({'组织' if char.is_organization else '角色'}, {char.role_type}): "
f"{char.personality[:100] if char.personality else '暂无描述'}"
for char in characters
])
# 获取大纲上下文
context_info = await self._get_outline_context(outline, project.id, db)
all_chapter_plans = []
for batch_num in range(total_batches):
# 计算当前批次的章节数
remaining_chapters = target_chapter_count - len(all_chapter_plans)
current_batch_size = min(batch_size, remaining_chapters)
current_start_index = len(all_chapter_plans) + 1
logger.info(f"开始生成第{batch_num + 1}/{total_batches}批,章节范围: {current_start_index}-{current_start_index + current_batch_size - 1}")
# 回调通知进度
if progress_callback:
await progress_callback(batch_num + 1, total_batches, current_start_index, current_batch_size)
# 构建当前批次的提示词(包含已生成章节的上下文)
previous_context = ""
if all_chapter_plans:
previous_summaries = []
for ch in all_chapter_plans[-3:]: # 只显示最近3章
previous_summaries.append(
f"{ch['sub_index']}节《{ch['title']}》: {ch['plot_summary'][:100]}..."
)
previous_context = f"""
已生成章节概要接续生成注意衔接
{chr(10).join(previous_summaries)}
当前是第{current_start_index}-{current_start_index + current_batch_size - 1}{target_chapter_count}节中的一部分
"""
# 获取自定义提示词模板
template = await PromptService.get_template("OUTLINE_EXPAND_MULTI", project.user_id, db)
# 格式化提示词
prompt = PromptService.format_prompt(
template,
project_title=project.title,
project_genre=project.genre or '通用',
project_theme=project.theme or '未设定',
project_narrative_perspective=project.narrative_perspective or '第三人称',
project_world_time_period=project.world_time_period or '未设定',
project_world_location=project.world_location or '未设定',
project_world_atmosphere=project.world_atmosphere or '未设定',
characters_info=characters_info or '暂无角色',
outline_order_index=outline.order_index,
outline_title=outline.title,
outline_content=outline.content,
context_info=context_info,
previous_context=previous_context,
strategy_instruction=expansion_strategy,
start_index=current_start_index,
end_index=current_start_index + current_batch_size - 1,
target_chapter_count=current_batch_size,
scene_instruction="", # 暂时为空
scene_field="" # 暂时为空
)
# 调用AI生成当前批次
logger.info(f"调用AI生成第{batch_num + 1}批...")
accumulated_text = ""
async for chunk in self.ai_service.generate_text_stream(
prompt=prompt,
provider=provider,
model=model
):
accumulated_text += chunk
# 提取内容
ai_content = accumulated_text
# 解析AI响应
batch_plans = self._parse_expansion_response(ai_content, outline.id)
# 调整sub_index以保持连续性
for i, plan in enumerate(batch_plans):
plan["sub_index"] = current_start_index + i
all_chapter_plans.extend(batch_plans)
logger.info(f"{batch_num + 1}批生成完成,本批生成{len(batch_plans)}章,累计{len(all_chapter_plans)}")
logger.info(f"分批生成完成,共生成 {len(all_chapter_plans)} 个章节规划")
return all_chapter_plans
async def batch_expand_outlines(
self,
project_id: str,
db: AsyncSession,
ai_service: AIService,
target_chapters_per_outline: int = 3,
expansion_strategy: str = "balanced",
provider: Optional[str] = None,
model: Optional[str] = None
) -> Dict[str, Any]:
"""
批量展开所有大纲为章节
Returns:
{
"total_outlines": 总大纲数,
"total_chapters_planned": 规划的总章节数,
"expansions": [每个大纲的展开结果]
}
"""
logger.info(f"开始批量展开项目 {project_id} 的所有大纲")
# 获取项目
project_result = await db.execute(
select(Project).where(Project.id == project_id)
)
project = project_result.scalar_one_or_none()
if not project:
raise ValueError(f"项目 {project_id} 不存在")
# 获取所有大纲
outlines_result = await db.execute(
select(Outline)
.where(Outline.project_id == project_id)
.order_by(Outline.order_index)
)
outlines = outlines_result.scalars().all()
if not outlines:
logger.warning(f"项目 {project_id} 没有大纲")
return {
"total_outlines": 0,
"total_chapters_planned": 0,
"expansions": []
}
# 逐个展开大纲
expansions = []
total_chapters = 0
for outline in outlines:
try:
chapter_plans = await self.analyze_outline_for_chapters(
outline=outline,
project=project,
db=db,
target_chapter_count=target_chapters_per_outline,
expansion_strategy=expansion_strategy,
provider=provider,
model=model
)
expansions.append({
"outline_id": outline.id,
"outline_title": outline.title,
"chapter_plans": chapter_plans,
"chapter_count": len(chapter_plans)
})
total_chapters += len(chapter_plans)
logger.info(f"大纲 {outline.title} 展开为 {len(chapter_plans)}")
except Exception as e:
logger.error(f"展开大纲 {outline.id} 失败: {str(e)}")
expansions.append({
"outline_id": outline.id,
"outline_title": outline.title,
"error": str(e),
"chapter_count": 0
})
result = {
"total_outlines": len(outlines),
"total_chapters_planned": total_chapters,
"expansions": expansions
}
logger.info(f"批量展开完成: {len(outlines)} 个大纲 → {total_chapters} 个章节规划")
return result
async def create_chapters_from_plans(
self,
outline_id: str,
chapter_plans: List[Dict[str, Any]],
project_id: str,
db: AsyncSession,
start_chapter_number: int = None
) -> List[Chapter]:
"""
根据章节规划创建实际的章节记录
Args:
outline_id: 大纲ID
chapter_plans: 章节规划列表
project_id: 项目ID
db: 数据库会话
start_chapter_number: 起始章节号如果为None则自动计算
Returns:
创建的章节列表
"""
logger.info(f"根据规划创建 {len(chapter_plans)} 个章节记录")
# 如果没有指定起始章节号,根据大纲顺序自动计算
if start_chapter_number is None:
# 1. 获取当前大纲信息
outline_result = await db.execute(
select(Outline).where(Outline.id == outline_id)
)
current_outline = outline_result.scalar_one_or_none()
if not current_outline:
raise ValueError(f"大纲 {outline_id} 不存在")
# 2. 查询所有在当前大纲之前的大纲(按order_index排序)
prev_outlines_result = await db.execute(
select(Outline)
.where(
Outline.project_id == project_id,
Outline.order_index < current_outline.order_index
)
.order_by(Outline.order_index)
)
prev_outlines = prev_outlines_result.scalars().all()
# 3. 计算前面所有大纲已展开的章节总数
total_prev_chapters = 0
for prev_outline in prev_outlines:
count_result = await db.execute(
select(func.count(Chapter.id))
.where(
Chapter.project_id == project_id,
Chapter.outline_id == prev_outline.id
)
)
total_prev_chapters += count_result.scalar() or 0
# 4. 起始章节号 = 前面所有大纲的章节数 + 1
start_chapter_number = total_prev_chapters + 1
logger.info(f"自动计算起始章节号: {start_chapter_number} (基于大纲order_index={current_outline.order_index}, 前置章节数={total_prev_chapters})")
chapters = []
for idx, plan in enumerate(chapter_plans):
# 保存完整的展开规划数据(JSON格式)
expansion_plan_json = json.dumps({
"key_events": plan.get("key_events", []),
"character_focus": plan.get("character_focus", []),
"emotional_tone": plan.get("emotional_tone", ""),
"narrative_goal": plan.get("narrative_goal", ""),
"conflict_type": plan.get("conflict_type", ""),
"estimated_words": plan.get("estimated_words", 3000),
"scenes": plan.get("scenes", []) if plan.get("scenes") else None
}, ensure_ascii=False)
chapter = Chapter(
project_id=project_id,
outline_id=outline_id,
chapter_number=start_chapter_number + idx,
sub_index=plan.get("sub_index", idx + 1),
title=plan.get("title", f"{start_chapter_number + idx}"),
summary=plan.get("plot_summary", ""),
expansion_plan=expansion_plan_json,
status="draft"
)
db.add(chapter)
chapters.append(chapter)
await db.commit()
for chapter in chapters:
await db.refresh(chapter)
logger.info(f"成功创建 {len(chapters)} 个章节记录(已保存展开规划数据)")
# 重新排序当前大纲之后的所有章节
await self._renumber_subsequent_chapters(
project_id=project_id,
current_outline_id=outline_id,
db=db
)
return chapters
async def _get_outline_context(
self,
outline: Outline,
project_id: str,
db: AsyncSession
) -> str:
"""获取大纲的上下文(前后大纲)"""
# 获取前一个大纲
prev_result = await db.execute(
select(Outline)
.where(
Outline.project_id == project_id,
Outline.order_index < outline.order_index
)
.order_by(Outline.order_index.desc())
.limit(1)
)
prev_outline = prev_result.scalar_one_or_none()
# 获取后一个大纲
next_result = await db.execute(
select(Outline)
.where(
Outline.project_id == project_id,
Outline.order_index > outline.order_index
)
.order_by(Outline.order_index)
.limit(1)
)
next_outline = next_result.scalar_one_or_none()
context = ""
if prev_outline:
context += f"【前一节】{prev_outline.title}: {prev_outline.content[:200]}...\n\n"
if next_outline:
context += f"【后一节】{next_outline.title}: {next_outline.content[:200]}...\n"
return context if context else "(无前后文)"
def _parse_expansion_response(
self,
ai_response: str,
outline_id: str
) -> List[Dict[str, Any]]:
"""解析AI的展开响应(使用统一的JSON清洗方法)"""
try:
# 使用统一的JSON清洗方法
cleaned_text = self.ai_service._clean_json_response(ai_response)
# 解析JSON
chapter_plans = json.loads(cleaned_text)
# 确保是列表
if not isinstance(chapter_plans, list):
chapter_plans = [chapter_plans]
# 为每个章节规划添加outline_id
for plan in chapter_plans:
plan["outline_id"] = outline_id
logger.info(f"✅ 成功解析 {len(chapter_plans)} 个章节规划")
return chapter_plans
except json.JSONDecodeError as e:
logger.error(f"❌ 解析AI响应失败: {e}, 响应内容: {ai_response[:500]}")
# 返回一个基础规划
return [{
"outline_id": outline_id,
"sub_index": 1,
"title": "AI解析失败的默认章节",
"plot_summary": ai_response[:500],
"key_events": ["解析失败"],
"character_focus": [],
"emotional_tone": "未知",
"narrative_goal": "需要重新生成",
"conflict_type": "未知",
"estimated_words": 3000
}]
except Exception as e:
logger.error(f"❌ 解析异常: {str(e)}")
return [{
"outline_id": outline_id,
"sub_index": 1,
"title": "解析异常的默认章节",
"plot_summary": "系统错误",
"key_events": [],
"character_focus": [],
"emotional_tone": "未知",
"narrative_goal": "需要重新生成",
"conflict_type": "未知",
"estimated_words": 3000
}]
async def _renumber_subsequent_chapters(
self,
project_id: str,
current_outline_id: str,
db: AsyncSession
):
"""
重新计算当前大纲之后所有大纲的章节序号
Args:
project_id: 项目ID
current_outline_id: 当前大纲ID
db: 数据库会话
"""
logger.info(f"开始重新排序大纲 {current_outline_id} 之后的所有章节")
# 1. 获取当前大纲信息
current_outline_result = await db.execute(
select(Outline).where(Outline.id == current_outline_id)
)
current_outline = current_outline_result.scalar_one_or_none()
if not current_outline:
logger.warning(f"大纲 {current_outline_id} 不存在,跳过重新排序")
return
# 2. 获取当前大纲及之后的所有大纲(按order_index排序)
subsequent_outlines_result = await db.execute(
select(Outline)
.where(
Outline.project_id == project_id,
Outline.order_index >= current_outline.order_index
)
.order_by(Outline.order_index)
)
subsequent_outlines = subsequent_outlines_result.scalars().all()
# 3. 计算每个大纲的起始章节号
current_chapter_number = 1
# 先计算前面大纲的章节总数
prev_outlines_result = await db.execute(
select(Outline)
.where(
Outline.project_id == project_id,
Outline.order_index < current_outline.order_index
)
.order_by(Outline.order_index)
)
prev_outlines = prev_outlines_result.scalars().all()
for prev_outline in prev_outlines:
count_result = await db.execute(
select(func.count(Chapter.id))
.where(
Chapter.project_id == project_id,
Chapter.outline_id == prev_outline.id
)
)
current_chapter_number += count_result.scalar() or 0
# 4. 逐个大纲更新章节序号
updated_count = 0
for outline in subsequent_outlines:
# 获取该大纲的所有章节(按sub_index排序)
chapters_result = await db.execute(
select(Chapter)
.where(
Chapter.project_id == project_id,
Chapter.outline_id == outline.id
)
.order_by(Chapter.sub_index)
)
chapters = chapters_result.scalars().all()
# 更新每个章节的chapter_number
for chapter in chapters:
if chapter.chapter_number != current_chapter_number:
logger.debug(f"更新章节 {chapter.id}: {chapter.chapter_number} -> {current_chapter_number}")
chapter.chapter_number = current_chapter_number
updated_count += 1
current_chapter_number += 1
# 5. 提交更新
await db.commit()
logger.info(f"重新排序完成,共更新 {updated_count} 个章节的序号")
# 工厂函数
def create_plot_expansion_service(ai_service: AIService) -> PlotExpansionService:
"""创建剧情展开服务实例"""
return PlotExpansionService(ai_service)

Some files were not shown because too many files have changed in this diff Show More