Merge branch 'dev' of https://github.com/xiamuceer-j/MuMuAINovel into dev
This commit is contained in:
@@ -39,6 +39,7 @@ Thumbs.db
|
||||
# 数据库文件(不包含在镜像中)
|
||||
data/*.db
|
||||
backend/data/*.db
|
||||
postgres_data/
|
||||
|
||||
# ChromaDB数据库(不包含在镜像中,会在运行时生成)
|
||||
backend/data/chroma_db/
|
||||
|
||||
+1
-1
@@ -1 +1 @@
|
||||
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
||||
# LFS tracking removed - models downloaded from HuggingFace at runtime
|
||||
|
||||
+16
-12
@@ -32,9 +32,11 @@ WORKDIR /app
|
||||
RUN sed -i 's/deb.debian.org/mirrors.aliyun.com/g' /etc/apt/sources.list.d/debian.sources \
|
||||
&& sed -i 's/security.debian.org/mirrors.aliyun.com/g' /etc/apt/sources.list.d/debian.sources
|
||||
|
||||
# 安装系统依赖
|
||||
# 安装系统依赖(添加数据库工具)
|
||||
RUN apt-get update && apt-get install -y \
|
||||
gcc \
|
||||
postgresql-client \
|
||||
netcat-traditional \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# 复制后端依赖文件
|
||||
@@ -46,21 +48,23 @@ RUN pip install --no-cache-dir torch==2.7.0 --index-url https://download.pytorch
|
||||
# 再安装其他Python依赖(使用阿里云镜像加速)
|
||||
RUN pip install --no-cache-dir -r requirements.txt -i https://mirrors.aliyun.com/pypi/simple/
|
||||
|
||||
# 复制后端代码
|
||||
# 复制后端代码(包含embedding模型)
|
||||
COPY backend/ ./
|
||||
|
||||
# 从前端构建阶段复制构建好的静态文件
|
||||
COPY --from=frontend-builder /frontend/dist ./static
|
||||
|
||||
# 复制 Alembic 迁移配置和脚本(PostgreSQL)
|
||||
COPY backend/alembic-postgres.ini ./alembic.ini
|
||||
COPY backend/alembic/postgres ./alembic
|
||||
COPY backend/scripts/entrypoint.sh /app/entrypoint.sh
|
||||
COPY backend/scripts/migrate.py ./scripts/migrate.py
|
||||
|
||||
# 赋予执行权限
|
||||
RUN chmod +x /app/entrypoint.sh
|
||||
|
||||
# 创建必要的目录
|
||||
RUN mkdir -p /app/data /app/logs /app/embedding
|
||||
|
||||
# 复制预下载的Embedding模型到独立目录(避免被docker-compose的data挂载覆盖)
|
||||
# 这样可以避免首次运行时联网下载约420MB的模型文件
|
||||
COPY backend/embedding /app/embedding
|
||||
|
||||
# 复制环境变量示例文件
|
||||
COPY backend/.env.example ./.env.example
|
||||
RUN mkdir -p /app/data /app/logs
|
||||
|
||||
# 暴露端口
|
||||
EXPOSE 8000
|
||||
@@ -80,5 +84,5 @@ ENV SENTENCE_TRANSFORMERS_HOME=/app/embedding
|
||||
HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
|
||||
CMD python -c "import urllib.request; urllib.request.urlopen('http://localhost:8000/health')" || exit 1
|
||||
|
||||
# 启动命令
|
||||
CMD ["python", "-m", "uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"]
|
||||
# 使用 entrypoint 脚本启动(自动执行迁移)
|
||||
ENTRYPOINT ["/app/entrypoint.sh"]
|
||||
@@ -2,16 +2,47 @@
|
||||
|
||||
<div align="center">
|
||||
|
||||

|
||||

|
||||

|
||||

|
||||

|
||||

|
||||

|
||||
|
||||
**一款基于 AI 的智能小说创作助手,帮助你轻松创作精彩故事**
|
||||
**基于 AI 的智能小说创作助手**
|
||||
|
||||
[特性](#-特性) • [快速开始](#-快速开始) • [部署方式](#-部署方式) • [配置说明](#%EF%B8%8F-配置说明) • [项目结构](#-项目结构)
|
||||
[特性](#-特性) • [快速开始](#-快速开始) • [配置说明](#%EF%B8%8F-配置说明) • [项目结构](#-项目结构)
|
||||
|
||||
</div>
|
||||
|
||||
---
|
||||
|
||||
<div align="center">
|
||||
|
||||
## 💖 支持项目
|
||||
|
||||
如果这个项目对你有帮助,欢迎通过以下方式支持开发:
|
||||
|
||||
**[☕ 请我喝杯咖啡](https://mumuverse.space:1588/)**
|
||||
|
||||
### 🎁 赞助专属权益
|
||||
|
||||
| 权益 | 说明 |
|
||||
|------|------|
|
||||
| 📋 **优先需求响应** | 您的功能需求和问题反馈将获得优先处理 |
|
||||
| 🚀 **Windows一键启动** | 获取免安装EXE程序,双击即可使用 |
|
||||
| 💬 **专属技术支持** | 加入赞助者内部群,获得远程协助和配置指导 |
|
||||
|
||||
### ☕ 赞助金额
|
||||
|
||||
| 金额 | 描述 |
|
||||
|------|------|
|
||||
| ¥5 | 🌶️ 一包辣条 |
|
||||
| ¥10 | 🍱 一顿拼好饭 |
|
||||
| ¥20 | 🧋 一杯咖啡 |
|
||||
| ¥50 | 🍖 一次烧烤 |
|
||||
| ¥99 | 🍲 一顿海底捞 |
|
||||
|
||||
您的支持是我持续开发的动力!🙏
|
||||
|
||||
</div>
|
||||
|
||||
@@ -19,116 +50,387 @@
|
||||
|
||||
## ✨ 特性
|
||||
|
||||
- 🤖 **多 AI 模型支持** - 支持 OpenAI、Google Gemini、Anthropic Claude 等主流 AI 模型
|
||||
- 📝 **智能向导** - 通过向导式引导快速创建小说项目,AI 自动生成大纲、角色和世界观
|
||||
- 👥 **角色管理** - 创建和管理小说角色,包括人物关系、组织架构等
|
||||
- 📖 **章节编辑** - 支持章节的创建、编辑、重新生成和润色功能
|
||||
- 🌐 **世界观设定** - 构建完整的故事世界观和背景设定
|
||||
- 🔐 **多种登录方式** - 支持 LinuxDO OAuth 登录和本地账户登录
|
||||
- 🐳 **Docker 部署** - 一键部署,开箱即用
|
||||
- 💾 **数据持久化** - 基于 SQLite 的本地数据存储,支持多用户隔离
|
||||
- 🎨 **现代化 UI** - 基于 Ant Design 的美观界面,响应式设计
|
||||
- 🤖 **多 AI 模型** - 支持 OpenAI、Gemini、Claude 等主流模型
|
||||
- 📝 **智能向导** - AI 自动生成大纲、角色和世界观
|
||||
- 👥 **角色管理** - 人物关系、组织架构可视化管理
|
||||
- 📖 **章节编辑** - 支持创建、编辑、重新生成和润色
|
||||
- 🌐 **世界观设定** - 构建完整的故事背景
|
||||
- 🔐 **多种登录** - LinuxDO OAuth 或本地账户登录
|
||||
- 💾 **PostgreSQL** - 生产级数据库,多用户数据隔离
|
||||
- 🐳 **Docker 部署** - 一键启动,开箱即用
|
||||
|
||||
## 📸 项目预览
|
||||
|
||||
<details>
|
||||
|
||||
<summary>多图预警</summary>
|
||||
|
||||
<div align="center">
|
||||
|
||||
### 登录界面
|
||||

|
||||
|
||||
### 主界面
|
||||

|
||||
|
||||
### 项目管理
|
||||

|
||||
|
||||
### 赞助我 💖
|
||||

|
||||
|
||||
</div>
|
||||
|
||||
</details>
|
||||
|
||||
## 📋 TODO List
|
||||
|
||||
以下是正在规划和开发中的功能:
|
||||
### ✅ 已完成功能
|
||||
|
||||
- [ ] **灵感模式** - 提供创作灵感和点子生成功能
|
||||
- [✔] **自定义写作风格** - 支持自定义AI写作风格和语言风格
|
||||
- [ ] **支持数据导入导出** - 支持项目数据的导入和导出功能
|
||||
- [ ] **添加prompt调整界面** - 提供可视化的prompt模板编辑和调整界面
|
||||
- [✔] **开放章节内容字数限制** - 支持用户在生成章节内容时设置字数 @wyf007
|
||||
- [ ] **设定追溯与矛盾检测** - 对大纲、世界观、角色档案中的设定支持悬停查看注释,显示相关章节来源和佐证原文;自动检测新章节与已有设定的矛盾(吃书),标记为"矛盾"设定并提供解决建议,当新设定解决矛盾后自动更新注释说明 @lulujiang
|
||||
- [ ] **思维链与章节关系图谱** - 为每章建立思维链,总结与上文的逻辑关系、明暗线发展;可选的章节关系满图功能,自动识别和标注伏笔埋设与揭晓、角色出场与呼应等内在联系,帮助提升小说结构的紧密性和连贯性 @lulujiang
|
||||
- [x] **灵感模式** - 创作灵感和点子生成
|
||||
- [x] **自定义写作风格** - 支持自定义 AI 写作风格
|
||||
- [x] **数据导入导出** - 项目数据的导入导出
|
||||
- [x] **Prompt 调整界面** - 可视化编辑 Prompt 模板
|
||||
- [x] **章节字数限制** - 用户可设置生成字数
|
||||
- [x] **思维链与章节关系图谱** - 可视化章节逻辑关系
|
||||
- [x] **根据分析一键重写** - 根据分析建议重新生成
|
||||
- [x] **Linux DO 自动创建账号** - OAuth 登录自动生成账号
|
||||
- [x] **职业等级体系** - 自定义职业和等级系统,支持修仙境界、魔法等级等多种体系
|
||||
- [x] **角色/组织卡片导入导出** - 单独导出角色和组织卡片,支持跨项目数据共享
|
||||
|
||||
> 💡 如果你有其他功能建议,欢迎提交 Issue 或 Pull Request!
|
||||
### 📝 规划中功能
|
||||
|
||||
- [ ] **伏笔管理** - 智能追踪剧情伏笔,提醒未回收线索,可视化伏笔时间线
|
||||
- [ ] **提示词工坊** - 社区驱动的 Prompt 模板分享平台,一键导入优质提示词
|
||||
|
||||
> 💡 欢迎提交 Issue 或 Pull Request!
|
||||
|
||||
## 🚀 快速开始
|
||||
|
||||
### 前置要求
|
||||
|
||||
- **Docker 部署**:Docker 和 Docker Compose
|
||||
- **本地开发**:Python 3.11+ 和 Node.js 18+
|
||||
- **必需**:至少一个 AI 服务的 API Key(OpenAI/Gemini/Anthropic)
|
||||
- Docker 和 Docker Compose
|
||||
- 至少一个 AI 服务的 API Key(OpenAI/Gemini/Claude)
|
||||
|
||||
### 方式一:从源码构建 Docker 镜像
|
||||
### Docker Compose 部署(推荐)
|
||||
|
||||
```bash
|
||||
# 1. 克隆项目
|
||||
git clone https://github.com/xiamuceer-j/MuMuAINovel.git
|
||||
cd MuMuAINovel
|
||||
|
||||
# 2. 配置环境变量
|
||||
# 2. 配置环境变量(必需)
|
||||
cp backend/.env.example .env
|
||||
# 编辑 .env 文件,填入你的 API Keys
|
||||
# 编辑 .env 文件,填入必要配置(API Key、数据库密码等)
|
||||
|
||||
# 3. 启动服务(会自动构建镜像)
|
||||
# 3. 确保文件准备完整
|
||||
# ⚠️ 重要:确保以下文件存在
|
||||
# - .env(配置文件,必需挂载到容器)
|
||||
# - backend/scripts/init_postgres.sql(数据库初始化脚本)
|
||||
|
||||
# 4. 启动服务
|
||||
docker-compose up -d
|
||||
|
||||
# 4. 访问应用
|
||||
# 5. 访问应用
|
||||
# 打开浏览器访问 http://localhost:8000
|
||||
```
|
||||
|
||||
### 方式二:本地开发
|
||||
> **📌 注意事项**
|
||||
>
|
||||
> 1. **`.env` 文件挂载**: `docker-compose.yml` 会自动将 `.env` 挂载到容器,确保文件存在
|
||||
> 2. **数据库初始化**: `init_postgres.sql` 会在首次启动时自动执行,安装必要的PostgreSQL扩展
|
||||
> 3. **自行构建**: 如需从源码构建,请先下载 embedding 模型文件([加群获取](frontend/public/qq.jpg))
|
||||
|
||||
#### 后端设置
|
||||
### 使用 Docker Hub 镜像(推荐新手)
|
||||
|
||||
```bash
|
||||
# 进入后端目录
|
||||
cd backend
|
||||
|
||||
# 创建虚拟环境
|
||||
python -m venv .venv
|
||||
|
||||
# 激活虚拟环境
|
||||
# Windows:
|
||||
.venv\Scripts\activate
|
||||
# Linux/Mac:
|
||||
source .venv/bin/activate
|
||||
|
||||
# 安装依赖
|
||||
pip install -r requirements.txt
|
||||
|
||||
# 配置环境变量
|
||||
cp .env.example .env
|
||||
# 编辑 .env 文件,填入你的配置
|
||||
|
||||
# 启动后端服务
|
||||
python -m uvicorn app.main:app --host 0.0.0.0 --port 8000 --reload
|
||||
```
|
||||
|
||||
#### 前端设置
|
||||
|
||||
```bash
|
||||
# 进入前端目录
|
||||
cd frontend
|
||||
|
||||
# 安装依赖
|
||||
npm install
|
||||
|
||||
# 开发模式(需要后端已启动)
|
||||
npm run dev
|
||||
|
||||
# 或构建生产版本
|
||||
npm run build
|
||||
```
|
||||
|
||||
## 🐳 部署方式
|
||||
|
||||
### Docker Compose 部署
|
||||
|
||||
#### 使用 Docker Hub 镜像(推荐)
|
||||
|
||||
项目已发布到 Docker Hub,可直接拉取使用:
|
||||
|
||||
```bash
|
||||
# 查看可用版本
|
||||
# 1. 拉取最新镜像(已包含模型文件)
|
||||
docker pull mumujie/mumuainovel:latest
|
||||
|
||||
# 2. 创建 docker-compose.yml(点击下方展开查看完整配置)
|
||||
```
|
||||
|
||||
<details>
|
||||
<summary>📄 点击展开 docker-compose.yml 完整配置</summary>
|
||||
|
||||
```yaml
|
||||
services:
|
||||
postgres:
|
||||
image: postgres:18-alpine
|
||||
container_name: mumuainovel-postgres
|
||||
environment:
|
||||
POSTGRES_DB: ${POSTGRES_DB:-mumuai_novel}
|
||||
POSTGRES_USER: ${POSTGRES_USER:-mumuai}
|
||||
POSTGRES_PASSWORD: ${POSTGRES_PASSWORD:-123456}
|
||||
POSTGRES_INITDB_ARGS: "--encoding=UTF8 --locale=C"
|
||||
TZ: ${TZ:-Asia/Shanghai}
|
||||
volumes:
|
||||
- postgres_data:/var/lib/postgresql/data
|
||||
- ./backend/scripts/init_postgres.sql:/docker-entrypoint-initdb.d/init.sql:ro
|
||||
ports:
|
||||
- "${POSTGRES_PORT:-5432}:5432"
|
||||
restart: unless-stopped
|
||||
healthcheck:
|
||||
test: ["CMD-SHELL", "pg_isready -U ${POSTGRES_USER:-mumuai} -d ${POSTGRES_DB:-mumuai_novel}"]
|
||||
interval: 10s
|
||||
timeout: 5s
|
||||
retries: 5
|
||||
start_period: 10s
|
||||
networks:
|
||||
- ai-story-network
|
||||
command:
|
||||
- postgres
|
||||
- -c
|
||||
- max_connections=${POSTGRES_MAX_CONNECTIONS:-200}
|
||||
- -c
|
||||
- shared_buffers=${POSTGRES_SHARED_BUFFERS:-256MB}
|
||||
- -c
|
||||
- effective_cache_size=${POSTGRES_EFFECTIVE_CACHE_SIZE:-1GB}
|
||||
- -c
|
||||
- maintenance_work_mem=${POSTGRES_MAINTENANCE_WORK_MEM:-64MB}
|
||||
- -c
|
||||
- checkpoint_completion_target=${POSTGRES_CHECKPOINT_COMPLETION_TARGET:-0.9}
|
||||
- -c
|
||||
- wal_buffers=${POSTGRES_WAL_BUFFERS:-16MB}
|
||||
- -c
|
||||
- default_statistics_target=${POSTGRES_DEFAULT_STATISTICS_TARGET:-100}
|
||||
- -c
|
||||
- random_page_cost=${POSTGRES_RANDOM_PAGE_COST:-1.1}
|
||||
- -c
|
||||
- effective_io_concurrency=${POSTGRES_EFFECTIVE_IO_CONCURRENCY:-200}
|
||||
- -c
|
||||
- work_mem=${POSTGRES_WORK_MEM:-4MB}
|
||||
- -c
|
||||
- min_wal_size=${POSTGRES_MIN_WAL_SIZE:-1GB}
|
||||
- -c
|
||||
- max_wal_size=${POSTGRES_MAX_WAL_SIZE:-4GB}
|
||||
|
||||
mumuainovel:
|
||||
image: mumujie/mumuainovel:latest
|
||||
container_name: mumuainovel
|
||||
depends_on:
|
||||
postgres:
|
||||
condition: service_healthy
|
||||
ports:
|
||||
- "${APP_PORT:-8000}:8000"
|
||||
volumes:
|
||||
- ./logs:/app/logs
|
||||
- ./.env:/app/.env:ro
|
||||
environment:
|
||||
# 应用配置
|
||||
- APP_NAME=${APP_NAME:-MuMuAINovel}
|
||||
- APP_VERSION=${APP_VERSION:-1.0.0}
|
||||
- APP_HOST=${APP_HOST:-0.0.0.0}
|
||||
- APP_PORT=8000
|
||||
- DEBUG=${DEBUG:-false}
|
||||
# 数据库配置
|
||||
- DATABASE_URL=postgresql+asyncpg://${POSTGRES_USER:-mumuai}:${POSTGRES_PASSWORD:-123456}@postgres:5432/${POSTGRES_DB:-mumuai_novel}
|
||||
- DB_HOST=postgres
|
||||
- DB_PORT=5432
|
||||
- POSTGRES_PASSWORD=${POSTGRES_PASSWORD:-123456}
|
||||
# PostgreSQL 连接池配置
|
||||
- DATABASE_POOL_SIZE=${DATABASE_POOL_SIZE:-30}
|
||||
- DATABASE_MAX_OVERFLOW=${DATABASE_MAX_OVERFLOW:-20}
|
||||
- DATABASE_POOL_TIMEOUT=${DATABASE_POOL_TIMEOUT:-60}
|
||||
- DATABASE_POOL_RECYCLE=${DATABASE_POOL_RECYCLE:-1800}
|
||||
- DATABASE_POOL_PRE_PING=${DATABASE_POOL_PRE_PING:-True}
|
||||
- DATABASE_POOL_USE_LIFO=${DATABASE_POOL_USE_LIFO:-True}
|
||||
# 代理配置(可选)
|
||||
- HTTP_PROXY=${HTTP_PROXY:-}
|
||||
- HTTPS_PROXY=${HTTPS_PROXY:-}
|
||||
- NO_PROXY=${NO_PROXY:-localhost,127.0.0.1}
|
||||
# AI 服务配置
|
||||
- OPENAI_API_KEY=${OPENAI_API_KEY:-}
|
||||
- OPENAI_BASE_URL=${OPENAI_BASE_URL:-https://api.openai.com/v1}
|
||||
- GEMINI_API_KEY=${GEMINI_API_KEY:-}
|
||||
- GEMINI_BASE_URL=${GEMINI_BASE_URL:-}
|
||||
- ANTHROPIC_API_KEY=${ANTHROPIC_API_KEY:-}
|
||||
- ANTHROPIC_BASE_URL=${ANTHROPIC_BASE_URL:-}
|
||||
- DEFAULT_AI_PROVIDER=${DEFAULT_AI_PROVIDER:-openai}
|
||||
- DEFAULT_MODEL=${DEFAULT_MODEL:-gpt-4o-mini}
|
||||
- DEFAULT_TEMPERATURE=${DEFAULT_TEMPERATURE:-0.7}
|
||||
- DEFAULT_MAX_TOKENS=${DEFAULT_MAX_TOKENS:-32000}
|
||||
# LinuxDO OAuth 配置
|
||||
- LINUXDO_CLIENT_ID=${LINUXDO_CLIENT_ID:-11111}
|
||||
- LINUXDO_CLIENT_SECRET=${LINUXDO_CLIENT_SECRET:-11111}
|
||||
- LINUXDO_REDIRECT_URI=${LINUXDO_REDIRECT_URI:-http://localhost:8000/api/auth/linuxdo/callback}
|
||||
- FRONTEND_URL=${FRONTEND_URL:-http://localhost:8000}
|
||||
# 本地账户登录配置
|
||||
- LOCAL_AUTH_ENABLED=${LOCAL_AUTH_ENABLED:-true}
|
||||
- LOCAL_AUTH_USERNAME=${LOCAL_AUTH_USERNAME:-admin}
|
||||
- LOCAL_AUTH_PASSWORD=${LOCAL_AUTH_PASSWORD:-admin123}
|
||||
- LOCAL_AUTH_DISPLAY_NAME=${LOCAL_AUTH_DISPLAY_NAME:-本地管理员}
|
||||
# 会话配置
|
||||
- SESSION_EXPIRE_MINUTES=${SESSION_EXPIRE_MINUTES:-120}
|
||||
- SESSION_REFRESH_THRESHOLD_MINUTES=${SESSION_REFRESH_THRESHOLD_MINUTES:-30}
|
||||
restart: unless-stopped
|
||||
healthcheck:
|
||||
test: ["CMD", "python", "-c", "import urllib.request; urllib.request.urlopen('http://localhost:8000/health')"]
|
||||
interval: 30s
|
||||
timeout: 10s
|
||||
retries: 3
|
||||
start_period: 30s
|
||||
networks:
|
||||
- ai-story-network
|
||||
|
||||
volumes:
|
||||
postgres_data:
|
||||
driver: local
|
||||
|
||||
networks:
|
||||
ai-story-network:
|
||||
driver: bridge
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
```bash
|
||||
# 3. 启动服务
|
||||
docker-compose up -d
|
||||
|
||||
# 4. 查看日志
|
||||
docker-compose logs -f
|
||||
|
||||
# 5. 更新到最新版本
|
||||
docker-compose pull
|
||||
docker-compose up -d
|
||||
```
|
||||
|
||||
> **💡 提示**: Docker Hub 镜像已包含所有依赖和模型文件,无需额外下载
|
||||
|
||||
### 本地开发 / 从源码构建
|
||||
|
||||
#### 前置准备
|
||||
|
||||
```bash
|
||||
# ⚠️ 重要:如果从源码构建,需要先下载 embedding 模型文件
|
||||
# 模型文件较大(约 400MB),需放置到以下目录:
|
||||
# backend/embedding/models--sentence-transformers--paraphrase-multilingual-MiniLM-L12-v2/
|
||||
#
|
||||
# 📥 获取方式:
|
||||
# - 加入项目 QQ 群或 Linux DO 讨论区获取下载链接
|
||||
# - 群号:见项目主页
|
||||
# - Linux DO:https://linux.do/t/topic/1100112
|
||||
```
|
||||
|
||||
#### 后端
|
||||
|
||||
```bash
|
||||
cd backend
|
||||
python -m venv .venv
|
||||
source .venv/bin/activate # Windows: .venv\Scripts\activate
|
||||
pip install -r requirements.txt
|
||||
|
||||
# 配置 .env 文件
|
||||
cp .env.example .env
|
||||
# 编辑 .env 填入必要配置
|
||||
|
||||
# 启动 PostgreSQL(可使用 Docker)
|
||||
docker run -d --name postgres \
|
||||
-e POSTGRES_PASSWORD=your_password \
|
||||
-e POSTGRES_DB=mumuai_novel \
|
||||
-p 5432:5432 \
|
||||
postgres:18-alpine
|
||||
|
||||
# 启动后端
|
||||
python -m uvicorn app.main:app --host localhost --port 8000 --reload
|
||||
```
|
||||
|
||||
#### 前端
|
||||
|
||||
```bash
|
||||
cd frontend
|
||||
npm install
|
||||
npm run dev # 开发模式
|
||||
npm run build # 生产构建
|
||||
```
|
||||
|
||||
## ⚙️ 配置说明
|
||||
|
||||
### 必需配置
|
||||
|
||||
创建 `.env` 文件:
|
||||
|
||||
```bash
|
||||
# PostgreSQL 数据库(必需)
|
||||
DATABASE_URL=postgresql+asyncpg://mumuai:your_password@postgres:5432/mumuai_novel
|
||||
POSTGRES_PASSWORD=your_secure_password
|
||||
|
||||
# AI 服务
|
||||
OPENAI_API_KEY=your_openai_key
|
||||
OPENAI_BASE_URL=https://api.openai.com/v1
|
||||
DEFAULT_AI_PROVIDER=openai
|
||||
DEFAULT_MODEL=gpt-4o-mini
|
||||
|
||||
# 本地账户登录
|
||||
LOCAL_AUTH_ENABLED=true
|
||||
LOCAL_AUTH_USERNAME=admin
|
||||
LOCAL_AUTH_PASSWORD=your_password
|
||||
```
|
||||
|
||||
### 可选配置
|
||||
|
||||
```bash
|
||||
# LinuxDO OAuth
|
||||
LINUXDO_CLIENT_ID=your_client_id
|
||||
LINUXDO_CLIENT_SECRET=your_client_secret
|
||||
LINUXDO_REDIRECT_URI=http://localhost:8000/api/auth/callback
|
||||
|
||||
# PostgreSQL 连接池(高并发优化)
|
||||
DATABASE_POOL_SIZE=30
|
||||
DATABASE_MAX_OVERFLOW=20
|
||||
```
|
||||
|
||||
### 中转 API 配置
|
||||
|
||||
支持所有 OpenAI 兼容格式的中转服务:
|
||||
|
||||
```bash
|
||||
# New API 示例
|
||||
OPENAI_API_KEY=sk-xxxxxxxx
|
||||
OPENAI_BASE_URL=https://api.new-api.com/v1
|
||||
|
||||
# 其他中转服务
|
||||
OPENAI_BASE_URL=https://your-proxy-service.com/v1
|
||||
```
|
||||
|
||||
## 🐳 Docker 部署详情
|
||||
|
||||
### 服务架构
|
||||
|
||||
- **postgres**: PostgreSQL 18 数据库
|
||||
- 端口: 5432
|
||||
- 数据持久化: `postgres_data` volume
|
||||
- 初始化脚本: `backend/scripts/init_postgres.sql`(自动挂载)
|
||||
- 优化配置: 支持 80-150 并发用户
|
||||
|
||||
- **mumuainovel**: 主应用服务
|
||||
- 端口: 8000
|
||||
- 日志目录: `./logs`
|
||||
- 配置挂载: `.env` 文件
|
||||
- 自动等待数据库就绪
|
||||
- 健康检查: 每 30 秒检测一次
|
||||
|
||||
### 重要文件说明
|
||||
|
||||
| 文件 | 说明 | 是否必需 |
|
||||
|------|------|---------|
|
||||
| `.env` | 环境配置(API Key、数据库密码等) | ✅ 必需 |
|
||||
| `docker-compose.yml` | 服务编排配置 | ✅ 必需 |
|
||||
| `backend/scripts/init_postgres.sql` | PostgreSQL 扩展安装脚本 | ✅ 自动挂载 |
|
||||
| `backend/embedding/models--*/` | Embedding 模型文件 | ⚠️ 自建需要 |
|
||||
|
||||
> **注意**: 使用 Docker Hub 镜像时,模型文件已包含在镜像中,无需额外下载
|
||||
|
||||
### 常用命令
|
||||
|
||||
```bash
|
||||
# 启动服务
|
||||
docker-compose up -d
|
||||
|
||||
# 查看状态
|
||||
docker-compose ps
|
||||
|
||||
# 查看日志
|
||||
docker-compose logs -f
|
||||
|
||||
@@ -138,346 +440,67 @@ docker-compose down
|
||||
# 重启服务
|
||||
docker-compose restart
|
||||
|
||||
# 更新到最新版本
|
||||
docker-compose pull
|
||||
docker-compose up -d
|
||||
# 查看资源使用
|
||||
docker stats
|
||||
```
|
||||
|
||||
#### Docker Compose 配置文件示例
|
||||
### 数据持久化
|
||||
|
||||
使用 Docker Hub 镜像的完整配置:
|
||||
- `./postgres_data` - PostgreSQL 数据库文件
|
||||
- `./logs` - 应用日志文件
|
||||
|
||||
### 端口配置
|
||||
|
||||
修改 `docker-compose.yml` 中的端口映射:
|
||||
|
||||
```yaml
|
||||
services:
|
||||
ai-story:
|
||||
image: mumujie/mumuainovel:latest
|
||||
container_name: mumuainovel
|
||||
ports:
|
||||
- "8800:8000" # 宿主机端口:容器端口
|
||||
volumes:
|
||||
# 持久化数据库和日志
|
||||
- ./data:/app/data
|
||||
- ./logs:/app/logs
|
||||
# 挂载环境变量文件
|
||||
- ./.env:/app/.env:ro
|
||||
environment:
|
||||
- APP_NAME=mumuainovel
|
||||
- APP_VERSION=1.0.0
|
||||
- APP_HOST=0.0.0.0
|
||||
- APP_PORT=8000
|
||||
- DEBUG=false
|
||||
# 其他环境变量会从 .env 文件自动加载
|
||||
restart: unless-stopped
|
||||
healthcheck:
|
||||
test: ["CMD", "python", "-c", "import urllib.request; urllib.request.urlopen('http://localhost:8000/health')"]
|
||||
interval: 30s
|
||||
timeout: 10s
|
||||
retries: 3
|
||||
start_period: 10s
|
||||
networks:
|
||||
- ai-story-network
|
||||
|
||||
networks:
|
||||
ai-story-network:
|
||||
driver: bridge
|
||||
```
|
||||
|
||||
### 生产环境部署建议
|
||||
|
||||
#### 1. 环境变量配置
|
||||
|
||||
**必需配置**:
|
||||
- `OPENAI_API_KEY` 或 `GEMINI_API_KEY`:至少配置一个 AI 服务
|
||||
- `LOCAL_AUTH_PASSWORD`:修改为强密码
|
||||
|
||||
**推荐配置**:
|
||||
- `OPENAI_BASE_URL`:如果使用中转 API,修改为中转服务地址
|
||||
- `DEFAULT_AI_PROVIDER`:根据你的 API Key 选择 `openai`、`gemini` 或 `anthropic`
|
||||
- `DEFAULT_MODEL`:选择合适的模型(如 `gpt-4o-mini`、`gemini-2.0-flash-exp`)
|
||||
|
||||
#### 2. 数据持久化
|
||||
|
||||
数据目录已通过 volume 挂载,数据不会丢失:
|
||||
- `./data`:SQLite 数据库文件
|
||||
- `./logs`:应用日志文件
|
||||
|
||||
#### 3. 端口配置
|
||||
|
||||
默认端口映射:`8800:8000`
|
||||
- 宿主机端口:`8800`(可自定义修改)
|
||||
- 容器内端口:`8000`(固定,不要修改)
|
||||
|
||||
访问地址:`http://your-server-ip:8800`
|
||||
|
||||
|
||||
配置后记得更新 `.env` 中的 `LINUXDO_REDIRECT_URI` 和 `FRONTEND_URL`。
|
||||
|
||||
#### 5. 资源限制(可选)
|
||||
|
||||
在 `docker-compose.yml` 中添加资源限制:
|
||||
|
||||
```yaml
|
||||
services:
|
||||
ai-story:
|
||||
# ... 其他配置
|
||||
deploy:
|
||||
resources:
|
||||
limits:
|
||||
cpus: '2.0'
|
||||
memory: 2G
|
||||
reservations:
|
||||
cpus: '0.5'
|
||||
memory: 512M
|
||||
```
|
||||
|
||||
### 端口说明
|
||||
|
||||
- **默认端口**:`8800`(宿主机)→ `8000`(容器)
|
||||
- **可自定义**:修改 docker-compose.yml 中的 `ports` 配置
|
||||
- **健康检查**:容器内部使用 `8000` 端口进行健康检查
|
||||
|
||||
## ⚙️ 配置说明
|
||||
|
||||
### 环境变量
|
||||
|
||||
创建 `.env` 文件并配置以下变量:
|
||||
|
||||
```bash
|
||||
# ===== AI 服务配置(必填)=====
|
||||
# OpenAI 配置(支持官方API和中转API)
|
||||
OPENAI_API_KEY=your_openai_key_here
|
||||
OPENAI_BASE_URL=https://api.openai.com/v1
|
||||
|
||||
# Anthropic 配置
|
||||
# ANTHROPIC_API_KEY=your_anthropic_key_here
|
||||
# ANTHROPIC_BASE_URL=https://api.anthropic.com
|
||||
|
||||
# 中转API配置示例(使用OpenAI格式)
|
||||
# New API 中转服务
|
||||
# OPENAI_API_KEY=your_newapi_key_here
|
||||
# OPENAI_BASE_URL=https://api.new-api.com/v1
|
||||
|
||||
# 默认 AI 提供商和模型
|
||||
DEFAULT_AI_PROVIDER=openai
|
||||
DEFAULT_MODEL=gpt-4o-mini
|
||||
DEFAULT_TEMPERATURE=0.8
|
||||
DEFAULT_MAX_TOKENS=32000
|
||||
|
||||
# ===== 应用配置 =====
|
||||
APP_NAME=MuMuAINovel
|
||||
APP_VERSION=1.0.0
|
||||
APP_HOST=0.0.0.0
|
||||
APP_PORT=8000
|
||||
DEBUG=false
|
||||
|
||||
# ===== LinuxDO OAuth 配置(可选)=====
|
||||
LINUXDO_CLIENT_ID=your_client_id_here
|
||||
LINUXDO_CLIENT_SECRET=your_client_secret_here
|
||||
LINUXDO_REDIRECT_URI=http://localhost:8000/api/auth/callback
|
||||
FRONTEND_URL=http://localhost:8000
|
||||
|
||||
# ===== 本地账户登录配置 =====
|
||||
LOCAL_AUTH_ENABLED=true
|
||||
LOCAL_AUTH_USERNAME=admin
|
||||
LOCAL_AUTH_PASSWORD=your_secure_password_here
|
||||
LOCAL_AUTH_DISPLAY_NAME=管理员
|
||||
|
||||
# 会话配置
|
||||
# 会话过期时间(分钟),默认120分钟(2小时)
|
||||
SESSION_EXPIRE_MINUTES=120
|
||||
# 会话刷新阈值(分钟),剩余时间少于此值时可刷新,默认30分钟
|
||||
SESSION_REFRESH_THRESHOLD_MINUTES=30
|
||||
|
||||
# ===== CORS 配置(生产环境)=====
|
||||
# CORS_ORIGINS=https://your-domain.com,https://www.your-domain.com
|
||||
```
|
||||
|
||||
### AI 模型配置
|
||||
|
||||
项目支持多个 AI 提供商,你可以根据需要配置:
|
||||
|
||||
| 提供商 | 推荐模型 | 用途 |
|
||||
|--------|---------|------|
|
||||
| OpenAI | gpt-4, gpt-3.5-turbo | 高质量文本生成 |
|
||||
| Anthropic | claude-3-opus, claude-3-sonnet | 长文本创作 |
|
||||
|
||||
#### 使用中转API服务
|
||||
|
||||
如果你无法直接访问 OpenAI 官方 API,或者想使用更经济实惠的中转服务,本项目完全支持各种 OpenAI 兼容格式的中转 API:
|
||||
|
||||
##### 配置方法
|
||||
|
||||
只需修改 `.env` 文件中的两个参数:
|
||||
|
||||
```bash
|
||||
# 1. 填入中转服务提供的 API Key
|
||||
OPENAI_API_KEY=your_api_key_from_proxy_service
|
||||
|
||||
# 2. 修改 Base URL 为中转服务的地址
|
||||
OPENAI_BASE_URL=https://your-proxy-service.com/v1
|
||||
```
|
||||
|
||||
##### 常见中转服务配置示例
|
||||
|
||||
**New API**
|
||||
```bash
|
||||
OPENAI_API_KEY=sk-xxxxxxxxxxxxxxxx
|
||||
OPENAI_BASE_URL=https://api.new-api.com/v1
|
||||
```
|
||||
|
||||
**API2D**
|
||||
```bash
|
||||
OPENAI_API_KEY=fk-xxxxxxxxxxxxxxxx
|
||||
OPENAI_BASE_URL=https://api.api2d.com/v1
|
||||
```
|
||||
|
||||
**OpenAI-SB**
|
||||
```bash
|
||||
OPENAI_API_KEY=sb-xxxxxxxxxxxxxxxx
|
||||
OPENAI_BASE_URL=https://api.openai-sb.com/v1
|
||||
```
|
||||
|
||||
**自建 One API / New API**
|
||||
```bash
|
||||
OPENAI_API_KEY=sk-xxxxxxxxxxxxxxxx
|
||||
OPENAI_BASE_URL=https://your-domain.com/v1
|
||||
```
|
||||
|
||||
##### 注意事项
|
||||
|
||||
- ✅ 所有支持 OpenAI 接口格式的服务都可以使用
|
||||
- ✅ 确保中转服务的 Base URL 以 `/v1` 结尾
|
||||
- ✅ 根据中转服务支持的模型,修改 `DEFAULT_MODEL` 参数
|
||||
- ⚠️ 不同中转服务的模型名称可能不同,请参考服务商文档
|
||||
- ⚠️ 部分中转服务可能对请求频率或并发有限制
|
||||
|
||||
##### 推荐的中转服务
|
||||
|
||||
如果你需要中转服务,以下是一些常见选择:
|
||||
|
||||
1. **New API** - 开源的 API 分发系统,支持多种模型
|
||||
2. **API2D** - 国内稳定的 API 中转服务
|
||||
3. **OpenAI-SB** - 提供多种 AI 模型的中转
|
||||
4. **自建服务** - 使用 One API 或 New API 自行搭建
|
||||
|
||||
> 💡 提示:使用中转服务时,请确保服务提供商的可靠性和数据安全性
|
||||
|
||||
### 登录方式配置
|
||||
|
||||
#### 本地账户登录(默认启用)
|
||||
|
||||
适合个人使用或小型团队:
|
||||
|
||||
```bash
|
||||
LOCAL_AUTH_ENABLED=true
|
||||
LOCAL_AUTH_USERNAME=admin
|
||||
LOCAL_AUTH_PASSWORD=your_password
|
||||
```
|
||||
|
||||
#### LinuxDO OAuth 登录
|
||||
|
||||
适合需要社区集成的场景,需要在 [LinuxDO](https://linux.do) 注册 OAuth 应用:
|
||||
|
||||
```bash
|
||||
LINUXDO_CLIENT_ID=your_client_id
|
||||
LINUXDO_CLIENT_SECRET=your_client_secret
|
||||
LINUXDO_REDIRECT_URI=http://your-domain:8000/api/auth/callback
|
||||
ports:
|
||||
- "8800:8000" # 宿主机:容器
|
||||
```
|
||||
|
||||
## 📁 项目结构
|
||||
|
||||
```
|
||||
MuMuAINovel/
|
||||
├── backend/ # 后端服务
|
||||
├── backend/ # 后端服务
|
||||
│ ├── app/
|
||||
│ │ ├── api/ # API 路由
|
||||
│ │ │ ├── auth.py # 认证接口
|
||||
│ │ │ ├── projects.py # 项目管理
|
||||
│ │ │ ├── chapters.py # 章节管理
|
||||
│ │ │ ├── characters.py # 角色管理
|
||||
│ │ │ ├── wizard_stream.py # 向导流式生成
|
||||
│ │ │ └── ...
|
||||
│ │ ├── models/ # 数据模型
|
||||
│ │ ├── schemas/ # Pydantic 模型
|
||||
│ │ ├── services/ # 业务逻辑
|
||||
│ │ │ ├── ai_service.py # AI 服务封装
|
||||
│ │ │ └── oauth_service.py # OAuth 服务
|
||||
│ │ ├── middleware/ # 中间件
|
||||
│ │ ├── utils/ # 工具函数
|
||||
│ │ ├── config.py # 配置管理
|
||||
│ │ ├── database.py # 数据库连接
|
||||
│ │ └── main.py # 应用入口
|
||||
│ ├── data/ # 数据存储目录
|
||||
│ ├── static/ # 前端静态文件(构建后)
|
||||
│ ├── requirements.txt # Python 依赖
|
||||
│ └── .env.example # 环境变量示例
|
||||
├── frontend/ # 前端应用
|
||||
│ │ ├── api/ # API 路由
|
||||
│ │ ├── models/ # 数据模型
|
||||
│ │ ├── services/ # 业务逻辑
|
||||
│ │ ├── middleware/ # 中间件
|
||||
│ │ ├── database.py # 数据库连接
|
||||
│ │ └── main.py # 应用入口
|
||||
│ ├── scripts/ # 工具脚本
|
||||
│ └── requirements.txt # Python 依赖
|
||||
├── frontend/ # 前端应用
|
||||
│ ├── src/
|
||||
│ │ ├── pages/ # 页面组件
|
||||
│ │ │ ├── ProjectList.tsx # 项目列表
|
||||
│ │ │ ├── ProjectWizardNew.tsx # 创建向导
|
||||
│ │ │ ├── Chapters.tsx # 章节管理
|
||||
│ │ │ ├── Characters.tsx # 角色管理
|
||||
│ │ │ └── ...
|
||||
│ │ ├── components/ # 通用组件
|
||||
│ │ ├── services/ # API 服务
|
||||
│ │ ├── store/ # 状态管理(Zustand)
|
||||
│ │ ├── types/ # TypeScript 类型
|
||||
│ │ └── utils/ # 工具函数
|
||||
│ ├── package.json
|
||||
│ └── vite.config.ts
|
||||
├── docker-compose.yml # Docker Compose 配置
|
||||
├── Dockerfile # Docker 镜像构建
|
||||
└── README.md # 项目说明文档
|
||||
│ │ ├── pages/ # 页面组件
|
||||
│ │ ├── components/ # 通用组件
|
||||
│ │ ├── services/ # API 服务
|
||||
│ │ └── store/ # 状态管理
|
||||
│ └── package.json
|
||||
├── docker-compose.yml # Docker Compose 配置
|
||||
├── Dockerfile # Docker 镜像构建
|
||||
└── README.md
|
||||
```
|
||||
|
||||
## 🛠️ 技术栈
|
||||
|
||||
### 后端
|
||||
**后端**: FastAPI • PostgreSQL • SQLAlchemy • OpenAI/Claude/Gemini SDK
|
||||
|
||||
- **框架**:FastAPI 0.109.0
|
||||
- **数据库**:SQLite + SQLAlchemy(异步)
|
||||
- **AI 集成**:OpenAI、Anthropic、Google Gemini SDK
|
||||
- **认证**:LinuxDO OAuth2、本地账户
|
||||
- **日志**:Python logging + 文件轮转
|
||||
|
||||
### 前端
|
||||
|
||||
- **框架**:React 18.3 + TypeScript
|
||||
- **UI 库**:Ant Design 5.27
|
||||
- **路由**:React Router 6.28
|
||||
- **状态管理**:Zustand 5.0
|
||||
- **HTTP 客户端**:Axios
|
||||
- **构建工具**:Vite 7.1
|
||||
**前端**: React 18 • TypeScript • Ant Design • Zustand • Vite
|
||||
|
||||
## 📖 使用指南
|
||||
|
||||
### 创建第一个小说项目
|
||||
|
||||
1. **登录系统**
|
||||
- 使用本地账户或 LinuxDO 账户登录
|
||||
|
||||
2. **创建项目**
|
||||
- 点击"创建项目"按钮
|
||||
- 选择"使用向导创建"或"手动创建"
|
||||
|
||||
3. **使用向导(推荐)**
|
||||
- 输入小说基本信息(标题、类型、背景等)
|
||||
- AI 自动生成大纲、角色和世界观
|
||||
- 实时查看生成进度
|
||||
|
||||
4. **编辑和完善**
|
||||
- 在项目详情页查看和编辑大纲
|
||||
- 管理角色和人物关系
|
||||
- 生成和编辑章节内容
|
||||
|
||||
1. **登录系统** - 使用本地账户或 LinuxDO 账户
|
||||
2. **创建项目** - 选择"使用向导创建"
|
||||
3. **AI 生成** - 输入基本信息,AI 自动生成大纲和角色
|
||||
4. **编辑完善** - 管理角色关系,生成和编辑章节
|
||||
|
||||
### API 文档
|
||||
|
||||
应用启动后,可访问自动生成的 API 文档:
|
||||
|
||||
- Swagger UI:`http://localhost:8000/docs`
|
||||
- ReDoc:`http://localhost:8000/redoc`
|
||||
- Swagger UI: `http://localhost:8000/docs`
|
||||
- ReDoc: `http://localhost:8000/redoc`
|
||||
|
||||
## 🤝 贡献
|
||||
|
||||
@@ -489,40 +512,44 @@ MuMuAINovel/
|
||||
4. 推送到分支 (`git push origin feature/AmazingFeature`)
|
||||
5. 提交 Pull Request
|
||||
|
||||
### 贡献者
|
||||
|
||||
感谢所有为本项目做出贡献的开发者!
|
||||
|
||||
<a href="https://github.com/xiamuceer-j/MuMuAINovel/graphs/contributors">
|
||||
<img src="https://contrib.rocks/image?repo=xiamuceer-j/MuMuAINovel" />
|
||||
</a>
|
||||
|
||||
## 📝 许可证
|
||||
|
||||
本项目采用 [GNU General Public License v3.0](https://www.gnu.org/licenses/gpl-3.0.html) 开源协议
|
||||
本项目采用 [GNU General Public License v3.0](LICENSE)
|
||||
|
||||
**这意味着:**
|
||||
|
||||
- ✅ **可以** - 自由使用、复制、修改和分发本项目
|
||||
- ✅ **可以** - 用于商业目的
|
||||
- ✅ **可以** - 用于个人学习和研究
|
||||
- 📝 **必须** - 开源你的修改版本
|
||||
- 📝 **必须** - 保留原作者版权声明
|
||||
- 📝 **必须** - 以相同的 GPL v3 协议发布衍生作品
|
||||
|
||||
详见 [LICENSE](LICENSE) 文件
|
||||
**GPL v3 意味着:**
|
||||
- ✅ 可自由使用、修改和分发
|
||||
- ✅ 可用于商业目的
|
||||
- 📝 必须开源修改版本
|
||||
- 📝 必须保留原作者版权
|
||||
- 📝 衍生作品必须使用 GPL v3 协议
|
||||
|
||||
## 🙏 致谢
|
||||
|
||||
- [FastAPI](https://fastapi.tiangolo.com/) - 现代化的 Python Web 框架
|
||||
- [React](https://react.dev/) - 用户界面构建库
|
||||
- [Ant Design](https://ant.design/) - 企业级 UI 设计语言
|
||||
- [OpenAI](https://openai.com/) / [Anthropic](https://www.anthropic.com/) - AI 模型提供商
|
||||
- [FastAPI](https://fastapi.tiangolo.com/) - Python Web 框架
|
||||
- [React](https://react.dev/) - 前端框架
|
||||
- [Ant Design](https://ant.design/) - UI 组件库
|
||||
- [PostgreSQL](https://www.postgresql.org/) - 数据库
|
||||
|
||||
## 📧 联系方式
|
||||
|
||||
如有问题或建议,欢迎通过以下方式联系:
|
||||
|
||||
- 提交 [Issue](https://github.com/yourusername/MuMuAINovel/issues)
|
||||
- Linux DO [LD](https://linux.do/t/topic/1100112)
|
||||
- 提交 [Issue](https://github.com/xiamuceer-j/MuMuAINovel/issues)
|
||||
- Linux DO [讨论](https://linux.do/t/topic/1106333)
|
||||
- 加入QQ群 [QQ群](frontend/public/qq.jpg)
|
||||
- 加入WX群 [WX群](frontend/public/WX.png)
|
||||
|
||||
---
|
||||
|
||||
<div align="center">
|
||||
|
||||
**如果这个项目对你有帮助,请给个 ⭐️ Star 支持一下!**
|
||||
**如果这个项目对你有帮助,请给个 ⭐️ Star!**
|
||||
|
||||
Made with ❤️
|
||||
|
||||
|
||||
+78
-38
@@ -1,54 +1,94 @@
|
||||
# AI服务配置
|
||||
# OpenAI配置
|
||||
OPENAI_API_KEY=your_openai_key_here
|
||||
OPENAI_BASE_URL=https://api.openai.com/v1
|
||||
|
||||
# Anthropic配置
|
||||
ANTHROPIC_API_KEY=your_anthropic_key_here
|
||||
ANTHROPIC_BASE_URL=https://api.anthropic.com
|
||||
|
||||
# 默认AI提供商:openai, gemini, anthropic
|
||||
DEFAULT_AI_PROVIDER=openai
|
||||
DEFAULT_MODEL=gpt-4.1
|
||||
DEFAULT_TEMPERATURE=0.8
|
||||
DEFAULT_MAX_TOKENS=32000
|
||||
# ==========================================
|
||||
# MuMuAINovel 配置文件示例
|
||||
# ==========================================
|
||||
# 复制此文件为 .env 并修改配置值
|
||||
# cp .env.example .env
|
||||
|
||||
# ==========================================
|
||||
# 应用配置
|
||||
# ==========================================
|
||||
APP_NAME=MuMuAINovel
|
||||
APP_VERSION=1.0.0
|
||||
APP_VERSION=1.2.6
|
||||
APP_HOST=0.0.0.0
|
||||
APP_PORT=8000
|
||||
DEBUG=true
|
||||
DEBUG=false
|
||||
TZ=Asia/Shanghai
|
||||
|
||||
# LinuxDO OAuth2 配置(可选)
|
||||
# 注意:Docker部署时,LINUXDO_REDIRECT_URI 应该使用实际的域名或服务器IP
|
||||
# 本地开发: http://localhost:8000/api/auth/callback
|
||||
# 生产环境: https://your-domain.com/api/auth/callback 或 http://your-server-ip:8000/api/auth/callback
|
||||
LINUXDO_CLIENT_ID=your_client_id_here
|
||||
LINUXDO_CLIENT_SECRET=your_client_secret_here
|
||||
# ==========================================
|
||||
# PostgreSQL 数据库配置
|
||||
# ==========================================
|
||||
|
||||
# PostgreSQL 连接信息
|
||||
POSTGRES_DB=mumuai_novel
|
||||
POSTGRES_USER=mumuai
|
||||
POSTGRES_PASSWORD=123456
|
||||
POSTGRES_PORT=5432
|
||||
|
||||
# 数据库连接 URL(Docker 部署时自动生成)
|
||||
# DATABASE_URL=postgresql+asyncpg://mumuai:123456@localhost:5432/mumuai_novel
|
||||
|
||||
# ==========================================
|
||||
# SQLite 数据库配置
|
||||
# ==========================================
|
||||
|
||||
# DATABASE_URL=sqlite+aiosqlite:///data/ai_story.db
|
||||
|
||||
# ==========================================
|
||||
# 日志配置
|
||||
# ==========================================
|
||||
LOG_LEVEL=INFO
|
||||
LOG_TO_FILE=true
|
||||
LOG_FILE_PATH=logs/app.log
|
||||
LOG_MAX_BYTES=10485760
|
||||
LOG_BACKUP_COUNT=30
|
||||
|
||||
# ==========================================
|
||||
# CORS 配置
|
||||
# ==========================================
|
||||
CORS_ORIGINS=["http://localhost:8000","http://127.0.0.1:8000"]
|
||||
|
||||
# ==========================================
|
||||
# 代理配置(可选)
|
||||
# ==========================================
|
||||
# HTTP_PROXY=http://your-proxy:port
|
||||
# HTTPS_PROXY=http://your-proxy:port
|
||||
# NO_PROXY=localhost,127.0.0.1
|
||||
|
||||
# ==========================================
|
||||
# AI 服务配置(至少配置一个)
|
||||
# ==========================================
|
||||
|
||||
# OpenAI 配置
|
||||
OPENAI_API_KEY=your_openai_api_key_here
|
||||
OPENAI_BASE_URL=https://api.openai.com/v1
|
||||
|
||||
# 默认 AI 配置
|
||||
DEFAULT_AI_PROVIDER=openai
|
||||
DEFAULT_MODEL=gpt-4o-mini
|
||||
DEFAULT_TEMPERATURE=0.7
|
||||
DEFAULT_MAX_TOKENS=32000
|
||||
|
||||
# ==========================================
|
||||
# LinuxDO OAuth 配置(可选)
|
||||
# ==========================================
|
||||
LINUXDO_CLIENT_ID=11111
|
||||
LINUXDO_CLIENT_SECRET=11111
|
||||
LINUXDO_REDIRECT_URI=http://localhost:8000/api/auth/callback
|
||||
|
||||
# 前端URL配置(用于OAuth回调后重定向到前端)
|
||||
# 本地开发: http://localhost:8000
|
||||
# 生产环境: https://your-domain.com 或 http://your-server-ip:8000
|
||||
FRONTEND_URL=http://localhost:8000
|
||||
|
||||
# 初始管理员(LinuxDO user_id)
|
||||
# INITIAL_ADMIN_LINUXDO_ID=your_linuxdo_user_id
|
||||
|
||||
# ==========================================
|
||||
# 本地账户登录配置
|
||||
# 启用本地账户登录(true/false)
|
||||
# ==========================================
|
||||
LOCAL_AUTH_ENABLED=true
|
||||
# 本地登录用户名
|
||||
LOCAL_AUTH_USERNAME=admin
|
||||
# 本地登录密码
|
||||
LOCAL_AUTH_PASSWORD=your_secure_password_here
|
||||
# 本地用户显示名称
|
||||
LOCAL_AUTH_DISPLAY_NAME=管理员
|
||||
LOCAL_AUTH_PASSWORD=admin123
|
||||
LOCAL_AUTH_DISPLAY_NAME=本地管理员
|
||||
|
||||
# ==========================================
|
||||
# 会话配置
|
||||
# 会话过期时间(分钟),默认120分钟(2小时)
|
||||
# ==========================================
|
||||
SESSION_EXPIRE_MINUTES=120
|
||||
# 会话刷新阈值(分钟),剩余时间少于此值时可刷新,默认30分钟
|
||||
SESSION_REFRESH_THRESHOLD_MINUTES=30
|
||||
|
||||
# CORS配置(生产环境)
|
||||
# 允许的跨域来源,多个用逗号分隔
|
||||
# CORS_ORIGINS=https://your-domain.com,https://www.your-domain.com
|
||||
@@ -0,0 +1,48 @@
|
||||
# Alembic Database Migration Profile - PostgreSQL
|
||||
# Database version management for the MuMuAINovel project
|
||||
|
||||
[alembic]
|
||||
# Migration Script storage directory (PostgreSQL)
|
||||
script_location = alembic/postgres
|
||||
|
||||
# Template File Path (for generating migration scripts)
|
||||
file_template = %%(year)d%%(month).2d%%(day).2d_%%(hour).2d%%(minute).2d_%%(rev)s_%%(slug)s
|
||||
|
||||
# Database connection string
|
||||
# Note: The actual connection string is read from the environment variable in env.py
|
||||
# sqlalchemy.url = postgresql+asyncpg://mumuai:password@localhost:5432/mumuai_novel
|
||||
|
||||
# Log Configuration
|
||||
[loggers]
|
||||
keys = root,sqlalchemy,alembic
|
||||
|
||||
[handlers]
|
||||
keys = console
|
||||
|
||||
[formatters]
|
||||
keys = generic
|
||||
|
||||
[logger_root]
|
||||
level = WARN
|
||||
handlers = console
|
||||
qualname =
|
||||
|
||||
[logger_sqlalchemy]
|
||||
level = WARN
|
||||
handlers =
|
||||
qualname = sqlalchemy.engine
|
||||
|
||||
[logger_alembic]
|
||||
level = INFO
|
||||
handlers =
|
||||
qualname = alembic
|
||||
|
||||
[handler_console]
|
||||
class = StreamHandler
|
||||
args = (sys.stderr,)
|
||||
level = NOTSET
|
||||
formatter = generic
|
||||
|
||||
[formatter_generic]
|
||||
format = %(levelname)-5.5s [%(name)s] %(message)s
|
||||
datefmt = %H:%M:%S
|
||||
@@ -0,0 +1,48 @@
|
||||
# Alembic Database Migration Profile - SQLite
|
||||
# Database version management for the MuMuAINovel project
|
||||
|
||||
[alembic]
|
||||
# Migration Script storage directory (SQLite)
|
||||
script_location = alembic/sqlite
|
||||
|
||||
# Template File Path (for generating migration scripts)
|
||||
file_template = %%(year)d%%(month).2d%%(day).2d_%%(hour).2d%%(minute).2d_%%(rev)s_%%(slug)s
|
||||
|
||||
# Database connection string
|
||||
# Note: The actual connection string is read from the environment variable in env.py
|
||||
# sqlalchemy.url = sqlite+aiosqlite:///data/ai_story.db
|
||||
|
||||
# Log Configuration
|
||||
[loggers]
|
||||
keys = root,sqlalchemy,alembic
|
||||
|
||||
[handlers]
|
||||
keys = console
|
||||
|
||||
[formatters]
|
||||
keys = generic
|
||||
|
||||
[logger_root]
|
||||
level = WARN
|
||||
handlers = console
|
||||
qualname =
|
||||
|
||||
[logger_sqlalchemy]
|
||||
level = WARN
|
||||
handlers =
|
||||
qualname = sqlalchemy.engine
|
||||
|
||||
[logger_alembic]
|
||||
level = INFO
|
||||
handlers =
|
||||
qualname = alembic
|
||||
|
||||
[handler_console]
|
||||
class = StreamHandler
|
||||
args = (sys.stderr,)
|
||||
level = NOTSET
|
||||
formatter = generic
|
||||
|
||||
[formatter_generic]
|
||||
format = %(levelname)-5.5s [%(name)s] %(message)s
|
||||
datefmt = %H:%M:%S
|
||||
@@ -0,0 +1,145 @@
|
||||
# Alembic 数据库迁移指南
|
||||
|
||||
本项目支持 **PostgreSQL** 和 **SQLite** 两种数据库,使用独立的 Alembic 配置管理迁移。
|
||||
|
||||
## 📁 目录结构
|
||||
|
||||
```
|
||||
backend/
|
||||
├── alembic-postgres.ini # PostgreSQL 配置文件
|
||||
├── alembic-sqlite.ini # SQLite 配置文件
|
||||
├── alembic/
|
||||
│ ├── postgres/ # PostgreSQL 迁移脚本目录
|
||||
│ │ ├── env.py
|
||||
│ │ ├── script.py.mako
|
||||
│ │ └── versions/ # PostgreSQL 迁移版本
|
||||
│ │ ├── 20251226_1008_ee0a189f1532_初始数据库结构.py
|
||||
│ │ └── 20251226_1102_e411428f00c0_初始化预置数据.py
|
||||
│ └── sqlite/ # SQLite 迁移脚本目录
|
||||
│ ├── env.py
|
||||
│ ├── script.py.mako
|
||||
│ └── versions/ # SQLite 迁移版本
|
||||
│ └── 20251226_1322_fbeb1038c728_初始化sqlite数据库.py
|
||||
```
|
||||
|
||||
## 🚀 使用方法
|
||||
|
||||
### 1. PostgreSQL 数据库
|
||||
|
||||
#### 配置环境变量
|
||||
```bash
|
||||
# .env 文件
|
||||
DATABASE_URL=postgresql+asyncpg://username:password@localhost:5432/database_name
|
||||
```
|
||||
|
||||
#### 生成迁移脚本
|
||||
```bash
|
||||
cd backend
|
||||
alembic -c alembic-postgres.ini revision --autogenerate -m "描述信息"
|
||||
```
|
||||
|
||||
#### 应用迁移
|
||||
```bash
|
||||
alembic -c alembic-postgres.ini upgrade head
|
||||
```
|
||||
|
||||
#### 回退迁移
|
||||
```bash
|
||||
alembic -c alembic-postgres.ini downgrade -1
|
||||
```
|
||||
|
||||
#### 查看迁移历史
|
||||
```bash
|
||||
alembic -c alembic-postgres.ini history
|
||||
alembic -c alembic-postgres.ini current
|
||||
```
|
||||
|
||||
### 2. SQLite 数据库
|
||||
|
||||
#### 配置环境变量
|
||||
```bash
|
||||
# .env 文件
|
||||
DATABASE_URL=sqlite+aiosqlite:///./data/mumuai.db
|
||||
```
|
||||
|
||||
#### 生成迁移脚本
|
||||
```bash
|
||||
cd backend
|
||||
alembic -c alembic-sqlite.ini revision --autogenerate -m "描述信息"
|
||||
```
|
||||
|
||||
#### 应用迁移
|
||||
```bash
|
||||
alembic -c alembic-sqlite.ini upgrade head
|
||||
```
|
||||
|
||||
#### 回退迁移
|
||||
```bash
|
||||
alembic -c alembic-sqlite.ini downgrade -1
|
||||
```
|
||||
|
||||
#### 查看迁移历史
|
||||
```bash
|
||||
alembic -c alembic-sqlite.ini history
|
||||
alembic -c alembic-sqlite.ini current
|
||||
```
|
||||
|
||||
## ⚙️ 关键配置差异
|
||||
|
||||
### PostgreSQL (alembic/postgres/env.py)
|
||||
- `render_as_batch=False` - 直接支持 ALTER TABLE
|
||||
- 使用 `server_default=sa.text('now()')`
|
||||
|
||||
### SQLite (alembic/sqlite/env.py)
|
||||
- `render_as_batch=True` - 通过重建表实现 ALTER TABLE
|
||||
- 使用 `server_default=sa.text('(CURRENT_TIMESTAMP)')` - SQLite 格式
|
||||
|
||||
## 📝 注意事项
|
||||
|
||||
### SQLite 限制
|
||||
1. **并发写入**:同时只允许一个写操作
|
||||
2. **ALTER TABLE 限制**:某些操作需要重建表(Alembic 的批处理模式会自动处理)
|
||||
3. **类型映射**:
|
||||
- `JSON` → `TEXT` (SQLAlchemy 自动处理)
|
||||
- `BOOLEAN` → `INTEGER` (0/1)
|
||||
- `DEFAULT now()` → `DEFAULT CURRENT_TIMESTAMP`
|
||||
|
||||
### PostgreSQL 优势
|
||||
1. **高并发支持**:多用户同时读写
|
||||
2. **完整的 ALTER TABLE 支持**
|
||||
3. **高级特性**:全文搜索、JSON 操作符、数组类型等
|
||||
|
||||
## 🔄 切换数据库
|
||||
|
||||
只需修改 `.env` 文件中的 `DATABASE_URL`,然后使用对应的配置文件执行迁移:
|
||||
|
||||
```bash
|
||||
# 切换到 SQLite
|
||||
DATABASE_URL=sqlite+aiosqlite:///./data/mumuai.db
|
||||
alembic -c alembic-sqlite.ini upgrade head
|
||||
|
||||
# 切换到 PostgreSQL
|
||||
DATABASE_URL=postgresql+asyncpg://user:pass@localhost:5432/db
|
||||
alembic -c alembic-postgres.ini upgrade head
|
||||
```
|
||||
|
||||
## 💡 最佳实践
|
||||
|
||||
1. **开发环境**:使用 SQLite(简单、无需额外服务)
|
||||
2. **生产环境**:使用 PostgreSQL(性能、并发、稳定性)
|
||||
3. **保持同步**:两个数据库的模型定义必须一致
|
||||
4. **测试迁移**:在两种数据库上都测试迁移脚本
|
||||
|
||||
## 🐛 常见问题
|
||||
|
||||
### Q: 迁移脚本生成后可以通用吗?
|
||||
A: 不行。PostgreSQL 和 SQLite 的迁移脚本是独立的,因为:
|
||||
- SQL 语法差异(如 DEFAULT 值)
|
||||
- 类型差异(如 JSON、BOOLEAN)
|
||||
- ALTER TABLE 能力差异
|
||||
|
||||
### Q: 如何从 PostgreSQL 迁移数据到 SQLite?
|
||||
A: 需要编写数据导出/导入脚本,不能直接复用迁移脚本。
|
||||
|
||||
### Q: 为什么 SQLite 迁移这么慢?
|
||||
A: SQLite 的 ALTER TABLE 限制导致需要重建表,这在大表时会很慢。
|
||||
@@ -0,0 +1,2 @@
|
||||
# 此文件确保 versions 目录被 Git 追踪
|
||||
# 迁移版本文件将存放在此目录
|
||||
@@ -0,0 +1,101 @@
|
||||
"""Alembic 环境配置文件 - PostgreSQL"""
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
from logging.config import fileConfig
|
||||
|
||||
from sqlalchemy import pool
|
||||
from sqlalchemy.engine import Connection
|
||||
from sqlalchemy.ext.asyncio import async_engine_from_config
|
||||
|
||||
from alembic import context
|
||||
|
||||
# 添加项目根目录到 Python 路径
|
||||
project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
if project_root not in sys.path:
|
||||
sys.path.insert(0, project_root)
|
||||
|
||||
# 导入应用配置
|
||||
from app.config import settings
|
||||
|
||||
# 导入 Base 和所有模型
|
||||
from app.database import Base
|
||||
from app.models import (
|
||||
Project, Outline, Character, Chapter, GenerationHistory,
|
||||
Settings, WritingStyle, ProjectDefaultStyle,
|
||||
RelationshipType, CharacterRelationship, Organization, OrganizationMember,
|
||||
StoryMemory, PlotAnalysis, AnalysisTask, BatchGenerationTask,
|
||||
RegenerationTask, Career, CharacterCareer, User, MCPPlugin, PromptTemplate
|
||||
)
|
||||
|
||||
# Alembic Config 对象
|
||||
config = context.config
|
||||
|
||||
# 设置数据库连接字符串(从环境变量读取)
|
||||
config.set_main_option("sqlalchemy.url", settings.database_url)
|
||||
|
||||
# 配置日志
|
||||
if config.config_file_name is not None:
|
||||
fileConfig(config.config_file_name)
|
||||
|
||||
# 设置 target_metadata 为应用的 Base.metadata
|
||||
target_metadata = Base.metadata
|
||||
|
||||
|
||||
def run_migrations_offline() -> None:
|
||||
"""在'离线'模式下运行迁移"""
|
||||
url = config.get_main_option("sqlalchemy.url")
|
||||
context.configure(
|
||||
url=url,
|
||||
target_metadata=target_metadata,
|
||||
literal_binds=True,
|
||||
dialect_opts={"paramstyle": "named"},
|
||||
compare_type=True,
|
||||
compare_server_default=True,
|
||||
)
|
||||
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
def do_run_migrations(connection: Connection) -> None:
|
||||
"""执行迁移的核心函数 - PostgreSQL 专用"""
|
||||
context.configure(
|
||||
connection=connection,
|
||||
target_metadata=target_metadata,
|
||||
compare_type=True,
|
||||
compare_server_default=True,
|
||||
render_as_batch=False, # PostgreSQL 不需要批处理模式
|
||||
)
|
||||
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
async def run_async_migrations() -> None:
|
||||
"""在'在线'模式下运行异步迁移"""
|
||||
configuration = config.get_section(config.config_ini_section, {})
|
||||
configuration["sqlalchemy.url"] = settings.database_url
|
||||
|
||||
connectable = async_engine_from_config(
|
||||
configuration,
|
||||
prefix="sqlalchemy.",
|
||||
poolclass=pool.NullPool,
|
||||
)
|
||||
|
||||
async with connectable.connect() as connection:
|
||||
await connection.run_sync(do_run_migrations)
|
||||
|
||||
await connectable.dispose()
|
||||
|
||||
|
||||
def run_migrations_online() -> None:
|
||||
"""在'在线'模式下运行迁移"""
|
||||
asyncio.run(run_async_migrations())
|
||||
|
||||
|
||||
# 根据上下文选择运行模式
|
||||
if context.is_offline_mode():
|
||||
run_migrations_offline()
|
||||
else:
|
||||
run_migrations_online()
|
||||
@@ -0,0 +1,26 @@
|
||||
"""${message}
|
||||
|
||||
Revision ID: ${up_revision}
|
||||
Revises: ${down_revision | comma,n}
|
||||
Create Date: ${create_date}
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
${imports if imports else ""}
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = ${repr(up_revision)}
|
||||
down_revision: Union[str, None] = ${repr(down_revision)}
|
||||
branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
|
||||
depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
${upgrades if upgrades else "pass"}
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
${downgrades if downgrades else "pass"}
|
||||
@@ -0,0 +1,554 @@
|
||||
"""初始数据库结构
|
||||
|
||||
Revision ID: ee0a189f1532
|
||||
Revises:
|
||||
Create Date: 2025-12-26 10:08:55.432217
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = 'ee0a189f1532'
|
||||
down_revision: Union[str, None] = None
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""升级数据库结构"""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table('batch_generation_tasks',
|
||||
sa.Column('id', sa.String(length=36), nullable=False),
|
||||
sa.Column('project_id', sa.String(length=36), nullable=False, comment='项目ID'),
|
||||
sa.Column('user_id', sa.String(length=100), nullable=False, comment='用户ID'),
|
||||
sa.Column('start_chapter_number', sa.Integer(), nullable=False, comment='起始章节序号'),
|
||||
sa.Column('chapter_count', sa.Integer(), nullable=False, comment='生成章节数量'),
|
||||
sa.Column('chapter_ids', sa.JSON(), nullable=False, comment='待生成的章节ID列表'),
|
||||
sa.Column('style_id', sa.Integer(), nullable=True, comment='使用的写作风格ID'),
|
||||
sa.Column('target_word_count', sa.Integer(), nullable=True, comment='目标字数'),
|
||||
sa.Column('enable_analysis', sa.Boolean(), nullable=True, comment='是否启用同步分析'),
|
||||
sa.Column('status', sa.String(length=20), nullable=True, comment='任务状态: pending/running/completed/failed/cancelled'),
|
||||
sa.Column('total_chapters', sa.Integer(), nullable=True, comment='总章节数'),
|
||||
sa.Column('completed_chapters', sa.Integer(), nullable=True, comment='已完成章节数'),
|
||||
sa.Column('failed_chapters', sa.JSON(), nullable=True, comment='失败的章节信息列表'),
|
||||
sa.Column('current_chapter_id', sa.String(length=36), nullable=True, comment='当前正在生成的章节ID'),
|
||||
sa.Column('current_chapter_number', sa.Integer(), nullable=True, comment='当前正在生成的章节序号'),
|
||||
sa.Column('current_retry_count', sa.Integer(), nullable=True, comment='当前章节重试次数'),
|
||||
sa.Column('max_retries', sa.Integer(), nullable=True, comment='最大重试次数'),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('now()'), nullable=True, comment='创建时间'),
|
||||
sa.Column('started_at', sa.DateTime(), nullable=True, comment='开始时间'),
|
||||
sa.Column('completed_at', sa.DateTime(), nullable=True, comment='完成时间'),
|
||||
sa.Column('error_message', sa.String(length=500), nullable=True, comment='错误信息'),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_table('mcp_plugins',
|
||||
sa.Column('id', sa.String(length=36), nullable=False),
|
||||
sa.Column('user_id', sa.String(length=50), nullable=False, comment='用户ID'),
|
||||
sa.Column('plugin_name', sa.String(length=100), nullable=False, comment='插件名称(唯一标识)'),
|
||||
sa.Column('display_name', sa.String(length=200), nullable=False, comment='显示名称'),
|
||||
sa.Column('description', sa.Text(), nullable=True, comment='插件描述'),
|
||||
sa.Column('plugin_type', sa.String(length=50), nullable=True, comment='插件类型:http/stdio'),
|
||||
sa.Column('server_url', sa.String(length=500), nullable=True, comment='服务器URL(HTTP类型)'),
|
||||
sa.Column('command', sa.String(length=500), nullable=True, comment='启动命令(stdio类型)'),
|
||||
sa.Column('args', sa.JSON(), nullable=True, comment='命令参数(stdio类型)'),
|
||||
sa.Column('env', sa.JSON(), nullable=True, comment='环境变量'),
|
||||
sa.Column('headers', sa.JSON(), nullable=True, comment='HTTP请求头'),
|
||||
sa.Column('config', sa.JSON(), nullable=True, comment='插件特定配置(JSON)'),
|
||||
sa.Column('tools', sa.JSON(), nullable=True, comment='提供的工具列表'),
|
||||
sa.Column('enabled', sa.Boolean(), nullable=True, comment='是否启用'),
|
||||
sa.Column('status', sa.String(length=50), nullable=True, comment='状态:active/inactive/error'),
|
||||
sa.Column('last_error', sa.Text(), nullable=True, comment='最后错误信息'),
|
||||
sa.Column('last_test_at', sa.DateTime(), nullable=True, comment='最后测试时间'),
|
||||
sa.Column('category', sa.String(length=100), nullable=True, comment='分类'),
|
||||
sa.Column('sort_order', sa.Integer(), nullable=True, comment='排序顺序'),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('now()'), nullable=True, comment='创建时间'),
|
||||
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('now()'), nullable=True, comment='更新时间'),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_index('idx_user_enabled', 'mcp_plugins', ['user_id', 'enabled'], unique=False)
|
||||
op.create_index('idx_user_plugin', 'mcp_plugins', ['user_id', 'plugin_name'], unique=True)
|
||||
op.create_index(op.f('ix_mcp_plugins_user_id'), 'mcp_plugins', ['user_id'], unique=False)
|
||||
op.create_table('projects',
|
||||
sa.Column('id', sa.String(length=36), nullable=False),
|
||||
sa.Column('user_id', sa.String(length=100), nullable=False, comment='用户ID'),
|
||||
sa.Column('title', sa.String(length=200), nullable=False, comment='项目标题'),
|
||||
sa.Column('description', sa.Text(), nullable=True, comment='项目简介'),
|
||||
sa.Column('theme', sa.Text(), nullable=True, comment='主题'),
|
||||
sa.Column('genre', sa.String(length=50), nullable=True, comment='小说类型'),
|
||||
sa.Column('target_words', sa.Integer(), nullable=True, comment='目标字数'),
|
||||
sa.Column('current_words', sa.Integer(), nullable=True, comment='当前字数'),
|
||||
sa.Column('status', sa.String(length=20), nullable=True, comment='创作状态'),
|
||||
sa.Column('wizard_status', sa.String(length=20), nullable=True, comment='向导完成状态: incomplete/completed'),
|
||||
sa.Column('wizard_step', sa.Integer(), nullable=True, comment='向导当前步骤: 0-4'),
|
||||
sa.Column('outline_mode', sa.String(length=20), nullable=False, comment='大纲章节模式: one-to-one(传统模式) 或 one-to-many(细化模式)'),
|
||||
sa.Column('world_time_period', sa.Text(), nullable=True, comment='时间背景'),
|
||||
sa.Column('world_location', sa.Text(), nullable=True, comment='地理位置'),
|
||||
sa.Column('world_atmosphere', sa.Text(), nullable=True, comment='氛围基调'),
|
||||
sa.Column('world_rules', sa.Text(), nullable=True, comment='世界规则'),
|
||||
sa.Column('chapter_count', sa.Integer(), nullable=True, comment='章节数量'),
|
||||
sa.Column('narrative_perspective', sa.String(length=50), nullable=True, comment='叙事视角:first_person/third_person/omniscient'),
|
||||
sa.Column('character_count', sa.Integer(), nullable=True, comment='角色数量'),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('now()'), nullable=True, comment='创建时间'),
|
||||
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('now()'), nullable=True, comment='更新时间'),
|
||||
sa.CheckConstraint("outline_mode IN ('one-to-one', 'one-to-many')", name='check_outline_mode'),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_index(op.f('ix_projects_user_id'), 'projects', ['user_id'], unique=False)
|
||||
op.create_table('prompt_templates',
|
||||
sa.Column('id', sa.String(length=36), nullable=False),
|
||||
sa.Column('user_id', sa.String(length=50), nullable=False, comment='用户ID'),
|
||||
sa.Column('template_key', sa.String(length=100), nullable=False, comment='模板键名'),
|
||||
sa.Column('template_name', sa.String(length=200), nullable=False, comment='模板显示名称'),
|
||||
sa.Column('template_content', sa.Text(), nullable=False, comment='模板内容'),
|
||||
sa.Column('description', sa.Text(), nullable=True, comment='模板描述'),
|
||||
sa.Column('category', sa.String(length=50), nullable=True, comment='模板分类'),
|
||||
sa.Column('parameters', sa.Text(), nullable=True, comment='模板参数定义(JSON)'),
|
||||
sa.Column('is_active', sa.Boolean(), nullable=True, comment='是否启用'),
|
||||
sa.Column('is_system_default', sa.Boolean(), nullable=True, comment='是否为系统默认模板'),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('now()'), nullable=True, comment='创建时间'),
|
||||
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('now()'), nullable=True, comment='更新时间'),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_index('idx_user_template', 'prompt_templates', ['user_id', 'template_key'], unique=True)
|
||||
op.create_index(op.f('ix_prompt_templates_user_id'), 'prompt_templates', ['user_id'], unique=False)
|
||||
op.create_table('relationship_types',
|
||||
sa.Column('id', sa.Integer(), autoincrement=True, nullable=False),
|
||||
sa.Column('name', sa.String(length=50), nullable=False, comment='关系名称'),
|
||||
sa.Column('category', sa.String(length=20), nullable=False, comment='分类:family/social/hostile/professional'),
|
||||
sa.Column('reverse_name', sa.String(length=50), nullable=True, comment='反向关系名称'),
|
||||
sa.Column('intimacy_range', sa.String(length=20), nullable=True, comment='亲密度范围:high/medium/low'),
|
||||
sa.Column('icon', sa.String(length=50), nullable=True, comment='图标标识'),
|
||||
sa.Column('description', sa.Text(), nullable=True, comment='关系描述'),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('now()'), nullable=True, comment='创建时间'),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_index(op.f('ix_relationship_types_id'), 'relationship_types', ['id'], unique=False)
|
||||
op.create_table('settings',
|
||||
sa.Column('id', sa.String(length=36), nullable=False),
|
||||
sa.Column('user_id', sa.String(length=50), nullable=False, comment='用户ID'),
|
||||
sa.Column('api_provider', sa.String(length=50), nullable=True, comment='API提供商'),
|
||||
sa.Column('api_key', sa.String(length=500), nullable=True, comment='API密钥'),
|
||||
sa.Column('api_base_url', sa.String(length=500), nullable=True, comment='自定义API地址'),
|
||||
sa.Column('llm_model', sa.String(length=100), nullable=True, comment='模型名称'),
|
||||
sa.Column('temperature', sa.Float(), nullable=True, comment='温度参数'),
|
||||
sa.Column('max_tokens', sa.Integer(), nullable=True, comment='最大token数'),
|
||||
sa.Column('preferences', sa.Text(), nullable=True, comment='其他偏好设置(JSON)'),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('now()'), nullable=True, comment='创建时间'),
|
||||
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('now()'), nullable=True, comment='更新时间'),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_index('idx_user_id', 'settings', ['user_id'], unique=False)
|
||||
op.create_index(op.f('ix_settings_user_id'), 'settings', ['user_id'], unique=True)
|
||||
op.create_table('user_passwords',
|
||||
sa.Column('user_id', sa.String(length=100), nullable=False, comment='用户ID'),
|
||||
sa.Column('username', sa.String(length=100), nullable=False, comment='用户名'),
|
||||
sa.Column('password_hash', sa.String(length=64), nullable=False, comment='密码哈希(SHA256)'),
|
||||
sa.Column('has_custom_password', sa.Boolean(), nullable=True, comment='是否为自定义密码'),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=True, comment='创建时间'),
|
||||
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=True, comment='更新时间'),
|
||||
sa.PrimaryKeyConstraint('user_id')
|
||||
)
|
||||
op.create_index(op.f('ix_user_passwords_user_id'), 'user_passwords', ['user_id'], unique=False)
|
||||
op.create_table('users',
|
||||
sa.Column('user_id', sa.String(length=100), nullable=False, comment='用户ID,格式:linuxdo_{id} 或 local_{id}'),
|
||||
sa.Column('username', sa.String(length=100), nullable=False, comment='用户名'),
|
||||
sa.Column('display_name', sa.String(length=200), nullable=False, comment='显示名称'),
|
||||
sa.Column('avatar_url', sa.String(length=500), nullable=True, comment='头像URL'),
|
||||
sa.Column('trust_level', sa.Integer(), nullable=True, comment='信任等级(仅用于显示)'),
|
||||
sa.Column('is_admin', sa.Boolean(), nullable=True, comment='是否为管理员'),
|
||||
sa.Column('linuxdo_id', sa.String(length=100), nullable=False, comment='LinuxDO用户ID或本地用户ID'),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=True, comment='创建时间'),
|
||||
sa.Column('last_login', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=True, comment='最后登录时间'),
|
||||
sa.PrimaryKeyConstraint('user_id')
|
||||
)
|
||||
op.create_index(op.f('ix_users_linuxdo_id'), 'users', ['linuxdo_id'], unique=True)
|
||||
op.create_index(op.f('ix_users_user_id'), 'users', ['user_id'], unique=False)
|
||||
op.create_index(op.f('ix_users_username'), 'users', ['username'], unique=False)
|
||||
op.create_table('careers',
|
||||
sa.Column('id', sa.String(length=36), nullable=False),
|
||||
sa.Column('project_id', sa.String(length=36), nullable=False),
|
||||
sa.Column('name', sa.String(length=100), nullable=False, comment='职业名称'),
|
||||
sa.Column('type', sa.String(length=20), nullable=False, comment='职业类型: main(主职业)/sub(副职业)'),
|
||||
sa.Column('description', sa.Text(), nullable=True, comment='职业描述'),
|
||||
sa.Column('category', sa.String(length=50), nullable=True, comment='职业分类(如:战斗系、生产系、辅助系)'),
|
||||
sa.Column('stages', sa.Text(), nullable=False, comment="职业阶段列表(JSON): [{level:1, name:'', description:''}, ...]"),
|
||||
sa.Column('max_stage', sa.Integer(), nullable=False, comment='最大阶段数'),
|
||||
sa.Column('requirements', sa.Text(), nullable=True, comment='职业要求/限制'),
|
||||
sa.Column('special_abilities', sa.Text(), nullable=True, comment='特殊能力描述'),
|
||||
sa.Column('worldview_rules', sa.Text(), nullable=True, comment='世界观规则关联'),
|
||||
sa.Column('attribute_bonuses', sa.Text(), nullable=True, comment="属性加成(JSON): {strength: '+10%', intelligence: '+5%'}"),
|
||||
sa.Column('source', sa.String(length=20), nullable=True, comment='来源: ai/manual'),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('now()'), nullable=True, comment='创建时间'),
|
||||
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('now()'), nullable=True, comment='更新时间'),
|
||||
sa.ForeignKeyConstraint(['project_id'], ['projects.id'], ondelete='CASCADE'),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_index('idx_project_id', 'careers', ['project_id'], unique=False)
|
||||
op.create_index('idx_type', 'careers', ['type'], unique=False)
|
||||
op.create_table('outlines',
|
||||
sa.Column('id', sa.String(length=36), nullable=False),
|
||||
sa.Column('project_id', sa.String(length=36), nullable=False),
|
||||
sa.Column('title', sa.String(length=200), nullable=False, comment='大纲标题'),
|
||||
sa.Column('content', sa.Text(), nullable=True, comment='大纲内容'),
|
||||
sa.Column('structure', sa.Text(), nullable=True, comment='结构化大纲数据(JSON)'),
|
||||
sa.Column('order_index', sa.Integer(), nullable=True, comment='排序序号'),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('now()'), nullable=True, comment='创建时间'),
|
||||
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('now()'), nullable=True, comment='更新时间'),
|
||||
sa.ForeignKeyConstraint(['project_id'], ['projects.id'], ondelete='CASCADE'),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_table('writing_styles',
|
||||
sa.Column('id', sa.Integer(), autoincrement=True, nullable=False),
|
||||
sa.Column('user_id', sa.String(length=255), nullable=True, comment='所属用户ID(NULL表示全局预设风格)'),
|
||||
sa.Column('name', sa.String(length=100), nullable=False, comment='风格名称'),
|
||||
sa.Column('style_type', sa.String(length=50), nullable=False, comment='风格类型:preset/custom'),
|
||||
sa.Column('preset_id', sa.String(length=50), nullable=True, comment='预设风格ID:natural/classical/modern等'),
|
||||
sa.Column('description', sa.Text(), nullable=True, comment='风格描述'),
|
||||
sa.Column('prompt_content', sa.Text(), nullable=False, comment='风格提示词内容'),
|
||||
sa.Column('order_index', sa.Integer(), nullable=True, comment='排序序号'),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('now()'), nullable=True, comment='创建时间'),
|
||||
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('now()'), nullable=True, comment='更新时间'),
|
||||
sa.ForeignKeyConstraint(['user_id'], ['users.user_id'], ondelete='CASCADE'),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_table('chapters',
|
||||
sa.Column('id', sa.String(length=36), nullable=False),
|
||||
sa.Column('project_id', sa.String(length=36), nullable=False),
|
||||
sa.Column('chapter_number', sa.Integer(), nullable=False, comment='章节序号'),
|
||||
sa.Column('title', sa.String(length=200), nullable=False, comment='章节标题'),
|
||||
sa.Column('content', sa.Text(), nullable=True, comment='章节内容'),
|
||||
sa.Column('summary', sa.Text(), nullable=True, comment='章节摘要'),
|
||||
sa.Column('word_count', sa.Integer(), nullable=True, comment='字数统计'),
|
||||
sa.Column('status', sa.String(length=20), nullable=True, comment='章节状态'),
|
||||
sa.Column('outline_id', sa.String(length=36), nullable=True, comment='关联的大纲ID'),
|
||||
sa.Column('sub_index', sa.Integer(), nullable=True, comment='大纲下的子章节序号'),
|
||||
sa.Column('expansion_plan', sa.Text(), nullable=True, comment='展开规划详情(JSON): 包含key_events, character_focus, emotional_tone等'),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('now()'), nullable=True, comment='创建时间'),
|
||||
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('now()'), nullable=True, comment='更新时间'),
|
||||
sa.ForeignKeyConstraint(['outline_id'], ['outlines.id'], ondelete='SET NULL'),
|
||||
sa.ForeignKeyConstraint(['project_id'], ['projects.id'], ondelete='CASCADE'),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_table('characters',
|
||||
sa.Column('id', sa.String(length=36), nullable=False),
|
||||
sa.Column('project_id', sa.String(length=36), nullable=False),
|
||||
sa.Column('name', sa.String(length=100), nullable=False, comment='角色/组织名称'),
|
||||
sa.Column('age', sa.String(length=50), nullable=True, comment='年龄'),
|
||||
sa.Column('gender', sa.String(length=50), nullable=True, comment='性别'),
|
||||
sa.Column('is_organization', sa.Boolean(), nullable=True, comment='是否为组织'),
|
||||
sa.Column('role_type', sa.String(length=50), nullable=True, comment='角色类型'),
|
||||
sa.Column('personality', sa.Text(), nullable=True, comment='性格特点/组织特性'),
|
||||
sa.Column('background', sa.Text(), nullable=True, comment='背景故事'),
|
||||
sa.Column('appearance', sa.Text(), nullable=True, comment='外貌描述'),
|
||||
sa.Column('relationships', sa.Text(), nullable=True, comment='人物关系(JSON)'),
|
||||
sa.Column('organization_type', sa.String(length=100), nullable=True, comment='组织类型'),
|
||||
sa.Column('organization_purpose', sa.String(length=500), nullable=True, comment='组织目的'),
|
||||
sa.Column('organization_members', sa.Text(), nullable=True, comment='组织成员(JSON)'),
|
||||
sa.Column('main_career_id', sa.String(length=36), nullable=True, comment='主职业ID'),
|
||||
sa.Column('main_career_stage', sa.Integer(), nullable=True, comment='主职业当前阶段'),
|
||||
sa.Column('sub_careers', sa.Text(), nullable=True, comment='副职业列表(JSON): [{"career_id": "xxx", "stage": 3}, ...]'),
|
||||
sa.Column('avatar_url', sa.String(length=500), nullable=True, comment='头像URL'),
|
||||
sa.Column('traits', sa.Text(), nullable=True, comment='特征标签(JSON)'),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('now()'), nullable=True, comment='创建时间'),
|
||||
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('now()'), nullable=True, comment='更新时间'),
|
||||
sa.ForeignKeyConstraint(['main_career_id'], ['careers.id'], ondelete='SET NULL'),
|
||||
sa.ForeignKeyConstraint(['project_id'], ['projects.id'], ondelete='CASCADE'),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_table('project_default_styles',
|
||||
sa.Column('id', sa.Integer(), autoincrement=True, nullable=False),
|
||||
sa.Column('project_id', sa.String(length=36), nullable=False, comment='项目ID'),
|
||||
sa.Column('style_id', sa.Integer(), nullable=False, comment='风格ID'),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('now()'), nullable=True, comment='创建时间'),
|
||||
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('now()'), nullable=True, comment='更新时间'),
|
||||
sa.ForeignKeyConstraint(['project_id'], ['projects.id'], ondelete='CASCADE'),
|
||||
sa.ForeignKeyConstraint(['style_id'], ['writing_styles.id'], ondelete='CASCADE'),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
sa.UniqueConstraint('project_id', name='uix_project_default_style')
|
||||
)
|
||||
op.create_table('analysis_tasks',
|
||||
sa.Column('id', sa.String(length=36), nullable=False, comment='任务ID'),
|
||||
sa.Column('chapter_id', sa.String(length=36), nullable=False, comment='章节ID'),
|
||||
sa.Column('user_id', sa.String(length=50), nullable=False, comment='用户ID'),
|
||||
sa.Column('project_id', sa.String(length=36), nullable=False, comment='项目ID'),
|
||||
sa.Column('status', sa.String(length=20), nullable=False, comment='任务状态: pending/running/completed/failed'),
|
||||
sa.Column('progress', sa.Integer(), nullable=True, comment='进度 0-100'),
|
||||
sa.Column('error_message', sa.Text(), nullable=True, comment='错误信息'),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('now()'), nullable=True, comment='创建时间'),
|
||||
sa.Column('started_at', sa.DateTime(), nullable=True, comment='开始执行时间'),
|
||||
sa.Column('completed_at', sa.DateTime(), nullable=True, comment='完成时间'),
|
||||
sa.ForeignKeyConstraint(['chapter_id'], ['chapters.id'], ondelete='CASCADE'),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_index('idx_chapter_id_created', 'analysis_tasks', ['chapter_id', 'created_at'], unique=False)
|
||||
op.create_index('idx_status', 'analysis_tasks', ['status'], unique=False)
|
||||
op.create_table('character_careers',
|
||||
sa.Column('id', sa.String(length=36), nullable=False),
|
||||
sa.Column('character_id', sa.String(length=36), nullable=False),
|
||||
sa.Column('career_id', sa.String(length=36), nullable=False),
|
||||
sa.Column('career_type', sa.String(length=20), nullable=False, comment='main(主职业)/sub(副职业)'),
|
||||
sa.Column('current_stage', sa.Integer(), nullable=False, comment='当前阶段(对应职业中的数值)'),
|
||||
sa.Column('stage_progress', sa.Integer(), nullable=True, comment='阶段内进度(0-100)'),
|
||||
sa.Column('started_at', sa.String(length=100), nullable=True, comment='开始修炼时间(小说时间线)'),
|
||||
sa.Column('reached_current_stage_at', sa.String(length=100), nullable=True, comment='到达当前阶段时间'),
|
||||
sa.Column('notes', sa.Text(), nullable=True, comment='备注(如:修炼心得、特殊事件)'),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('now()'), nullable=True, comment='创建时间'),
|
||||
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('now()'), nullable=True, comment='更新时间'),
|
||||
sa.ForeignKeyConstraint(['career_id'], ['careers.id'], ondelete='CASCADE'),
|
||||
sa.ForeignKeyConstraint(['character_id'], ['characters.id'], ondelete='CASCADE'),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_index('idx_career_type', 'character_careers', ['career_type'], unique=False)
|
||||
op.create_index('idx_character_career', 'character_careers', ['character_id', 'career_id'], unique=True)
|
||||
op.create_index('idx_character_id', 'character_careers', ['character_id'], unique=False)
|
||||
op.create_table('character_relationships',
|
||||
sa.Column('id', sa.String(length=36), nullable=False, comment='关系ID'),
|
||||
sa.Column('project_id', sa.String(length=36), nullable=False, comment='项目ID'),
|
||||
sa.Column('character_from_id', sa.String(length=36), nullable=False, comment='角色A的ID'),
|
||||
sa.Column('character_to_id', sa.String(length=36), nullable=False, comment='角色B的ID'),
|
||||
sa.Column('relationship_type_id', sa.Integer(), nullable=True, comment='关系类型ID'),
|
||||
sa.Column('relationship_name', sa.String(length=100), nullable=True, comment='自定义关系名称'),
|
||||
sa.Column('intimacy_level', sa.Integer(), nullable=True, comment='亲密度:-100到100'),
|
||||
sa.Column('status', sa.String(length=20), nullable=True, comment='状态:active/broken/past/complicated'),
|
||||
sa.Column('description', sa.Text(), nullable=True, comment='关系详细描述'),
|
||||
sa.Column('started_at', sa.String(length=100), nullable=True, comment='关系开始时间(故事时间)'),
|
||||
sa.Column('ended_at', sa.String(length=100), nullable=True, comment='关系结束时间(故事时间)'),
|
||||
sa.Column('source', sa.String(length=20), nullable=True, comment='来源:ai/manual/imported'),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('now()'), nullable=True, comment='创建时间'),
|
||||
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('now()'), nullable=True, comment='更新时间'),
|
||||
sa.ForeignKeyConstraint(['character_from_id'], ['characters.id'], ondelete='CASCADE'),
|
||||
sa.ForeignKeyConstraint(['character_to_id'], ['characters.id'], ondelete='CASCADE'),
|
||||
sa.ForeignKeyConstraint(['project_id'], ['projects.id'], ondelete='CASCADE'),
|
||||
sa.ForeignKeyConstraint(['relationship_type_id'], ['relationship_types.id'], ),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_index(op.f('ix_character_relationships_character_from_id'), 'character_relationships', ['character_from_id'], unique=False)
|
||||
op.create_index(op.f('ix_character_relationships_character_to_id'), 'character_relationships', ['character_to_id'], unique=False)
|
||||
op.create_index(op.f('ix_character_relationships_project_id'), 'character_relationships', ['project_id'], unique=False)
|
||||
op.create_index(op.f('ix_character_relationships_relationship_type_id'), 'character_relationships', ['relationship_type_id'], unique=False)
|
||||
op.create_table('generation_history',
|
||||
sa.Column('id', sa.String(length=36), nullable=False),
|
||||
sa.Column('project_id', sa.String(length=36), nullable=False),
|
||||
sa.Column('chapter_id', sa.String(length=36), nullable=True),
|
||||
sa.Column('prompt', sa.Text(), nullable=True, comment='使用的提示词'),
|
||||
sa.Column('generated_content', sa.Text(), nullable=True, comment='生成的内容'),
|
||||
sa.Column('model', sa.String(length=50), nullable=True, comment='使用的模型'),
|
||||
sa.Column('tokens_used', sa.Integer(), nullable=True, comment='消耗的token数'),
|
||||
sa.Column('generation_time', sa.Float(), nullable=True, comment='生成耗时(秒)'),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('now()'), nullable=True, comment='创建时间'),
|
||||
sa.ForeignKeyConstraint(['chapter_id'], ['chapters.id'], ondelete='SET NULL'),
|
||||
sa.ForeignKeyConstraint(['project_id'], ['projects.id'], ondelete='CASCADE'),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_table('organizations',
|
||||
sa.Column('id', sa.String(length=36), nullable=False, comment='组织ID'),
|
||||
sa.Column('character_id', sa.String(length=36), nullable=False, comment='关联的角色ID'),
|
||||
sa.Column('project_id', sa.String(length=36), nullable=False, comment='项目ID'),
|
||||
sa.Column('parent_org_id', sa.String(length=36), nullable=True, comment='父组织ID'),
|
||||
sa.Column('level', sa.Integer(), nullable=True, comment='组织层级'),
|
||||
sa.Column('power_level', sa.Integer(), nullable=True, comment='势力等级:0-100'),
|
||||
sa.Column('member_count', sa.Integer(), nullable=True, comment='成员数量'),
|
||||
sa.Column('location', sa.Text(), nullable=True, comment='所在地'),
|
||||
sa.Column('motto', sa.String(length=200), nullable=True, comment='宗旨/口号'),
|
||||
sa.Column('color', sa.String(length=100), nullable=True, comment='代表颜色'),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('now()'), nullable=True, comment='创建时间'),
|
||||
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('now()'), nullable=True, comment='更新时间'),
|
||||
sa.ForeignKeyConstraint(['character_id'], ['characters.id'], ondelete='CASCADE'),
|
||||
sa.ForeignKeyConstraint(['parent_org_id'], ['organizations.id'], ondelete='SET NULL'),
|
||||
sa.ForeignKeyConstraint(['project_id'], ['projects.id'], ondelete='CASCADE'),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
sa.UniqueConstraint('character_id')
|
||||
)
|
||||
op.create_index(op.f('ix_organizations_project_id'), 'organizations', ['project_id'], unique=False)
|
||||
op.create_table('plot_analysis',
|
||||
sa.Column('id', sa.String(length=36), nullable=False),
|
||||
sa.Column('project_id', sa.String(length=36), nullable=False),
|
||||
sa.Column('chapter_id', sa.String(length=36), nullable=False),
|
||||
sa.Column('plot_stage', sa.String(length=50), nullable=True, comment='剧情阶段: 开端/发展/高潮/结局/过渡'),
|
||||
sa.Column('conflict_level', sa.Integer(), nullable=True, comment='冲突强度 1-10'),
|
||||
sa.Column('conflict_types', sa.JSON(), nullable=True, comment="冲突类型列表: ['人与人', '人与己', '人与环境']"),
|
||||
sa.Column('emotional_tone', sa.String(length=100), nullable=True, comment='主导情感: 紧张/温馨/悲伤/激昂/平静'),
|
||||
sa.Column('emotional_intensity', sa.Float(), nullable=True, comment='情感强度 0.0-1.0'),
|
||||
sa.Column('emotional_curve', sa.JSON(), nullable=True, comment='情感曲线: {start: 0.3, middle: 0.7, end: 0.5}'),
|
||||
sa.Column('hooks', sa.JSON(), nullable=True, comment='钩子列表 - 吸引读者的元素: [\n {\n "type": "悬念|情感|冲突|认知",\n "content": "具体内容",\n "strength": 8,\n "position": "开头|中段|结尾"\n }\n ]'),
|
||||
sa.Column('hooks_count', sa.Integer(), nullable=True, comment='钩子数量'),
|
||||
sa.Column('hooks_avg_strength', sa.Float(), nullable=True, comment='钩子平均强度'),
|
||||
sa.Column('foreshadows', sa.JSON(), nullable=True, comment='伏笔列表: [\n {\n "content": "伏笔内容",\n "type": "planted|resolved",\n "strength": 7,\n "subtlety": 8,\n "reference_chapter": 3\n }\n ]'),
|
||||
sa.Column('foreshadows_planted', sa.Integer(), nullable=True, comment='本章埋下的伏笔数量'),
|
||||
sa.Column('foreshadows_resolved', sa.Integer(), nullable=True, comment='本章回收的伏笔数量'),
|
||||
sa.Column('plot_points', sa.JSON(), nullable=True, comment='情节点列表: [\n {\n "content": "情节点描述",\n "importance": 0.9,\n "type": "revelation|conflict|resolution|transition",\n "impact": "对故事的影响描述"\n }\n ]'),
|
||||
sa.Column('plot_points_count', sa.Integer(), nullable=True, comment='情节点数量'),
|
||||
sa.Column('character_states', sa.JSON(), nullable=True, comment='角色状态变化: [\n {\n "character_id": "xxx",\n "character_name": "张三",\n "state_before": "犹豫不决",\n "state_after": "坚定信念",\n "psychological_change": "内心描述",\n "key_event": "触发事件",\n "relationship_changes": {"李四": "关系变化"}\n }\n ]'),
|
||||
sa.Column('scenes', sa.JSON(), nullable=True, comment="场景列表: [{location: '地点', atmosphere: '氛围', duration: '时长'}]"),
|
||||
sa.Column('pacing', sa.String(length=50), nullable=True, comment='节奏: slow|moderate|fast|varied'),
|
||||
sa.Column('overall_quality_score', sa.Float(), nullable=True, comment='整体质量评分 0.0-10.0'),
|
||||
sa.Column('pacing_score', sa.Float(), nullable=True, comment='节奏评分 0.0-10.0'),
|
||||
sa.Column('engagement_score', sa.Float(), nullable=True, comment='吸引力评分 0.0-10.0'),
|
||||
sa.Column('coherence_score', sa.Float(), nullable=True, comment='连贯性评分 0.0-10.0'),
|
||||
sa.Column('analysis_report', sa.Text(), nullable=True, comment='完整的文字分析报告'),
|
||||
sa.Column('suggestions', sa.JSON(), nullable=True, comment="改进建议列表: ['建议1', '建议2']"),
|
||||
sa.Column('word_count', sa.Integer(), nullable=True, comment='章节字数'),
|
||||
sa.Column('dialogue_ratio', sa.Float(), nullable=True, comment='对话占比 0.0-1.0'),
|
||||
sa.Column('description_ratio', sa.Float(), nullable=True, comment='描写占比 0.0-1.0'),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('now()'), nullable=True, comment='分析时间'),
|
||||
sa.ForeignKeyConstraint(['chapter_id'], ['chapters.id'], ondelete='CASCADE'),
|
||||
sa.ForeignKeyConstraint(['project_id'], ['projects.id'], ondelete='CASCADE'),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_index(op.f('ix_plot_analysis_chapter_id'), 'plot_analysis', ['chapter_id'], unique=True)
|
||||
op.create_index(op.f('ix_plot_analysis_project_id'), 'plot_analysis', ['project_id'], unique=False)
|
||||
op.create_table('regeneration_tasks',
|
||||
sa.Column('id', sa.String(length=36), nullable=False),
|
||||
sa.Column('chapter_id', sa.String(length=36), nullable=False),
|
||||
sa.Column('analysis_id', sa.String(length=36), nullable=True, comment='关联的分析结果ID'),
|
||||
sa.Column('user_id', sa.String(length=50), nullable=False),
|
||||
sa.Column('project_id', sa.String(length=36), nullable=False),
|
||||
sa.Column('modification_instructions', sa.Text(), nullable=False, comment='综合修改指令'),
|
||||
sa.Column('original_suggestions', sa.JSON(), nullable=True, comment='来自分析的原始建议列表'),
|
||||
sa.Column('selected_suggestion_indices', sa.JSON(), nullable=True, comment='用户选择的建议索引'),
|
||||
sa.Column('custom_instructions', sa.Text(), nullable=True, comment='用户自定义修改意见'),
|
||||
sa.Column('style_id', sa.Integer(), nullable=True, comment='写作风格ID'),
|
||||
sa.Column('target_word_count', sa.Integer(), nullable=True, comment='目标字数'),
|
||||
sa.Column('focus_areas', sa.JSON(), nullable=True, comment='重点优化方向'),
|
||||
sa.Column('preserve_elements', sa.JSON(), nullable=True, comment='需要保留的元素配置'),
|
||||
sa.Column('status', sa.String(length=20), nullable=True, comment='pending/running/completed/failed'),
|
||||
sa.Column('progress', sa.Integer(), nullable=True, comment='进度 0-100'),
|
||||
sa.Column('error_message', sa.Text(), nullable=True),
|
||||
sa.Column('original_content', sa.Text(), nullable=True, comment='原始章节内容快照'),
|
||||
sa.Column('original_word_count', sa.Integer(), nullable=True, comment='原始字数'),
|
||||
sa.Column('regenerated_content', sa.Text(), nullable=True, comment='重新生成的内容'),
|
||||
sa.Column('regenerated_word_count', sa.Integer(), nullable=True, comment='新内容字数'),
|
||||
sa.Column('version_number', sa.Integer(), nullable=True, comment='版本号'),
|
||||
sa.Column('version_note', sa.String(length=500), nullable=True, comment='版本说明'),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('now()'), nullable=True),
|
||||
sa.Column('started_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('completed_at', sa.DateTime(), nullable=True),
|
||||
sa.ForeignKeyConstraint(['chapter_id'], ['chapters.id'], ondelete='CASCADE'),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_index(op.f('ix_regeneration_tasks_chapter_id'), 'regeneration_tasks', ['chapter_id'], unique=False)
|
||||
op.create_index(op.f('ix_regeneration_tasks_project_id'), 'regeneration_tasks', ['project_id'], unique=False)
|
||||
op.create_index(op.f('ix_regeneration_tasks_user_id'), 'regeneration_tasks', ['user_id'], unique=False)
|
||||
op.create_table('story_memories',
|
||||
sa.Column('id', sa.String(length=100), nullable=False),
|
||||
sa.Column('project_id', sa.String(length=36), nullable=False),
|
||||
sa.Column('chapter_id', sa.String(length=36), nullable=True),
|
||||
sa.Column('memory_type', sa.String(length=50), nullable=False, comment='\n 记忆类型:\n - plot_point: 情节点\n - character_event: 角色事件\n - world_detail: 世界观细节\n - hook: 钩子(悬念/冲突)\n - foreshadow: 伏笔\n - dialogue: 重要对话\n - scene: 场景描写\n '),
|
||||
sa.Column('title', sa.String(length=200), nullable=True, comment='记忆标题/简述'),
|
||||
sa.Column('content', sa.Text(), nullable=False, comment='记忆内容摘要(100-500字)'),
|
||||
sa.Column('full_context', sa.Text(), nullable=True, comment='完整上下文(可选,用于详细记录)'),
|
||||
sa.Column('related_characters', sa.JSON(), nullable=True, comment="涉及角色ID列表: ['char_id_1', 'char_id_2']"),
|
||||
sa.Column('related_locations', sa.JSON(), nullable=True, comment="涉及地点列表: ['地点1', '地点2']"),
|
||||
sa.Column('tags', sa.JSON(), nullable=True, comment="标签列表: ['悬念', '转折', '伏笔', '高潮']"),
|
||||
sa.Column('importance_score', sa.Float(), nullable=True, comment='重要性评分 0.0-1.0'),
|
||||
sa.Column('story_timeline', sa.Integer(), nullable=False, comment='故事时间线位置(章节序号)'),
|
||||
sa.Column('chapter_position', sa.Integer(), nullable=True, comment='章节内位置(字符位置)'),
|
||||
sa.Column('text_length', sa.Integer(), nullable=True, comment='文本长度(字符数)'),
|
||||
sa.Column('is_foreshadow', sa.Integer(), nullable=True, comment='伏笔状态: 0=普通记忆, 1=已埋下伏笔, 2=伏笔已回收'),
|
||||
sa.Column('foreshadow_resolved_at', sa.String(length=100), nullable=True, comment='伏笔回收的章节ID'),
|
||||
sa.Column('foreshadow_strength', sa.Float(), nullable=True, comment='伏笔强度 0.0-1.0'),
|
||||
sa.Column('vector_id', sa.String(length=100), nullable=True, comment='向量数据库中的唯一ID'),
|
||||
sa.Column('embedding_model', sa.String(length=100), nullable=True, comment='使用的embedding模型'),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('now()'), nullable=True, comment='创建时间'),
|
||||
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('now()'), nullable=True, comment='更新时间'),
|
||||
sa.ForeignKeyConstraint(['chapter_id'], ['chapters.id'], ondelete='CASCADE'),
|
||||
sa.ForeignKeyConstraint(['foreshadow_resolved_at'], ['chapters.id'], ondelete='SET NULL'),
|
||||
sa.ForeignKeyConstraint(['project_id'], ['projects.id'], ondelete='CASCADE'),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
sa.UniqueConstraint('vector_id')
|
||||
)
|
||||
op.create_index(op.f('ix_story_memories_chapter_id'), 'story_memories', ['chapter_id'], unique=False)
|
||||
op.create_index(op.f('ix_story_memories_memory_type'), 'story_memories', ['memory_type'], unique=False)
|
||||
op.create_index(op.f('ix_story_memories_project_id'), 'story_memories', ['project_id'], unique=False)
|
||||
op.create_index(op.f('ix_story_memories_story_timeline'), 'story_memories', ['story_timeline'], unique=False)
|
||||
op.create_table('organization_members',
|
||||
sa.Column('id', sa.String(length=36), nullable=False, comment='成员关系ID'),
|
||||
sa.Column('organization_id', sa.String(length=36), nullable=False, comment='组织ID'),
|
||||
sa.Column('character_id', sa.String(length=36), nullable=False, comment='角色ID'),
|
||||
sa.Column('position', sa.String(length=100), nullable=False, comment='职位名称'),
|
||||
sa.Column('rank', sa.Integer(), nullable=True, comment='职位等级'),
|
||||
sa.Column('status', sa.String(length=20), nullable=True, comment='状态:active/retired/expelled/deceased'),
|
||||
sa.Column('joined_at', sa.String(length=100), nullable=True, comment='加入时间(故事时间)'),
|
||||
sa.Column('left_at', sa.String(length=100), nullable=True, comment='离开时间(故事时间)'),
|
||||
sa.Column('loyalty', sa.Integer(), nullable=True, comment='忠诚度:0-100'),
|
||||
sa.Column('contribution', sa.Integer(), nullable=True, comment='贡献度:0-100'),
|
||||
sa.Column('source', sa.String(length=20), nullable=True, comment='来源:ai/manual'),
|
||||
sa.Column('notes', sa.Text(), nullable=True, comment='备注'),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('now()'), nullable=True, comment='创建时间'),
|
||||
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('now()'), nullable=True, comment='更新时间'),
|
||||
sa.ForeignKeyConstraint(['character_id'], ['characters.id'], ondelete='CASCADE'),
|
||||
sa.ForeignKeyConstraint(['organization_id'], ['organizations.id'], ondelete='CASCADE'),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_index(op.f('ix_organization_members_character_id'), 'organization_members', ['character_id'], unique=False)
|
||||
op.create_index(op.f('ix_organization_members_organization_id'), 'organization_members', ['organization_id'], unique=False)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""降级数据库结构"""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_index(op.f('ix_organization_members_organization_id'), table_name='organization_members')
|
||||
op.drop_index(op.f('ix_organization_members_character_id'), table_name='organization_members')
|
||||
op.drop_table('organization_members')
|
||||
op.drop_index(op.f('ix_story_memories_story_timeline'), table_name='story_memories')
|
||||
op.drop_index(op.f('ix_story_memories_project_id'), table_name='story_memories')
|
||||
op.drop_index(op.f('ix_story_memories_memory_type'), table_name='story_memories')
|
||||
op.drop_index(op.f('ix_story_memories_chapter_id'), table_name='story_memories')
|
||||
op.drop_table('story_memories')
|
||||
op.drop_index(op.f('ix_regeneration_tasks_user_id'), table_name='regeneration_tasks')
|
||||
op.drop_index(op.f('ix_regeneration_tasks_project_id'), table_name='regeneration_tasks')
|
||||
op.drop_index(op.f('ix_regeneration_tasks_chapter_id'), table_name='regeneration_tasks')
|
||||
op.drop_table('regeneration_tasks')
|
||||
op.drop_index(op.f('ix_plot_analysis_project_id'), table_name='plot_analysis')
|
||||
op.drop_index(op.f('ix_plot_analysis_chapter_id'), table_name='plot_analysis')
|
||||
op.drop_table('plot_analysis')
|
||||
op.drop_index(op.f('ix_organizations_project_id'), table_name='organizations')
|
||||
op.drop_table('organizations')
|
||||
op.drop_table('generation_history')
|
||||
op.drop_index(op.f('ix_character_relationships_relationship_type_id'), table_name='character_relationships')
|
||||
op.drop_index(op.f('ix_character_relationships_project_id'), table_name='character_relationships')
|
||||
op.drop_index(op.f('ix_character_relationships_character_to_id'), table_name='character_relationships')
|
||||
op.drop_index(op.f('ix_character_relationships_character_from_id'), table_name='character_relationships')
|
||||
op.drop_table('character_relationships')
|
||||
op.drop_index('idx_character_id', table_name='character_careers')
|
||||
op.drop_index('idx_character_career', table_name='character_careers')
|
||||
op.drop_index('idx_career_type', table_name='character_careers')
|
||||
op.drop_table('character_careers')
|
||||
op.drop_index('idx_status', table_name='analysis_tasks')
|
||||
op.drop_index('idx_chapter_id_created', table_name='analysis_tasks')
|
||||
op.drop_table('analysis_tasks')
|
||||
op.drop_table('project_default_styles')
|
||||
op.drop_table('characters')
|
||||
op.drop_table('chapters')
|
||||
op.drop_table('writing_styles')
|
||||
op.drop_table('outlines')
|
||||
op.drop_index('idx_type', table_name='careers')
|
||||
op.drop_index('idx_project_id', table_name='careers')
|
||||
op.drop_table('careers')
|
||||
op.drop_index(op.f('ix_users_username'), table_name='users')
|
||||
op.drop_index(op.f('ix_users_user_id'), table_name='users')
|
||||
op.drop_index(op.f('ix_users_linuxdo_id'), table_name='users')
|
||||
op.drop_table('users')
|
||||
op.drop_index(op.f('ix_user_passwords_user_id'), table_name='user_passwords')
|
||||
op.drop_table('user_passwords')
|
||||
op.drop_index(op.f('ix_settings_user_id'), table_name='settings')
|
||||
op.drop_index('idx_user_id', table_name='settings')
|
||||
op.drop_table('settings')
|
||||
op.drop_index(op.f('ix_relationship_types_id'), table_name='relationship_types')
|
||||
op.drop_table('relationship_types')
|
||||
op.drop_index(op.f('ix_prompt_templates_user_id'), table_name='prompt_templates')
|
||||
op.drop_index('idx_user_template', table_name='prompt_templates')
|
||||
op.drop_table('prompt_templates')
|
||||
op.drop_index(op.f('ix_projects_user_id'), table_name='projects')
|
||||
op.drop_table('projects')
|
||||
op.drop_index(op.f('ix_mcp_plugins_user_id'), table_name='mcp_plugins')
|
||||
op.drop_index('idx_user_plugin', table_name='mcp_plugins')
|
||||
op.drop_index('idx_user_enabled', table_name='mcp_plugins')
|
||||
op.drop_table('mcp_plugins')
|
||||
op.drop_table('batch_generation_tasks')
|
||||
# ### end Alembic commands ###
|
||||
@@ -0,0 +1,181 @@
|
||||
"""初始化预置数据
|
||||
|
||||
Revision ID: e411428f00c0
|
||||
Revises: ee0a189f1532
|
||||
Create Date: 2025-12-26 11:02:24.080526
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
from datetime import datetime
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import table, column, String, Integer, Float, Text, Boolean, DateTime
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = 'e411428f00c0'
|
||||
down_revision: Union[str, None] = 'ee0a189f1532'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""插入预置数据"""
|
||||
|
||||
# ==================== 1. 插入关系类型数据 ====================
|
||||
relationship_types_table = table(
|
||||
'relationship_types',
|
||||
column('name', String),
|
||||
column('category', String),
|
||||
column('reverse_name', String),
|
||||
column('intimacy_range', String),
|
||||
column('icon', String),
|
||||
column('description', Text),
|
||||
)
|
||||
|
||||
relationship_types_data = [
|
||||
# 家庭关系
|
||||
{"name": "父亲", "category": "family", "reverse_name": "子女", "intimacy_range": "high", "icon": "👨", "description": "父子/父女关系"},
|
||||
{"name": "母亲", "category": "family", "reverse_name": "子女", "intimacy_range": "high", "icon": "👩", "description": "母子/母女关系"},
|
||||
{"name": "兄弟", "category": "family", "reverse_name": "兄弟", "intimacy_range": "high", "icon": "👬", "description": "兄弟关系"},
|
||||
{"name": "姐妹", "category": "family", "reverse_name": "姐妹", "intimacy_range": "high", "icon": "👭", "description": "姐妹关系"},
|
||||
{"name": "子女", "category": "family", "reverse_name": "父母", "intimacy_range": "high", "icon": "👶", "description": "子女关系"},
|
||||
{"name": "配偶", "category": "family", "reverse_name": "配偶", "intimacy_range": "high", "icon": "💑", "description": "夫妻关系"},
|
||||
{"name": "恋人", "category": "family", "reverse_name": "恋人", "intimacy_range": "high", "icon": "💕", "description": "恋爱关系"},
|
||||
|
||||
# 社交关系
|
||||
{"name": "师父", "category": "social", "reverse_name": "徒弟", "intimacy_range": "high", "icon": "🎓", "description": "师徒关系(师父视角)"},
|
||||
{"name": "徒弟", "category": "social", "reverse_name": "师父", "intimacy_range": "high", "icon": "📚", "description": "师徒关系(徒弟视角)"},
|
||||
{"name": "朋友", "category": "social", "reverse_name": "朋友", "intimacy_range": "medium", "icon": "🤝", "description": "朋友关系"},
|
||||
{"name": "同学", "category": "social", "reverse_name": "同学", "intimacy_range": "medium", "icon": "🎒", "description": "同学关系"},
|
||||
{"name": "邻居", "category": "social", "reverse_name": "邻居", "intimacy_range": "low", "icon": "🏘️", "description": "邻居关系"},
|
||||
{"name": "知己", "category": "social", "reverse_name": "知己", "intimacy_range": "high", "icon": "💙", "description": "知心好友"},
|
||||
|
||||
# 职业关系
|
||||
{"name": "上司", "category": "professional", "reverse_name": "下属", "intimacy_range": "low", "icon": "👔", "description": "上下级关系(上司视角)"},
|
||||
{"name": "下属", "category": "professional", "reverse_name": "上司", "intimacy_range": "low", "icon": "💼", "description": "上下级关系(下属视角)"},
|
||||
{"name": "同事", "category": "professional", "reverse_name": "同事", "intimacy_range": "medium", "icon": "🤵", "description": "同事关系"},
|
||||
{"name": "合作伙伴", "category": "professional", "reverse_name": "合作伙伴", "intimacy_range": "medium", "icon": "🤜🤛", "description": "合作关系"},
|
||||
|
||||
# 敌对关系
|
||||
{"name": "敌人", "category": "hostile", "reverse_name": "敌人", "intimacy_range": "low", "icon": "⚔️", "description": "敌对关系"},
|
||||
{"name": "仇人", "category": "hostile", "reverse_name": "仇人", "intimacy_range": "low", "icon": "💢", "description": "仇恨关系"},
|
||||
{"name": "竞争对手", "category": "hostile", "reverse_name": "竞争对手", "intimacy_range": "low", "icon": "🎯", "description": "竞争关系"},
|
||||
{"name": "宿敌", "category": "hostile", "reverse_name": "宿敌", "intimacy_range": "low", "icon": "⚡", "description": "宿命之敌"},
|
||||
]
|
||||
|
||||
op.bulk_insert(relationship_types_table, relationship_types_data)
|
||||
print(f"✅ 已插入 {len(relationship_types_data)} 条关系类型数据")
|
||||
|
||||
|
||||
# ==================== 2. 插入全局写作风格预设 ====================
|
||||
# 注意:这里需要从 WritingStyleManager 获取预设配置
|
||||
# 为了避免导入应用代码,我们直接硬编码预设风格
|
||||
|
||||
writing_styles_table = table(
|
||||
'writing_styles',
|
||||
column('user_id', String),
|
||||
column('name', String),
|
||||
column('style_type', String),
|
||||
column('preset_id', String),
|
||||
column('description', Text),
|
||||
column('prompt_content', Text),
|
||||
column('order_index', Integer),
|
||||
)
|
||||
|
||||
writing_styles_data = [
|
||||
{
|
||||
"user_id": None, # NULL 表示全局预设
|
||||
"name": "自然流畅",
|
||||
"style_type": "preset",
|
||||
"preset_id": "natural",
|
||||
"description": "自然流畅的叙事风格,适合现代都市、现实题材",
|
||||
"prompt_content": """写作风格要求:
|
||||
1. 语言简洁明快,贴近现代口语
|
||||
2. 多用短句,节奏流畅
|
||||
3. 注重情感细节的自然流露
|
||||
4. 避免过度修饰和复杂句式""",
|
||||
"order_index": 1
|
||||
},
|
||||
{
|
||||
"user_id": None,
|
||||
"name": "古典优雅",
|
||||
"style_type": "preset",
|
||||
"preset_id": "classical",
|
||||
"description": "古典文雅的写作风格,适合古装、仙侠题材",
|
||||
"prompt_content": """写作风格要求:
|
||||
1. 使用文言、半文言或典雅的白话
|
||||
2. 适当运用古典诗词意象
|
||||
3. 注重意境营造和韵味
|
||||
4. 对话和描写保持古典美感""",
|
||||
"order_index": 2
|
||||
},
|
||||
{
|
||||
"user_id": None,
|
||||
"name": "现代简约",
|
||||
"style_type": "preset",
|
||||
"preset_id": "modern",
|
||||
"description": "现代简约风格,适合轻小说、网文快节奏叙事",
|
||||
"prompt_content": """写作风格要求:
|
||||
1. 语言直白简练,信息密度高
|
||||
2. 多用对话推进情节
|
||||
3. 避免冗长描写,突出关键动作
|
||||
4. 节奏明快,适合快速阅读""",
|
||||
"order_index": 3
|
||||
},
|
||||
{
|
||||
"user_id": None,
|
||||
"name": "文艺细腻",
|
||||
"style_type": "preset",
|
||||
"preset_id": "literary",
|
||||
"description": "文艺细腻风格,注重心理描写和氛围营造",
|
||||
"prompt_content": """写作风格要求:
|
||||
1. 注重心理活动和情感细节
|
||||
2. 善用环境描写烘托氛围
|
||||
3. 语言优美,富有文学性
|
||||
4. 适当使用比喻、象征等修辞手法""",
|
||||
"order_index": 4
|
||||
},
|
||||
{
|
||||
"user_id": None,
|
||||
"name": "紧张悬疑",
|
||||
"style_type": "preset",
|
||||
"preset_id": "suspense",
|
||||
"description": "紧张悬疑风格,适合推理、惊悚题材",
|
||||
"prompt_content": """写作风格要求:
|
||||
1. 营造紧张压迫的氛围
|
||||
2. 多用短句加快节奏
|
||||
3. 善于设置悬念和伏笔
|
||||
4. 注重细节描写,为推理埋下线索""",
|
||||
"order_index": 5
|
||||
},
|
||||
{
|
||||
"user_id": None,
|
||||
"name": "幽默诙谐",
|
||||
"style_type": "preset",
|
||||
"preset_id": "humorous",
|
||||
"description": "幽默诙谐风格,适合轻松搞笑题材",
|
||||
"prompt_content": """写作风格要求:
|
||||
1. 语言活泼风趣,善用俏皮话
|
||||
2. 注重对话的喜剧效果
|
||||
3. 适当夸张和反转制造笑点
|
||||
4. 保持轻松愉快的基调""",
|
||||
"order_index": 6
|
||||
},
|
||||
]
|
||||
|
||||
op.bulk_insert(writing_styles_table, writing_styles_data)
|
||||
print(f"✅ 已插入 {len(writing_styles_data)} 条全局写作风格预设")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""删除预置数据"""
|
||||
|
||||
# 删除写作风格预设(只删除全局预设)
|
||||
op.execute("DELETE FROM writing_styles WHERE user_id IS NULL")
|
||||
print("✅ 已删除全局写作风格预设")
|
||||
|
||||
# 删除关系类型
|
||||
op.execute("DELETE FROM relationship_types")
|
||||
print("✅ 已删除关系类型数据")
|
||||
+30
@@ -0,0 +1,30 @@
|
||||
"""添加system_prompt字段到settings表
|
||||
|
||||
Revision ID: a7e4408e1d5b
|
||||
Revises: e411428f00c0
|
||||
Create Date: 2025-12-27 15:41:22.310160
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = 'a7e4408e1d5b'
|
||||
down_revision: Union[str, None] = 'e411428f00c0'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column('settings', sa.Column('system_prompt', sa.Text(), nullable=True, comment='系统级别提示词,每次AI调用都会使用'))
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_column('settings', 'system_prompt')
|
||||
# ### end Alembic commands ###
|
||||
@@ -0,0 +1,2 @@
|
||||
# 此文件确保 versions 目录被 Git 追踪
|
||||
# 迁移版本文件将存放在此目录
|
||||
@@ -0,0 +1,102 @@
|
||||
"""Alembic 环境配置文件 - SQLite"""
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
from logging.config import fileConfig
|
||||
|
||||
from sqlalchemy import pool
|
||||
from sqlalchemy.engine import Connection
|
||||
from sqlalchemy.ext.asyncio import async_engine_from_config
|
||||
|
||||
from alembic import context
|
||||
|
||||
# 添加项目根目录到 Python 路径
|
||||
project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
if project_root not in sys.path:
|
||||
sys.path.insert(0, project_root)
|
||||
|
||||
# 导入应用配置
|
||||
from app.config import settings
|
||||
|
||||
# 导入 Base 和所有模型
|
||||
from app.database import Base
|
||||
from app.models import (
|
||||
Project, Outline, Character, Chapter, GenerationHistory,
|
||||
Settings, WritingStyle, ProjectDefaultStyle,
|
||||
RelationshipType, CharacterRelationship, Organization, OrganizationMember,
|
||||
StoryMemory, PlotAnalysis, AnalysisTask, BatchGenerationTask,
|
||||
RegenerationTask, Career, CharacterCareer, User, MCPPlugin, PromptTemplate
|
||||
)
|
||||
|
||||
# Alembic Config 对象
|
||||
config = context.config
|
||||
|
||||
# 设置数据库连接字符串(从环境变量读取)
|
||||
config.set_main_option("sqlalchemy.url", settings.database_url)
|
||||
|
||||
# 配置日志
|
||||
if config.config_file_name is not None:
|
||||
fileConfig(config.config_file_name)
|
||||
|
||||
# 设置 target_metadata 为应用的 Base.metadata
|
||||
target_metadata = Base.metadata
|
||||
|
||||
|
||||
def run_migrations_offline() -> None:
|
||||
"""在'离线'模式下运行迁移"""
|
||||
url = config.get_main_option("sqlalchemy.url")
|
||||
context.configure(
|
||||
url=url,
|
||||
target_metadata=target_metadata,
|
||||
literal_binds=True,
|
||||
dialect_opts={"paramstyle": "named"},
|
||||
compare_type=True,
|
||||
compare_server_default=True,
|
||||
render_as_batch=True, # SQLite 必须启用批处理模式
|
||||
)
|
||||
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
def do_run_migrations(connection: Connection) -> None:
|
||||
"""执行迁移的核心函数 - SQLite 专用"""
|
||||
context.configure(
|
||||
connection=connection,
|
||||
target_metadata=target_metadata,
|
||||
compare_type=True,
|
||||
compare_server_default=True,
|
||||
render_as_batch=True, # SQLite 必须启用批处理模式
|
||||
)
|
||||
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
async def run_async_migrations() -> None:
|
||||
"""在'在线'模式下运行异步迁移"""
|
||||
configuration = config.get_section(config.config_ini_section, {})
|
||||
configuration["sqlalchemy.url"] = settings.database_url
|
||||
|
||||
connectable = async_engine_from_config(
|
||||
configuration,
|
||||
prefix="sqlalchemy.",
|
||||
poolclass=pool.NullPool,
|
||||
)
|
||||
|
||||
async with connectable.connect() as connection:
|
||||
await connection.run_sync(do_run_migrations)
|
||||
|
||||
await connectable.dispose()
|
||||
|
||||
|
||||
def run_migrations_online() -> None:
|
||||
"""在'在线'模式下运行迁移"""
|
||||
asyncio.run(run_async_migrations())
|
||||
|
||||
|
||||
# 根据上下文选择运行模式
|
||||
if context.is_offline_mode():
|
||||
run_migrations_offline()
|
||||
else:
|
||||
run_migrations_online()
|
||||
@@ -0,0 +1,26 @@
|
||||
"""${message}
|
||||
|
||||
Revision ID: ${up_revision}
|
||||
Revises: ${down_revision | comma,n}
|
||||
Create Date: ${create_date}
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
${imports if imports else ""}
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = ${repr(up_revision)}
|
||||
down_revision: Union[str, None] = ${repr(down_revision)}
|
||||
branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
|
||||
depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
${upgrades if upgrades else "pass"}
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
${downgrades if downgrades else "pass"}
|
||||
@@ -0,0 +1,616 @@
|
||||
"""初始化SQLite数据库
|
||||
|
||||
Revision ID: fbeb1038c728
|
||||
Revises:
|
||||
Create Date: 2025-12-26 13:22:53.151546
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = 'fbeb1038c728'
|
||||
down_revision: Union[str, None] = None
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table('batch_generation_tasks',
|
||||
sa.Column('id', sa.String(length=36), nullable=False),
|
||||
sa.Column('project_id', sa.String(length=36), nullable=False, comment='项目ID'),
|
||||
sa.Column('user_id', sa.String(length=100), nullable=False, comment='用户ID'),
|
||||
sa.Column('start_chapter_number', sa.Integer(), nullable=False, comment='起始章节序号'),
|
||||
sa.Column('chapter_count', sa.Integer(), nullable=False, comment='生成章节数量'),
|
||||
sa.Column('chapter_ids', sa.JSON(), nullable=False, comment='待生成的章节ID列表'),
|
||||
sa.Column('style_id', sa.Integer(), nullable=True, comment='使用的写作风格ID'),
|
||||
sa.Column('target_word_count', sa.Integer(), nullable=True, comment='目标字数'),
|
||||
sa.Column('enable_analysis', sa.Boolean(), nullable=True, comment='是否启用同步分析'),
|
||||
sa.Column('status', sa.String(length=20), nullable=True, comment='任务状态: pending/running/completed/failed/cancelled'),
|
||||
sa.Column('total_chapters', sa.Integer(), nullable=True, comment='总章节数'),
|
||||
sa.Column('completed_chapters', sa.Integer(), nullable=True, comment='已完成章节数'),
|
||||
sa.Column('failed_chapters', sa.JSON(), nullable=True, comment='失败的章节信息列表'),
|
||||
sa.Column('current_chapter_id', sa.String(length=36), nullable=True, comment='当前正在生成的章节ID'),
|
||||
sa.Column('current_chapter_number', sa.Integer(), nullable=True, comment='当前正在生成的章节序号'),
|
||||
sa.Column('current_retry_count', sa.Integer(), nullable=True, comment='当前章节重试次数'),
|
||||
sa.Column('max_retries', sa.Integer(), nullable=True, comment='最大重试次数'),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=True, comment='创建时间'),
|
||||
sa.Column('started_at', sa.DateTime(), nullable=True, comment='开始时间'),
|
||||
sa.Column('completed_at', sa.DateTime(), nullable=True, comment='完成时间'),
|
||||
sa.Column('error_message', sa.String(length=500), nullable=True, comment='错误信息'),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_table('mcp_plugins',
|
||||
sa.Column('id', sa.String(length=36), nullable=False),
|
||||
sa.Column('user_id', sa.String(length=50), nullable=False, comment='用户ID'),
|
||||
sa.Column('plugin_name', sa.String(length=100), nullable=False, comment='插件名称(唯一标识)'),
|
||||
sa.Column('display_name', sa.String(length=200), nullable=False, comment='显示名称'),
|
||||
sa.Column('description', sa.Text(), nullable=True, comment='插件描述'),
|
||||
sa.Column('plugin_type', sa.String(length=50), nullable=True, comment='插件类型:http/stdio'),
|
||||
sa.Column('server_url', sa.String(length=500), nullable=True, comment='服务器URL(HTTP类型)'),
|
||||
sa.Column('command', sa.String(length=500), nullable=True, comment='启动命令(stdio类型)'),
|
||||
sa.Column('args', sa.JSON(), nullable=True, comment='命令参数(stdio类型)'),
|
||||
sa.Column('env', sa.JSON(), nullable=True, comment='环境变量'),
|
||||
sa.Column('headers', sa.JSON(), nullable=True, comment='HTTP请求头'),
|
||||
sa.Column('config', sa.JSON(), nullable=True, comment='插件特定配置(JSON)'),
|
||||
sa.Column('tools', sa.JSON(), nullable=True, comment='提供的工具列表'),
|
||||
sa.Column('enabled', sa.Boolean(), nullable=True, comment='是否启用'),
|
||||
sa.Column('status', sa.String(length=50), nullable=True, comment='状态:active/inactive/error'),
|
||||
sa.Column('last_error', sa.Text(), nullable=True, comment='最后错误信息'),
|
||||
sa.Column('last_test_at', sa.DateTime(), nullable=True, comment='最后测试时间'),
|
||||
sa.Column('category', sa.String(length=100), nullable=True, comment='分类'),
|
||||
sa.Column('sort_order', sa.Integer(), nullable=True, comment='排序顺序'),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=True, comment='创建时间'),
|
||||
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=True, comment='更新时间'),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
with op.batch_alter_table('mcp_plugins', schema=None) as batch_op:
|
||||
batch_op.create_index('idx_user_enabled', ['user_id', 'enabled'], unique=False)
|
||||
batch_op.create_index('idx_user_plugin', ['user_id', 'plugin_name'], unique=True)
|
||||
batch_op.create_index(batch_op.f('ix_mcp_plugins_user_id'), ['user_id'], unique=False)
|
||||
|
||||
op.create_table('projects',
|
||||
sa.Column('id', sa.String(length=36), nullable=False),
|
||||
sa.Column('user_id', sa.String(length=100), nullable=False, comment='用户ID'),
|
||||
sa.Column('title', sa.String(length=200), nullable=False, comment='项目标题'),
|
||||
sa.Column('description', sa.Text(), nullable=True, comment='项目简介'),
|
||||
sa.Column('theme', sa.Text(), nullable=True, comment='主题'),
|
||||
sa.Column('genre', sa.String(length=50), nullable=True, comment='小说类型'),
|
||||
sa.Column('target_words', sa.Integer(), nullable=True, comment='目标字数'),
|
||||
sa.Column('current_words', sa.Integer(), nullable=True, comment='当前字数'),
|
||||
sa.Column('status', sa.String(length=20), nullable=True, comment='创作状态'),
|
||||
sa.Column('wizard_status', sa.String(length=20), nullable=True, comment='向导完成状态: incomplete/completed'),
|
||||
sa.Column('wizard_step', sa.Integer(), nullable=True, comment='向导当前步骤: 0-4'),
|
||||
sa.Column('outline_mode', sa.String(length=20), nullable=False, comment='大纲章节模式: one-to-one(传统模式) 或 one-to-many(细化模式)'),
|
||||
sa.Column('world_time_period', sa.Text(), nullable=True, comment='时间背景'),
|
||||
sa.Column('world_location', sa.Text(), nullable=True, comment='地理位置'),
|
||||
sa.Column('world_atmosphere', sa.Text(), nullable=True, comment='氛围基调'),
|
||||
sa.Column('world_rules', sa.Text(), nullable=True, comment='世界规则'),
|
||||
sa.Column('chapter_count', sa.Integer(), nullable=True, comment='章节数量'),
|
||||
sa.Column('narrative_perspective', sa.String(length=50), nullable=True, comment='叙事视角:first_person/third_person/omniscient'),
|
||||
sa.Column('character_count', sa.Integer(), nullable=True, comment='角色数量'),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=True, comment='创建时间'),
|
||||
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=True, comment='更新时间'),
|
||||
sa.CheckConstraint("outline_mode IN ('one-to-one', 'one-to-many')", name='check_outline_mode'),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
with op.batch_alter_table('projects', schema=None) as batch_op:
|
||||
batch_op.create_index(batch_op.f('ix_projects_user_id'), ['user_id'], unique=False)
|
||||
|
||||
op.create_table('prompt_templates',
|
||||
sa.Column('id', sa.String(length=36), nullable=False),
|
||||
sa.Column('user_id', sa.String(length=50), nullable=False, comment='用户ID'),
|
||||
sa.Column('template_key', sa.String(length=100), nullable=False, comment='模板键名'),
|
||||
sa.Column('template_name', sa.String(length=200), nullable=False, comment='模板显示名称'),
|
||||
sa.Column('template_content', sa.Text(), nullable=False, comment='模板内容'),
|
||||
sa.Column('description', sa.Text(), nullable=True, comment='模板描述'),
|
||||
sa.Column('category', sa.String(length=50), nullable=True, comment='模板分类'),
|
||||
sa.Column('parameters', sa.Text(), nullable=True, comment='模板参数定义(JSON)'),
|
||||
sa.Column('is_active', sa.Boolean(), nullable=True, comment='是否启用'),
|
||||
sa.Column('is_system_default', sa.Boolean(), nullable=True, comment='是否为系统默认模板'),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=True, comment='创建时间'),
|
||||
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=True, comment='更新时间'),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
with op.batch_alter_table('prompt_templates', schema=None) as batch_op:
|
||||
batch_op.create_index('idx_user_template', ['user_id', 'template_key'], unique=True)
|
||||
batch_op.create_index(batch_op.f('ix_prompt_templates_user_id'), ['user_id'], unique=False)
|
||||
|
||||
op.create_table('relationship_types',
|
||||
sa.Column('id', sa.Integer(), autoincrement=True, nullable=False),
|
||||
sa.Column('name', sa.String(length=50), nullable=False, comment='关系名称'),
|
||||
sa.Column('category', sa.String(length=20), nullable=False, comment='分类:family/social/hostile/professional'),
|
||||
sa.Column('reverse_name', sa.String(length=50), nullable=True, comment='反向关系名称'),
|
||||
sa.Column('intimacy_range', sa.String(length=20), nullable=True, comment='亲密度范围:high/medium/low'),
|
||||
sa.Column('icon', sa.String(length=50), nullable=True, comment='图标标识'),
|
||||
sa.Column('description', sa.Text(), nullable=True, comment='关系描述'),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=True, comment='创建时间'),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
with op.batch_alter_table('relationship_types', schema=None) as batch_op:
|
||||
batch_op.create_index(batch_op.f('ix_relationship_types_id'), ['id'], unique=False)
|
||||
|
||||
op.create_table('settings',
|
||||
sa.Column('id', sa.String(length=36), nullable=False),
|
||||
sa.Column('user_id', sa.String(length=50), nullable=False, comment='用户ID'),
|
||||
sa.Column('api_provider', sa.String(length=50), nullable=True, comment='API提供商'),
|
||||
sa.Column('api_key', sa.String(length=500), nullable=True, comment='API密钥'),
|
||||
sa.Column('api_base_url', sa.String(length=500), nullable=True, comment='自定义API地址'),
|
||||
sa.Column('llm_model', sa.String(length=100), nullable=True, comment='模型名称'),
|
||||
sa.Column('temperature', sa.Float(), nullable=True, comment='温度参数'),
|
||||
sa.Column('max_tokens', sa.Integer(), nullable=True, comment='最大token数'),
|
||||
sa.Column('preferences', sa.Text(), nullable=True, comment='其他偏好设置(JSON)'),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=True, comment='创建时间'),
|
||||
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=True, comment='更新时间'),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
with op.batch_alter_table('settings', schema=None) as batch_op:
|
||||
batch_op.create_index('idx_user_id', ['user_id'], unique=False)
|
||||
batch_op.create_index(batch_op.f('ix_settings_user_id'), ['user_id'], unique=True)
|
||||
|
||||
op.create_table('user_passwords',
|
||||
sa.Column('user_id', sa.String(length=100), nullable=False, comment='用户ID'),
|
||||
sa.Column('username', sa.String(length=100), nullable=False, comment='用户名'),
|
||||
sa.Column('password_hash', sa.String(length=64), nullable=False, comment='密码哈希(SHA256)'),
|
||||
sa.Column('has_custom_password', sa.Boolean(), nullable=True, comment='是否为自定义密码'),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=True, comment='创建时间'),
|
||||
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=True, comment='更新时间'),
|
||||
sa.PrimaryKeyConstraint('user_id')
|
||||
)
|
||||
with op.batch_alter_table('user_passwords', schema=None) as batch_op:
|
||||
batch_op.create_index(batch_op.f('ix_user_passwords_user_id'), ['user_id'], unique=False)
|
||||
|
||||
op.create_table('users',
|
||||
sa.Column('user_id', sa.String(length=100), nullable=False, comment='用户ID,格式:linuxdo_{id} 或 local_{id}'),
|
||||
sa.Column('username', sa.String(length=100), nullable=False, comment='用户名'),
|
||||
sa.Column('display_name', sa.String(length=200), nullable=False, comment='显示名称'),
|
||||
sa.Column('avatar_url', sa.String(length=500), nullable=True, comment='头像URL'),
|
||||
sa.Column('trust_level', sa.Integer(), nullable=True, comment='信任等级(仅用于显示)'),
|
||||
sa.Column('is_admin', sa.Boolean(), nullable=True, comment='是否为管理员'),
|
||||
sa.Column('linuxdo_id', sa.String(length=100), nullable=False, comment='LinuxDO用户ID或本地用户ID'),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=True, comment='创建时间'),
|
||||
sa.Column('last_login', sa.DateTime(timezone=True), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=True, comment='最后登录时间'),
|
||||
sa.PrimaryKeyConstraint('user_id')
|
||||
)
|
||||
with op.batch_alter_table('users', schema=None) as batch_op:
|
||||
batch_op.create_index(batch_op.f('ix_users_linuxdo_id'), ['linuxdo_id'], unique=True)
|
||||
batch_op.create_index(batch_op.f('ix_users_user_id'), ['user_id'], unique=False)
|
||||
batch_op.create_index(batch_op.f('ix_users_username'), ['username'], unique=False)
|
||||
|
||||
op.create_table('careers',
|
||||
sa.Column('id', sa.String(length=36), nullable=False),
|
||||
sa.Column('project_id', sa.String(length=36), nullable=False),
|
||||
sa.Column('name', sa.String(length=100), nullable=False, comment='职业名称'),
|
||||
sa.Column('type', sa.String(length=20), nullable=False, comment='职业类型: main(主职业)/sub(副职业)'),
|
||||
sa.Column('description', sa.Text(), nullable=True, comment='职业描述'),
|
||||
sa.Column('category', sa.String(length=50), nullable=True, comment='职业分类(如:战斗系、生产系、辅助系)'),
|
||||
sa.Column('stages', sa.Text(), nullable=False, comment="职业阶段列表(JSON): [{level:1, name:'', description:''}, ...]"),
|
||||
sa.Column('max_stage', sa.Integer(), nullable=False, comment='最大阶段数'),
|
||||
sa.Column('requirements', sa.Text(), nullable=True, comment='职业要求/限制'),
|
||||
sa.Column('special_abilities', sa.Text(), nullable=True, comment='特殊能力描述'),
|
||||
sa.Column('worldview_rules', sa.Text(), nullable=True, comment='世界观规则关联'),
|
||||
sa.Column('attribute_bonuses', sa.Text(), nullable=True, comment="属性加成(JSON): {strength: '+10%', intelligence: '+5%'}"),
|
||||
sa.Column('source', sa.String(length=20), nullable=True, comment='来源: ai/manual'),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=True, comment='创建时间'),
|
||||
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=True, comment='更新时间'),
|
||||
sa.ForeignKeyConstraint(['project_id'], ['projects.id'], ondelete='CASCADE'),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
with op.batch_alter_table('careers', schema=None) as batch_op:
|
||||
batch_op.create_index('idx_project_id', ['project_id'], unique=False)
|
||||
batch_op.create_index('idx_type', ['type'], unique=False)
|
||||
|
||||
op.create_table('outlines',
|
||||
sa.Column('id', sa.String(length=36), nullable=False),
|
||||
sa.Column('project_id', sa.String(length=36), nullable=False),
|
||||
sa.Column('title', sa.String(length=200), nullable=False, comment='大纲标题'),
|
||||
sa.Column('content', sa.Text(), nullable=True, comment='大纲内容'),
|
||||
sa.Column('structure', sa.Text(), nullable=True, comment='结构化大纲数据(JSON)'),
|
||||
sa.Column('order_index', sa.Integer(), nullable=True, comment='排序序号'),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=True, comment='创建时间'),
|
||||
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=True, comment='更新时间'),
|
||||
sa.ForeignKeyConstraint(['project_id'], ['projects.id'], ondelete='CASCADE'),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_table('writing_styles',
|
||||
sa.Column('id', sa.Integer(), autoincrement=True, nullable=False),
|
||||
sa.Column('user_id', sa.String(length=255), nullable=True, comment='所属用户ID(NULL表示全局预设风格)'),
|
||||
sa.Column('name', sa.String(length=100), nullable=False, comment='风格名称'),
|
||||
sa.Column('style_type', sa.String(length=50), nullable=False, comment='风格类型:preset/custom'),
|
||||
sa.Column('preset_id', sa.String(length=50), nullable=True, comment='预设风格ID:natural/classical/modern等'),
|
||||
sa.Column('description', sa.Text(), nullable=True, comment='风格描述'),
|
||||
sa.Column('prompt_content', sa.Text(), nullable=False, comment='风格提示词内容'),
|
||||
sa.Column('order_index', sa.Integer(), nullable=True, comment='排序序号'),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=True, comment='创建时间'),
|
||||
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=True, comment='更新时间'),
|
||||
sa.ForeignKeyConstraint(['user_id'], ['users.user_id'], ondelete='CASCADE'),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_table('chapters',
|
||||
sa.Column('id', sa.String(length=36), nullable=False),
|
||||
sa.Column('project_id', sa.String(length=36), nullable=False),
|
||||
sa.Column('chapter_number', sa.Integer(), nullable=False, comment='章节序号'),
|
||||
sa.Column('title', sa.String(length=200), nullable=False, comment='章节标题'),
|
||||
sa.Column('content', sa.Text(), nullable=True, comment='章节内容'),
|
||||
sa.Column('summary', sa.Text(), nullable=True, comment='章节摘要'),
|
||||
sa.Column('word_count', sa.Integer(), nullable=True, comment='字数统计'),
|
||||
sa.Column('status', sa.String(length=20), nullable=True, comment='章节状态'),
|
||||
sa.Column('outline_id', sa.String(length=36), nullable=True, comment='关联的大纲ID'),
|
||||
sa.Column('sub_index', sa.Integer(), nullable=True, comment='大纲下的子章节序号'),
|
||||
sa.Column('expansion_plan', sa.Text(), nullable=True, comment='展开规划详情(JSON): 包含key_events, character_focus, emotional_tone等'),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=True, comment='创建时间'),
|
||||
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=True, comment='更新时间'),
|
||||
sa.ForeignKeyConstraint(['outline_id'], ['outlines.id'], ondelete='SET NULL'),
|
||||
sa.ForeignKeyConstraint(['project_id'], ['projects.id'], ondelete='CASCADE'),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_table('characters',
|
||||
sa.Column('id', sa.String(length=36), nullable=False),
|
||||
sa.Column('project_id', sa.String(length=36), nullable=False),
|
||||
sa.Column('name', sa.String(length=100), nullable=False, comment='角色/组织名称'),
|
||||
sa.Column('age', sa.String(length=50), nullable=True, comment='年龄'),
|
||||
sa.Column('gender', sa.String(length=50), nullable=True, comment='性别'),
|
||||
sa.Column('is_organization', sa.Boolean(), nullable=True, comment='是否为组织'),
|
||||
sa.Column('role_type', sa.String(length=50), nullable=True, comment='角色类型'),
|
||||
sa.Column('personality', sa.Text(), nullable=True, comment='性格特点/组织特性'),
|
||||
sa.Column('background', sa.Text(), nullable=True, comment='背景故事'),
|
||||
sa.Column('appearance', sa.Text(), nullable=True, comment='外貌描述'),
|
||||
sa.Column('relationships', sa.Text(), nullable=True, comment='人物关系(JSON)'),
|
||||
sa.Column('organization_type', sa.String(length=100), nullable=True, comment='组织类型'),
|
||||
sa.Column('organization_purpose', sa.String(length=500), nullable=True, comment='组织目的'),
|
||||
sa.Column('organization_members', sa.Text(), nullable=True, comment='组织成员(JSON)'),
|
||||
sa.Column('main_career_id', sa.String(length=36), nullable=True, comment='主职业ID'),
|
||||
sa.Column('main_career_stage', sa.Integer(), nullable=True, comment='主职业当前阶段'),
|
||||
sa.Column('sub_careers', sa.Text(), nullable=True, comment='副职业列表(JSON): [{"career_id": "xxx", "stage": 3}, ...]'),
|
||||
sa.Column('avatar_url', sa.String(length=500), nullable=True, comment='头像URL'),
|
||||
sa.Column('traits', sa.Text(), nullable=True, comment='特征标签(JSON)'),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=True, comment='创建时间'),
|
||||
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=True, comment='更新时间'),
|
||||
sa.ForeignKeyConstraint(['main_career_id'], ['careers.id'], ondelete='SET NULL'),
|
||||
sa.ForeignKeyConstraint(['project_id'], ['projects.id'], ondelete='CASCADE'),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_table('project_default_styles',
|
||||
sa.Column('id', sa.Integer(), autoincrement=True, nullable=False),
|
||||
sa.Column('project_id', sa.String(length=36), nullable=False, comment='项目ID'),
|
||||
sa.Column('style_id', sa.Integer(), nullable=False, comment='风格ID'),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=True, comment='创建时间'),
|
||||
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=True, comment='更新时间'),
|
||||
sa.ForeignKeyConstraint(['project_id'], ['projects.id'], ondelete='CASCADE'),
|
||||
sa.ForeignKeyConstraint(['style_id'], ['writing_styles.id'], ondelete='CASCADE'),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
sa.UniqueConstraint('project_id', name='uix_project_default_style')
|
||||
)
|
||||
op.create_table('analysis_tasks',
|
||||
sa.Column('id', sa.String(length=36), nullable=False, comment='任务ID'),
|
||||
sa.Column('chapter_id', sa.String(length=36), nullable=False, comment='章节ID'),
|
||||
sa.Column('user_id', sa.String(length=50), nullable=False, comment='用户ID'),
|
||||
sa.Column('project_id', sa.String(length=36), nullable=False, comment='项目ID'),
|
||||
sa.Column('status', sa.String(length=20), nullable=False, comment='任务状态: pending/running/completed/failed'),
|
||||
sa.Column('progress', sa.Integer(), nullable=True, comment='进度 0-100'),
|
||||
sa.Column('error_message', sa.Text(), nullable=True, comment='错误信息'),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=True, comment='创建时间'),
|
||||
sa.Column('started_at', sa.DateTime(), nullable=True, comment='开始执行时间'),
|
||||
sa.Column('completed_at', sa.DateTime(), nullable=True, comment='完成时间'),
|
||||
sa.ForeignKeyConstraint(['chapter_id'], ['chapters.id'], ondelete='CASCADE'),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
with op.batch_alter_table('analysis_tasks', schema=None) as batch_op:
|
||||
batch_op.create_index('idx_chapter_id_created', ['chapter_id', 'created_at'], unique=False)
|
||||
batch_op.create_index('idx_status', ['status'], unique=False)
|
||||
|
||||
op.create_table('character_careers',
|
||||
sa.Column('id', sa.String(length=36), nullable=False),
|
||||
sa.Column('character_id', sa.String(length=36), nullable=False),
|
||||
sa.Column('career_id', sa.String(length=36), nullable=False),
|
||||
sa.Column('career_type', sa.String(length=20), nullable=False, comment='main(主职业)/sub(副职业)'),
|
||||
sa.Column('current_stage', sa.Integer(), nullable=False, comment='当前阶段(对应职业中的数值)'),
|
||||
sa.Column('stage_progress', sa.Integer(), nullable=True, comment='阶段内进度(0-100)'),
|
||||
sa.Column('started_at', sa.String(length=100), nullable=True, comment='开始修炼时间(小说时间线)'),
|
||||
sa.Column('reached_current_stage_at', sa.String(length=100), nullable=True, comment='到达当前阶段时间'),
|
||||
sa.Column('notes', sa.Text(), nullable=True, comment='备注(如:修炼心得、特殊事件)'),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=True, comment='创建时间'),
|
||||
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=True, comment='更新时间'),
|
||||
sa.ForeignKeyConstraint(['career_id'], ['careers.id'], ondelete='CASCADE'),
|
||||
sa.ForeignKeyConstraint(['character_id'], ['characters.id'], ondelete='CASCADE'),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
with op.batch_alter_table('character_careers', schema=None) as batch_op:
|
||||
batch_op.create_index('idx_career_type', ['career_type'], unique=False)
|
||||
batch_op.create_index('idx_character_career', ['character_id', 'career_id'], unique=True)
|
||||
batch_op.create_index('idx_character_id', ['character_id'], unique=False)
|
||||
|
||||
op.create_table('character_relationships',
|
||||
sa.Column('id', sa.String(length=36), nullable=False, comment='关系ID'),
|
||||
sa.Column('project_id', sa.String(length=36), nullable=False, comment='项目ID'),
|
||||
sa.Column('character_from_id', sa.String(length=36), nullable=False, comment='角色A的ID'),
|
||||
sa.Column('character_to_id', sa.String(length=36), nullable=False, comment='角色B的ID'),
|
||||
sa.Column('relationship_type_id', sa.Integer(), nullable=True, comment='关系类型ID'),
|
||||
sa.Column('relationship_name', sa.String(length=100), nullable=True, comment='自定义关系名称'),
|
||||
sa.Column('intimacy_level', sa.Integer(), nullable=True, comment='亲密度:-100到100'),
|
||||
sa.Column('status', sa.String(length=20), nullable=True, comment='状态:active/broken/past/complicated'),
|
||||
sa.Column('description', sa.Text(), nullable=True, comment='关系详细描述'),
|
||||
sa.Column('started_at', sa.String(length=100), nullable=True, comment='关系开始时间(故事时间)'),
|
||||
sa.Column('ended_at', sa.String(length=100), nullable=True, comment='关系结束时间(故事时间)'),
|
||||
sa.Column('source', sa.String(length=20), nullable=True, comment='来源:ai/manual/imported'),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=True, comment='创建时间'),
|
||||
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=True, comment='更新时间'),
|
||||
sa.ForeignKeyConstraint(['character_from_id'], ['characters.id'], ondelete='CASCADE'),
|
||||
sa.ForeignKeyConstraint(['character_to_id'], ['characters.id'], ondelete='CASCADE'),
|
||||
sa.ForeignKeyConstraint(['project_id'], ['projects.id'], ondelete='CASCADE'),
|
||||
sa.ForeignKeyConstraint(['relationship_type_id'], ['relationship_types.id'], ),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
with op.batch_alter_table('character_relationships', schema=None) as batch_op:
|
||||
batch_op.create_index(batch_op.f('ix_character_relationships_character_from_id'), ['character_from_id'], unique=False)
|
||||
batch_op.create_index(batch_op.f('ix_character_relationships_character_to_id'), ['character_to_id'], unique=False)
|
||||
batch_op.create_index(batch_op.f('ix_character_relationships_project_id'), ['project_id'], unique=False)
|
||||
batch_op.create_index(batch_op.f('ix_character_relationships_relationship_type_id'), ['relationship_type_id'], unique=False)
|
||||
|
||||
op.create_table('generation_history',
|
||||
sa.Column('id', sa.String(length=36), nullable=False),
|
||||
sa.Column('project_id', sa.String(length=36), nullable=False),
|
||||
sa.Column('chapter_id', sa.String(length=36), nullable=True),
|
||||
sa.Column('prompt', sa.Text(), nullable=True, comment='使用的提示词'),
|
||||
sa.Column('generated_content', sa.Text(), nullable=True, comment='生成的内容'),
|
||||
sa.Column('model', sa.String(length=50), nullable=True, comment='使用的模型'),
|
||||
sa.Column('tokens_used', sa.Integer(), nullable=True, comment='消耗的token数'),
|
||||
sa.Column('generation_time', sa.Float(), nullable=True, comment='生成耗时(秒)'),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=True, comment='创建时间'),
|
||||
sa.ForeignKeyConstraint(['chapter_id'], ['chapters.id'], ondelete='SET NULL'),
|
||||
sa.ForeignKeyConstraint(['project_id'], ['projects.id'], ondelete='CASCADE'),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_table('organizations',
|
||||
sa.Column('id', sa.String(length=36), nullable=False, comment='组织ID'),
|
||||
sa.Column('character_id', sa.String(length=36), nullable=False, comment='关联的角色ID'),
|
||||
sa.Column('project_id', sa.String(length=36), nullable=False, comment='项目ID'),
|
||||
sa.Column('parent_org_id', sa.String(length=36), nullable=True, comment='父组织ID'),
|
||||
sa.Column('level', sa.Integer(), nullable=True, comment='组织层级'),
|
||||
sa.Column('power_level', sa.Integer(), nullable=True, comment='势力等级:0-100'),
|
||||
sa.Column('member_count', sa.Integer(), nullable=True, comment='成员数量'),
|
||||
sa.Column('location', sa.Text(), nullable=True, comment='所在地'),
|
||||
sa.Column('motto', sa.String(length=200), nullable=True, comment='宗旨/口号'),
|
||||
sa.Column('color', sa.String(length=100), nullable=True, comment='代表颜色'),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=True, comment='创建时间'),
|
||||
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=True, comment='更新时间'),
|
||||
sa.ForeignKeyConstraint(['character_id'], ['characters.id'], ondelete='CASCADE'),
|
||||
sa.ForeignKeyConstraint(['parent_org_id'], ['organizations.id'], ondelete='SET NULL'),
|
||||
sa.ForeignKeyConstraint(['project_id'], ['projects.id'], ondelete='CASCADE'),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
sa.UniqueConstraint('character_id')
|
||||
)
|
||||
with op.batch_alter_table('organizations', schema=None) as batch_op:
|
||||
batch_op.create_index(batch_op.f('ix_organizations_project_id'), ['project_id'], unique=False)
|
||||
|
||||
op.create_table('plot_analysis',
|
||||
sa.Column('id', sa.String(length=36), nullable=False),
|
||||
sa.Column('project_id', sa.String(length=36), nullable=False),
|
||||
sa.Column('chapter_id', sa.String(length=36), nullable=False),
|
||||
sa.Column('plot_stage', sa.String(length=50), nullable=True, comment='剧情阶段: 开端/发展/高潮/结局/过渡'),
|
||||
sa.Column('conflict_level', sa.Integer(), nullable=True, comment='冲突强度 1-10'),
|
||||
sa.Column('conflict_types', sa.JSON(), nullable=True, comment="冲突类型列表: ['人与人', '人与己', '人与环境']"),
|
||||
sa.Column('emotional_tone', sa.String(length=100), nullable=True, comment='主导情感: 紧张/温馨/悲伤/激昂/平静'),
|
||||
sa.Column('emotional_intensity', sa.Float(), nullable=True, comment='情感强度 0.0-1.0'),
|
||||
sa.Column('emotional_curve', sa.JSON(), nullable=True, comment='情感曲线: {start: 0.3, middle: 0.7, end: 0.5}'),
|
||||
sa.Column('hooks', sa.JSON(), nullable=True, comment='钩子列表 - 吸引读者的元素: [\n {\n "type": "悬念|情感|冲突|认知",\n "content": "具体内容",\n "strength": 8,\n "position": "开头|中段|结尾"\n }\n ]'),
|
||||
sa.Column('hooks_count', sa.Integer(), nullable=True, comment='钩子数量'),
|
||||
sa.Column('hooks_avg_strength', sa.Float(), nullable=True, comment='钩子平均强度'),
|
||||
sa.Column('foreshadows', sa.JSON(), nullable=True, comment='伏笔列表: [\n {\n "content": "伏笔内容",\n "type": "planted|resolved",\n "strength": 7,\n "subtlety": 8,\n "reference_chapter": 3\n }\n ]'),
|
||||
sa.Column('foreshadows_planted', sa.Integer(), nullable=True, comment='本章埋下的伏笔数量'),
|
||||
sa.Column('foreshadows_resolved', sa.Integer(), nullable=True, comment='本章回收的伏笔数量'),
|
||||
sa.Column('plot_points', sa.JSON(), nullable=True, comment='情节点列表: [\n {\n "content": "情节点描述",\n "importance": 0.9,\n "type": "revelation|conflict|resolution|transition",\n "impact": "对故事的影响描述"\n }\n ]'),
|
||||
sa.Column('plot_points_count', sa.Integer(), nullable=True, comment='情节点数量'),
|
||||
sa.Column('character_states', sa.JSON(), nullable=True, comment='角色状态变化: [\n {\n "character_id": "xxx",\n "character_name": "张三",\n "state_before": "犹豫不决",\n "state_after": "坚定信念",\n "psychological_change": "内心描述",\n "key_event": "触发事件",\n "relationship_changes": {"李四": "关系变化"}\n }\n ]'),
|
||||
sa.Column('scenes', sa.JSON(), nullable=True, comment="场景列表: [{location: '地点', atmosphere: '氛围', duration: '时长'}]"),
|
||||
sa.Column('pacing', sa.String(length=50), nullable=True, comment='节奏: slow|moderate|fast|varied'),
|
||||
sa.Column('overall_quality_score', sa.Float(), nullable=True, comment='整体质量评分 0.0-10.0'),
|
||||
sa.Column('pacing_score', sa.Float(), nullable=True, comment='节奏评分 0.0-10.0'),
|
||||
sa.Column('engagement_score', sa.Float(), nullable=True, comment='吸引力评分 0.0-10.0'),
|
||||
sa.Column('coherence_score', sa.Float(), nullable=True, comment='连贯性评分 0.0-10.0'),
|
||||
sa.Column('analysis_report', sa.Text(), nullable=True, comment='完整的文字分析报告'),
|
||||
sa.Column('suggestions', sa.JSON(), nullable=True, comment="改进建议列表: ['建议1', '建议2']"),
|
||||
sa.Column('word_count', sa.Integer(), nullable=True, comment='章节字数'),
|
||||
sa.Column('dialogue_ratio', sa.Float(), nullable=True, comment='对话占比 0.0-1.0'),
|
||||
sa.Column('description_ratio', sa.Float(), nullable=True, comment='描写占比 0.0-1.0'),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=True, comment='分析时间'),
|
||||
sa.ForeignKeyConstraint(['chapter_id'], ['chapters.id'], ondelete='CASCADE'),
|
||||
sa.ForeignKeyConstraint(['project_id'], ['projects.id'], ondelete='CASCADE'),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
with op.batch_alter_table('plot_analysis', schema=None) as batch_op:
|
||||
batch_op.create_index(batch_op.f('ix_plot_analysis_chapter_id'), ['chapter_id'], unique=True)
|
||||
batch_op.create_index(batch_op.f('ix_plot_analysis_project_id'), ['project_id'], unique=False)
|
||||
|
||||
op.create_table('regeneration_tasks',
|
||||
sa.Column('id', sa.String(length=36), nullable=False),
|
||||
sa.Column('chapter_id', sa.String(length=36), nullable=False),
|
||||
sa.Column('analysis_id', sa.String(length=36), nullable=True, comment='关联的分析结果ID'),
|
||||
sa.Column('user_id', sa.String(length=50), nullable=False),
|
||||
sa.Column('project_id', sa.String(length=36), nullable=False),
|
||||
sa.Column('modification_instructions', sa.Text(), nullable=False, comment='综合修改指令'),
|
||||
sa.Column('original_suggestions', sa.JSON(), nullable=True, comment='来自分析的原始建议列表'),
|
||||
sa.Column('selected_suggestion_indices', sa.JSON(), nullable=True, comment='用户选择的建议索引'),
|
||||
sa.Column('custom_instructions', sa.Text(), nullable=True, comment='用户自定义修改意见'),
|
||||
sa.Column('style_id', sa.Integer(), nullable=True, comment='写作风格ID'),
|
||||
sa.Column('target_word_count', sa.Integer(), nullable=True, comment='目标字数'),
|
||||
sa.Column('focus_areas', sa.JSON(), nullable=True, comment='重点优化方向'),
|
||||
sa.Column('preserve_elements', sa.JSON(), nullable=True, comment='需要保留的元素配置'),
|
||||
sa.Column('status', sa.String(length=20), nullable=True, comment='pending/running/completed/failed'),
|
||||
sa.Column('progress', sa.Integer(), nullable=True, comment='进度 0-100'),
|
||||
sa.Column('error_message', sa.Text(), nullable=True),
|
||||
sa.Column('original_content', sa.Text(), nullable=True, comment='原始章节内容快照'),
|
||||
sa.Column('original_word_count', sa.Integer(), nullable=True, comment='原始字数'),
|
||||
sa.Column('regenerated_content', sa.Text(), nullable=True, comment='重新生成的内容'),
|
||||
sa.Column('regenerated_word_count', sa.Integer(), nullable=True, comment='新内容字数'),
|
||||
sa.Column('version_number', sa.Integer(), nullable=True, comment='版本号'),
|
||||
sa.Column('version_note', sa.String(length=500), nullable=True, comment='版本说明'),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=True),
|
||||
sa.Column('started_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('completed_at', sa.DateTime(), nullable=True),
|
||||
sa.ForeignKeyConstraint(['chapter_id'], ['chapters.id'], ondelete='CASCADE'),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
with op.batch_alter_table('regeneration_tasks', schema=None) as batch_op:
|
||||
batch_op.create_index(batch_op.f('ix_regeneration_tasks_chapter_id'), ['chapter_id'], unique=False)
|
||||
batch_op.create_index(batch_op.f('ix_regeneration_tasks_project_id'), ['project_id'], unique=False)
|
||||
batch_op.create_index(batch_op.f('ix_regeneration_tasks_user_id'), ['user_id'], unique=False)
|
||||
|
||||
op.create_table('story_memories',
|
||||
sa.Column('id', sa.String(length=100), nullable=False),
|
||||
sa.Column('project_id', sa.String(length=36), nullable=False),
|
||||
sa.Column('chapter_id', sa.String(length=36), nullable=True),
|
||||
sa.Column('memory_type', sa.String(length=50), nullable=False, comment='\n 记忆类型:\n - plot_point: 情节点\n - character_event: 角色事件\n - world_detail: 世界观细节\n - hook: 钩子(悬念/冲突)\n - foreshadow: 伏笔\n - dialogue: 重要对话\n - scene: 场景描写\n '),
|
||||
sa.Column('title', sa.String(length=200), nullable=True, comment='记忆标题/简述'),
|
||||
sa.Column('content', sa.Text(), nullable=False, comment='记忆内容摘要(100-500字)'),
|
||||
sa.Column('full_context', sa.Text(), nullable=True, comment='完整上下文(可选,用于详细记录)'),
|
||||
sa.Column('related_characters', sa.JSON(), nullable=True, comment="涉及角色ID列表: ['char_id_1', 'char_id_2']"),
|
||||
sa.Column('related_locations', sa.JSON(), nullable=True, comment="涉及地点列表: ['地点1', '地点2']"),
|
||||
sa.Column('tags', sa.JSON(), nullable=True, comment="标签列表: ['悬念', '转折', '伏笔', '高潮']"),
|
||||
sa.Column('importance_score', sa.Float(), nullable=True, comment='重要性评分 0.0-1.0'),
|
||||
sa.Column('story_timeline', sa.Integer(), nullable=False, comment='故事时间线位置(章节序号)'),
|
||||
sa.Column('chapter_position', sa.Integer(), nullable=True, comment='章节内位置(字符位置)'),
|
||||
sa.Column('text_length', sa.Integer(), nullable=True, comment='文本长度(字符数)'),
|
||||
sa.Column('is_foreshadow', sa.Integer(), nullable=True, comment='伏笔状态: 0=普通记忆, 1=已埋下伏笔, 2=伏笔已回收'),
|
||||
sa.Column('foreshadow_resolved_at', sa.String(length=100), nullable=True, comment='伏笔回收的章节ID'),
|
||||
sa.Column('foreshadow_strength', sa.Float(), nullable=True, comment='伏笔强度 0.0-1.0'),
|
||||
sa.Column('vector_id', sa.String(length=100), nullable=True, comment='向量数据库中的唯一ID'),
|
||||
sa.Column('embedding_model', sa.String(length=100), nullable=True, comment='使用的embedding模型'),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=True, comment='创建时间'),
|
||||
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=True, comment='更新时间'),
|
||||
sa.ForeignKeyConstraint(['chapter_id'], ['chapters.id'], ondelete='CASCADE'),
|
||||
sa.ForeignKeyConstraint(['foreshadow_resolved_at'], ['chapters.id'], ondelete='SET NULL'),
|
||||
sa.ForeignKeyConstraint(['project_id'], ['projects.id'], ondelete='CASCADE'),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
sa.UniqueConstraint('vector_id')
|
||||
)
|
||||
with op.batch_alter_table('story_memories', schema=None) as batch_op:
|
||||
batch_op.create_index(batch_op.f('ix_story_memories_chapter_id'), ['chapter_id'], unique=False)
|
||||
batch_op.create_index(batch_op.f('ix_story_memories_memory_type'), ['memory_type'], unique=False)
|
||||
batch_op.create_index(batch_op.f('ix_story_memories_project_id'), ['project_id'], unique=False)
|
||||
batch_op.create_index(batch_op.f('ix_story_memories_story_timeline'), ['story_timeline'], unique=False)
|
||||
|
||||
op.create_table('organization_members',
|
||||
sa.Column('id', sa.String(length=36), nullable=False, comment='成员关系ID'),
|
||||
sa.Column('organization_id', sa.String(length=36), nullable=False, comment='组织ID'),
|
||||
sa.Column('character_id', sa.String(length=36), nullable=False, comment='角色ID'),
|
||||
sa.Column('position', sa.String(length=100), nullable=False, comment='职位名称'),
|
||||
sa.Column('rank', sa.Integer(), nullable=True, comment='职位等级'),
|
||||
sa.Column('status', sa.String(length=20), nullable=True, comment='状态:active/retired/expelled/deceased'),
|
||||
sa.Column('joined_at', sa.String(length=100), nullable=True, comment='加入时间(故事时间)'),
|
||||
sa.Column('left_at', sa.String(length=100), nullable=True, comment='离开时间(故事时间)'),
|
||||
sa.Column('loyalty', sa.Integer(), nullable=True, comment='忠诚度:0-100'),
|
||||
sa.Column('contribution', sa.Integer(), nullable=True, comment='贡献度:0-100'),
|
||||
sa.Column('source', sa.String(length=20), nullable=True, comment='来源:ai/manual'),
|
||||
sa.Column('notes', sa.Text(), nullable=True, comment='备注'),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=True, comment='创建时间'),
|
||||
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=True, comment='更新时间'),
|
||||
sa.ForeignKeyConstraint(['character_id'], ['characters.id'], ondelete='CASCADE'),
|
||||
sa.ForeignKeyConstraint(['organization_id'], ['organizations.id'], ondelete='CASCADE'),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
with op.batch_alter_table('organization_members', schema=None) as batch_op:
|
||||
batch_op.create_index(batch_op.f('ix_organization_members_character_id'), ['character_id'], unique=False)
|
||||
batch_op.create_index(batch_op.f('ix_organization_members_organization_id'), ['organization_id'], unique=False)
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('organization_members', schema=None) as batch_op:
|
||||
batch_op.drop_index(batch_op.f('ix_organization_members_organization_id'))
|
||||
batch_op.drop_index(batch_op.f('ix_organization_members_character_id'))
|
||||
|
||||
op.drop_table('organization_members')
|
||||
with op.batch_alter_table('story_memories', schema=None) as batch_op:
|
||||
batch_op.drop_index(batch_op.f('ix_story_memories_story_timeline'))
|
||||
batch_op.drop_index(batch_op.f('ix_story_memories_project_id'))
|
||||
batch_op.drop_index(batch_op.f('ix_story_memories_memory_type'))
|
||||
batch_op.drop_index(batch_op.f('ix_story_memories_chapter_id'))
|
||||
|
||||
op.drop_table('story_memories')
|
||||
with op.batch_alter_table('regeneration_tasks', schema=None) as batch_op:
|
||||
batch_op.drop_index(batch_op.f('ix_regeneration_tasks_user_id'))
|
||||
batch_op.drop_index(batch_op.f('ix_regeneration_tasks_project_id'))
|
||||
batch_op.drop_index(batch_op.f('ix_regeneration_tasks_chapter_id'))
|
||||
|
||||
op.drop_table('regeneration_tasks')
|
||||
with op.batch_alter_table('plot_analysis', schema=None) as batch_op:
|
||||
batch_op.drop_index(batch_op.f('ix_plot_analysis_project_id'))
|
||||
batch_op.drop_index(batch_op.f('ix_plot_analysis_chapter_id'))
|
||||
|
||||
op.drop_table('plot_analysis')
|
||||
with op.batch_alter_table('organizations', schema=None) as batch_op:
|
||||
batch_op.drop_index(batch_op.f('ix_organizations_project_id'))
|
||||
|
||||
op.drop_table('organizations')
|
||||
op.drop_table('generation_history')
|
||||
with op.batch_alter_table('character_relationships', schema=None) as batch_op:
|
||||
batch_op.drop_index(batch_op.f('ix_character_relationships_relationship_type_id'))
|
||||
batch_op.drop_index(batch_op.f('ix_character_relationships_project_id'))
|
||||
batch_op.drop_index(batch_op.f('ix_character_relationships_character_to_id'))
|
||||
batch_op.drop_index(batch_op.f('ix_character_relationships_character_from_id'))
|
||||
|
||||
op.drop_table('character_relationships')
|
||||
with op.batch_alter_table('character_careers', schema=None) as batch_op:
|
||||
batch_op.drop_index('idx_character_id')
|
||||
batch_op.drop_index('idx_character_career')
|
||||
batch_op.drop_index('idx_career_type')
|
||||
|
||||
op.drop_table('character_careers')
|
||||
with op.batch_alter_table('analysis_tasks', schema=None) as batch_op:
|
||||
batch_op.drop_index('idx_status')
|
||||
batch_op.drop_index('idx_chapter_id_created')
|
||||
|
||||
op.drop_table('analysis_tasks')
|
||||
op.drop_table('project_default_styles')
|
||||
op.drop_table('characters')
|
||||
op.drop_table('chapters')
|
||||
op.drop_table('writing_styles')
|
||||
op.drop_table('outlines')
|
||||
with op.batch_alter_table('careers', schema=None) as batch_op:
|
||||
batch_op.drop_index('idx_type')
|
||||
batch_op.drop_index('idx_project_id')
|
||||
|
||||
op.drop_table('careers')
|
||||
with op.batch_alter_table('users', schema=None) as batch_op:
|
||||
batch_op.drop_index(batch_op.f('ix_users_username'))
|
||||
batch_op.drop_index(batch_op.f('ix_users_user_id'))
|
||||
batch_op.drop_index(batch_op.f('ix_users_linuxdo_id'))
|
||||
|
||||
op.drop_table('users')
|
||||
with op.batch_alter_table('user_passwords', schema=None) as batch_op:
|
||||
batch_op.drop_index(batch_op.f('ix_user_passwords_user_id'))
|
||||
|
||||
op.drop_table('user_passwords')
|
||||
with op.batch_alter_table('settings', schema=None) as batch_op:
|
||||
batch_op.drop_index(batch_op.f('ix_settings_user_id'))
|
||||
batch_op.drop_index('idx_user_id')
|
||||
|
||||
op.drop_table('settings')
|
||||
with op.batch_alter_table('relationship_types', schema=None) as batch_op:
|
||||
batch_op.drop_index(batch_op.f('ix_relationship_types_id'))
|
||||
|
||||
op.drop_table('relationship_types')
|
||||
with op.batch_alter_table('prompt_templates', schema=None) as batch_op:
|
||||
batch_op.drop_index(batch_op.f('ix_prompt_templates_user_id'))
|
||||
batch_op.drop_index('idx_user_template')
|
||||
|
||||
op.drop_table('prompt_templates')
|
||||
with op.batch_alter_table('projects', schema=None) as batch_op:
|
||||
batch_op.drop_index(batch_op.f('ix_projects_user_id'))
|
||||
|
||||
op.drop_table('projects')
|
||||
with op.batch_alter_table('mcp_plugins', schema=None) as batch_op:
|
||||
batch_op.drop_index(batch_op.f('ix_mcp_plugins_user_id'))
|
||||
batch_op.drop_index('idx_user_plugin')
|
||||
batch_op.drop_index('idx_user_enabled')
|
||||
|
||||
op.drop_table('mcp_plugins')
|
||||
op.drop_table('batch_generation_tasks')
|
||||
# ### end Alembic commands ###
|
||||
@@ -0,0 +1,177 @@
|
||||
"""初始化SQLite预置数据
|
||||
|
||||
Revision ID: a1b2c3d4e5f6
|
||||
Revises: fbeb1038c728
|
||||
Create Date: 2025-12-27 08:56:00.000000
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import table, column, String, Integer, Text
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = 'a1b2c3d4e5f6'
|
||||
down_revision: Union[str, None] = 'fbeb1038c728'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""插入预置数据"""
|
||||
|
||||
# ==================== 1. 插入关系类型数据 ====================
|
||||
relationship_types_table = table(
|
||||
'relationship_types',
|
||||
column('name', String),
|
||||
column('category', String),
|
||||
column('reverse_name', String),
|
||||
column('intimacy_range', String),
|
||||
column('icon', String),
|
||||
column('description', Text),
|
||||
)
|
||||
|
||||
relationship_types_data = [
|
||||
# 家庭关系
|
||||
{"name": "父亲", "category": "family", "reverse_name": "子女", "intimacy_range": "high", "icon": "👨", "description": "父子/父女关系"},
|
||||
{"name": "母亲", "category": "family", "reverse_name": "子女", "intimacy_range": "high", "icon": "👩", "description": "母子/母女关系"},
|
||||
{"name": "兄弟", "category": "family", "reverse_name": "兄弟", "intimacy_range": "high", "icon": "👬", "description": "兄弟关系"},
|
||||
{"name": "姐妹", "category": "family", "reverse_name": "姐妹", "intimacy_range": "high", "icon": "👭", "description": "姐妹关系"},
|
||||
{"name": "子女", "category": "family", "reverse_name": "父母", "intimacy_range": "high", "icon": "👶", "description": "子女关系"},
|
||||
{"name": "配偶", "category": "family", "reverse_name": "配偶", "intimacy_range": "high", "icon": "💑", "description": "夫妻关系"},
|
||||
{"name": "恋人", "category": "family", "reverse_name": "恋人", "intimacy_range": "high", "icon": "💕", "description": "恋爱关系"},
|
||||
|
||||
# 社交关系
|
||||
{"name": "师父", "category": "social", "reverse_name": "徒弟", "intimacy_range": "high", "icon": "🎓", "description": "师徒关系(师父视角)"},
|
||||
{"name": "徒弟", "category": "social", "reverse_name": "师父", "intimacy_range": "high", "icon": "📚", "description": "师徒关系(徒弟视角)"},
|
||||
{"name": "朋友", "category": "social", "reverse_name": "朋友", "intimacy_range": "medium", "icon": "🤝", "description": "朋友关系"},
|
||||
{"name": "同学", "category": "social", "reverse_name": "同学", "intimacy_range": "medium", "icon": "🎒", "description": "同学关系"},
|
||||
{"name": "邻居", "category": "social", "reverse_name": "邻居", "intimacy_range": "low", "icon": "🏘️", "description": "邻居关系"},
|
||||
{"name": "知己", "category": "social", "reverse_name": "知己", "intimacy_range": "high", "icon": "💙", "description": "知心好友"},
|
||||
|
||||
# 职业关系
|
||||
{"name": "上司", "category": "professional", "reverse_name": "下属", "intimacy_range": "low", "icon": "👔", "description": "上下级关系(上司视角)"},
|
||||
{"name": "下属", "category": "professional", "reverse_name": "上司", "intimacy_range": "low", "icon": "💼", "description": "上下级关系(下属视角)"},
|
||||
{"name": "同事", "category": "professional", "reverse_name": "同事", "intimacy_range": "medium", "icon": "🤵", "description": "同事关系"},
|
||||
{"name": "合作伙伴", "category": "professional", "reverse_name": "合作伙伴", "intimacy_range": "medium", "icon": "🤜🤛", "description": "合作关系"},
|
||||
|
||||
# 敌对关系
|
||||
{"name": "敌人", "category": "hostile", "reverse_name": "敌人", "intimacy_range": "low", "icon": "⚔️", "description": "敌对关系"},
|
||||
{"name": "仇人", "category": "hostile", "reverse_name": "仇人", "intimacy_range": "low", "icon": "💢", "description": "仇恨关系"},
|
||||
{"name": "竞争对手", "category": "hostile", "reverse_name": "竞争对手", "intimacy_range": "low", "icon": "🎯", "description": "竞争关系"},
|
||||
{"name": "宿敌", "category": "hostile", "reverse_name": "宿敌", "intimacy_range": "low", "icon": "⚡", "description": "宿命之敌"},
|
||||
]
|
||||
|
||||
op.bulk_insert(relationship_types_table, relationship_types_data)
|
||||
print(f"✅ SQLite: 已插入 {len(relationship_types_data)} 条关系类型数据")
|
||||
|
||||
|
||||
# ==================== 2. 插入全局写作风格预设 ====================
|
||||
writing_styles_table = table(
|
||||
'writing_styles',
|
||||
column('user_id', String),
|
||||
column('name', String),
|
||||
column('style_type', String),
|
||||
column('preset_id', String),
|
||||
column('description', Text),
|
||||
column('prompt_content', Text),
|
||||
column('order_index', Integer),
|
||||
)
|
||||
|
||||
writing_styles_data = [
|
||||
{
|
||||
"user_id": None, # NULL 表示全局预设
|
||||
"name": "自然流畅",
|
||||
"style_type": "preset",
|
||||
"preset_id": "natural",
|
||||
"description": "自然流畅的叙事风格,适合现代都市、现实题材",
|
||||
"prompt_content": """写作风格要求:
|
||||
1. 语言简洁明快,贴近现代口语
|
||||
2. 多用短句,节奏流畅
|
||||
3. 注重情感细节的自然流露
|
||||
4. 避免过度修饰和复杂句式""",
|
||||
"order_index": 1
|
||||
},
|
||||
{
|
||||
"user_id": None,
|
||||
"name": "古典优雅",
|
||||
"style_type": "preset",
|
||||
"preset_id": "classical",
|
||||
"description": "古典文雅的写作风格,适合古装、仙侠题材",
|
||||
"prompt_content": """写作风格要求:
|
||||
1. 使用文言、半文言或典雅的白话
|
||||
2. 适当运用古典诗词意象
|
||||
3. 注重意境营造和韵味
|
||||
4. 对话和描写保持古典美感""",
|
||||
"order_index": 2
|
||||
},
|
||||
{
|
||||
"user_id": None,
|
||||
"name": "现代简约",
|
||||
"style_type": "preset",
|
||||
"preset_id": "modern",
|
||||
"description": "现代简约风格,适合轻小说、网文快节奏叙事",
|
||||
"prompt_content": """写作风格要求:
|
||||
1. 语言直白简练,信息密度高
|
||||
2. 多用对话推进情节
|
||||
3. 避免冗长描写,突出关键动作
|
||||
4. 节奏明快,适合快速阅读""",
|
||||
"order_index": 3
|
||||
},
|
||||
{
|
||||
"user_id": None,
|
||||
"name": "文艺细腻",
|
||||
"style_type": "preset",
|
||||
"preset_id": "literary",
|
||||
"description": "文艺细腻风格,注重心理描写和氛围营造",
|
||||
"prompt_content": """写作风格要求:
|
||||
1. 注重心理活动和情感细节
|
||||
2. 善用环境描写烘托氛围
|
||||
3. 语言优美,富有文学性
|
||||
4. 适当使用比喻、象征等修辞手法""",
|
||||
"order_index": 4
|
||||
},
|
||||
{
|
||||
"user_id": None,
|
||||
"name": "紧张悬疑",
|
||||
"style_type": "preset",
|
||||
"preset_id": "suspense",
|
||||
"description": "紧张悬疑风格,适合推理、惊悚题材",
|
||||
"prompt_content": """写作风格要求:
|
||||
1. 营造紧张压迫的氛围
|
||||
2. 多用短句加快节奏
|
||||
3. 善于设置悬念和伏笔
|
||||
4. 注重细节描写,为推理埋下线索""",
|
||||
"order_index": 5
|
||||
},
|
||||
{
|
||||
"user_id": None,
|
||||
"name": "幽默诙谐",
|
||||
"style_type": "preset",
|
||||
"preset_id": "humorous",
|
||||
"description": "幽默诙谐风格,适合轻松搞笑题材",
|
||||
"prompt_content": """写作风格要求:
|
||||
1. 语言活泼风趣,善用俏皮话
|
||||
2. 注重对话的喜剧效果
|
||||
3. 适当夸张和反转制造笑点
|
||||
4. 保持轻松愉快的基调""",
|
||||
"order_index": 6
|
||||
},
|
||||
]
|
||||
|
||||
op.bulk_insert(writing_styles_table, writing_styles_data)
|
||||
print(f"✅ SQLite: 已插入 {len(writing_styles_data)} 条全局写作风格预设")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""删除预置数据"""
|
||||
|
||||
# 删除写作风格预设(只删除全局预设)
|
||||
op.execute("DELETE FROM writing_styles WHERE user_id IS NULL")
|
||||
print("✅ SQLite: 已删除全局写作风格预设")
|
||||
|
||||
# 删除关系类型
|
||||
op.execute("DELETE FROM relationship_types")
|
||||
print("✅ SQLite: 已删除关系类型数据")
|
||||
+34
@@ -0,0 +1,34 @@
|
||||
"""添加system_prompt字段到settings表
|
||||
|
||||
Revision ID: 7899f8d4d839
|
||||
Revises: a1b2c3d4e5f6
|
||||
Create Date: 2025-12-27 17:00:35.440551
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '7899f8d4d839'
|
||||
down_revision: Union[str, None] = 'a1b2c3d4e5f6'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('settings', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('system_prompt', sa.Text(), nullable=True, comment='系统级别提示词,每次AI调用都会使用'))
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('settings', schema=None) as batch_op:
|
||||
batch_op.drop_column('system_prompt')
|
||||
|
||||
# ### end Alembic commands ###
|
||||
@@ -0,0 +1,388 @@
|
||||
"""
|
||||
管理员API - 用户管理功能
|
||||
"""
|
||||
from fastapi import APIRouter, HTTPException, Request, Depends
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Optional, List
|
||||
from datetime import datetime
|
||||
import hashlib
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from app.database import get_db
|
||||
from app.models.user import User
|
||||
from app.user_manager import user_manager
|
||||
from app.user_password import password_manager
|
||||
from app.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/admin", tags=["管理员"])
|
||||
|
||||
|
||||
# ==================== 请求/响应模型 ====================
|
||||
|
||||
class CreateUserRequest(BaseModel):
|
||||
"""创建用户请求"""
|
||||
username: str = Field(..., min_length=3, max_length=20, description="用户名")
|
||||
display_name: str = Field(..., min_length=2, max_length=50, description="显示名称")
|
||||
password: Optional[str] = Field(None, min_length=6, description="初始密码,留空则自动生成")
|
||||
avatar_url: Optional[str] = Field(None, description="头像URL")
|
||||
trust_level: int = Field(0, ge=0, le=9, description="信任等级")
|
||||
is_admin: bool = Field(False, description="是否为管理员")
|
||||
|
||||
|
||||
class UpdateUserRequest(BaseModel):
|
||||
"""更新用户请求"""
|
||||
display_name: Optional[str] = Field(None, min_length=2, max_length=50)
|
||||
avatar_url: Optional[str] = None
|
||||
trust_level: Optional[int] = Field(None, ge=-1, le=9)
|
||||
is_admin: Optional[bool] = Field(None, description="是否为管理员")
|
||||
|
||||
|
||||
class ToggleStatusRequest(BaseModel):
|
||||
"""切换用户状态请求"""
|
||||
is_active: bool = Field(..., description="true=启用, false=禁用")
|
||||
|
||||
|
||||
class ResetPasswordRequest(BaseModel):
|
||||
"""重置密码请求"""
|
||||
new_password: Optional[str] = Field(None, min_length=6, description="新密码,留空则重置为默认密码")
|
||||
|
||||
|
||||
class UserResponse(BaseModel):
|
||||
"""用户响应"""
|
||||
user_id: str
|
||||
username: str
|
||||
display_name: str
|
||||
avatar_url: Optional[str]
|
||||
trust_level: int
|
||||
is_admin: bool
|
||||
is_active: bool
|
||||
linuxdo_id: str
|
||||
created_at: str
|
||||
last_login: Optional[str]
|
||||
|
||||
|
||||
class CreateUserResponse(BaseModel):
|
||||
"""创建用户响应"""
|
||||
success: bool
|
||||
message: str
|
||||
user: dict
|
||||
default_password: Optional[str] = None
|
||||
|
||||
|
||||
# ==================== 权限检查依赖 ====================
|
||||
|
||||
async def check_admin(request: Request) -> User:
|
||||
"""检查管理员权限"""
|
||||
user = getattr(request.state, "user", None)
|
||||
if not user:
|
||||
raise HTTPException(status_code=401, detail="未登录")
|
||||
|
||||
if not user.is_admin:
|
||||
raise HTTPException(status_code=403, detail="需要管理员权限")
|
||||
|
||||
return user
|
||||
|
||||
|
||||
# ==================== API 端点 ====================
|
||||
|
||||
@router.get("/users")
|
||||
async def get_users(
|
||||
admin: User = Depends(check_admin),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""获取用户列表(仅管理员)"""
|
||||
try:
|
||||
all_users = await user_manager.get_all_users()
|
||||
|
||||
users_data = []
|
||||
for user in all_users:
|
||||
# user_manager 返回的是 Pydantic User 对象,直接转为 dict
|
||||
user_dict = user.model_dump()
|
||||
user_dict["is_active"] = user.trust_level != -1
|
||||
users_data.append(user_dict)
|
||||
|
||||
logger.info(f"管理员 {admin.user_id} 获取用户列表,共 {len(users_data)} 个用户")
|
||||
|
||||
return {
|
||||
"total": len(users_data),
|
||||
"users": users_data
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"获取用户列表失败: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"获取用户列表失败: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/users")
|
||||
async def create_user(
|
||||
data: CreateUserRequest,
|
||||
admin: User = Depends(check_admin),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""添加用户(仅管理员)"""
|
||||
try:
|
||||
# 检查用户名是否已存在
|
||||
all_users = await user_manager.get_all_users()
|
||||
for user in all_users:
|
||||
if user.username == data.username:
|
||||
raise HTTPException(status_code=409, detail="用户名已存在")
|
||||
|
||||
# 生成用户ID
|
||||
user_id = f"admin_created_{hashlib.md5(data.username.encode()).hexdigest()[:16]}"
|
||||
|
||||
# 创建用户
|
||||
new_user = await user_manager.create_or_update_from_linuxdo(
|
||||
linuxdo_id=user_id,
|
||||
username=data.username,
|
||||
display_name=data.display_name,
|
||||
avatar_url=data.avatar_url,
|
||||
trust_level=data.trust_level
|
||||
)
|
||||
|
||||
# 设置管理员标志
|
||||
if data.is_admin:
|
||||
# 直接更新数据库中的is_admin字段
|
||||
async with await user_manager._get_session() as session:
|
||||
result = await session.execute(
|
||||
select(User).where(User.user_id == user_id)
|
||||
)
|
||||
db_user = result.scalar_one_or_none()
|
||||
if db_user:
|
||||
db_user.is_admin = True
|
||||
await session.commit()
|
||||
new_user.is_admin = True
|
||||
|
||||
# 设置密码
|
||||
actual_password = await password_manager.set_password(
|
||||
user_id=new_user.user_id,
|
||||
username=data.username,
|
||||
password=data.password
|
||||
)
|
||||
|
||||
# Settings 将在首次访问设置页面时自动创建(延迟初始化)
|
||||
|
||||
logger.info(f"管理员 {admin.user_id} 创建了新用户 {new_user.user_id} ({data.username})")
|
||||
|
||||
return CreateUserResponse(
|
||||
success=True,
|
||||
message="用户创建成功",
|
||||
user=new_user.model_dump(),
|
||||
default_password=actual_password if not data.password else None
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"创建用户失败: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"创建用户失败: {str(e)}")
|
||||
|
||||
|
||||
@router.put("/users/{user_id}")
|
||||
async def update_user(
|
||||
user_id: str,
|
||||
data: UpdateUserRequest,
|
||||
admin: User = Depends(check_admin),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""编辑用户信息(仅管理员)"""
|
||||
try:
|
||||
# 获取目标用户
|
||||
target_user = await user_manager.get_user(user_id)
|
||||
if not target_user:
|
||||
raise HTTPException(status_code=404, detail="用户不存在")
|
||||
|
||||
# 更新用户信息
|
||||
async with await user_manager._get_session() as session:
|
||||
result = await session.execute(
|
||||
select(User).where(User.user_id == user_id)
|
||||
)
|
||||
db_user = result.scalar_one_or_none()
|
||||
|
||||
if not db_user:
|
||||
raise HTTPException(status_code=404, detail="用户不存在")
|
||||
|
||||
# 更新字段
|
||||
if data.display_name is not None:
|
||||
db_user.display_name = data.display_name
|
||||
if data.avatar_url is not None:
|
||||
db_user.avatar_url = data.avatar_url
|
||||
if data.trust_level is not None:
|
||||
db_user.trust_level = data.trust_level
|
||||
if data.is_admin is not None:
|
||||
# 检查是否是最后一个管理员
|
||||
if db_user.is_admin and not data.is_admin:
|
||||
all_users = await user_manager.get_all_users()
|
||||
admin_count = sum(1 for u in all_users if u.is_admin)
|
||||
if admin_count <= 1:
|
||||
raise HTTPException(status_code=400, detail="不能取消最后一个管理员的权限")
|
||||
db_user.is_admin = data.is_admin
|
||||
|
||||
await session.commit()
|
||||
await session.refresh(db_user)
|
||||
|
||||
logger.info(f"管理员 {admin.user_id} 更新了用户 {user_id} 的信息")
|
||||
|
||||
updated_user = await user_manager.get_user(user_id)
|
||||
user_dict = updated_user.model_dump()
|
||||
user_dict["is_active"] = updated_user.trust_level != -1
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": "用户信息更新成功",
|
||||
"user": user_dict
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"更新用户失败: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"更新用户失败: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/users/{user_id}/toggle-status")
|
||||
async def toggle_user_status(
|
||||
user_id: str,
|
||||
data: ToggleStatusRequest,
|
||||
admin: User = Depends(check_admin),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""切换用户状态(启用/禁用)(仅管理员)"""
|
||||
try:
|
||||
# 不允许禁用自己
|
||||
if user_id == admin.user_id:
|
||||
raise HTTPException(status_code=400, detail="不能禁用自己的账号")
|
||||
|
||||
# 获取目标用户
|
||||
target_user = await user_manager.get_user(user_id)
|
||||
if not target_user:
|
||||
raise HTTPException(status_code=404, detail="用户不存在")
|
||||
|
||||
# 更新状态
|
||||
async with await user_manager._get_session() as session:
|
||||
result = await session.execute(
|
||||
select(User).where(User.user_id == user_id)
|
||||
)
|
||||
db_user = result.scalar_one_or_none()
|
||||
|
||||
if not db_user:
|
||||
raise HTTPException(status_code=404, detail="用户不存在")
|
||||
|
||||
if data.is_active:
|
||||
# 启用用户:恢复trust_level为0(或之前的值)
|
||||
db_user.trust_level = 0
|
||||
else:
|
||||
# 禁用用户:设置trust_level为-1
|
||||
db_user.trust_level = -1
|
||||
|
||||
await session.commit()
|
||||
|
||||
status_text = "启用" if data.is_active else "禁用"
|
||||
logger.info(f"管理员 {admin.user_id} {status_text}了用户 {user_id}")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"用户已{status_text}",
|
||||
"is_active": data.is_active
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"切换用户状态失败: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"切换用户状态失败: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/users/{user_id}/reset-password")
|
||||
async def reset_password(
|
||||
user_id: str,
|
||||
data: ResetPasswordRequest,
|
||||
admin: User = Depends(check_admin),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""重置用户密码(仅管理员)"""
|
||||
try:
|
||||
# 获取目标用户
|
||||
target_user = await user_manager.get_user(user_id)
|
||||
if not target_user:
|
||||
raise HTTPException(status_code=404, detail="用户不存在")
|
||||
|
||||
# 重置密码
|
||||
actual_password = await password_manager.set_password(
|
||||
user_id=user_id,
|
||||
username=target_user.username,
|
||||
password=data.new_password
|
||||
)
|
||||
|
||||
logger.info(f"管理员 {admin.user_id} 重置了用户 {user_id} 的密码")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": "密码重置成功",
|
||||
"new_password": actual_password
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"重置密码失败: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"重置密码失败: {str(e)}")
|
||||
|
||||
|
||||
@router.delete("/users/{user_id}")
|
||||
async def delete_user(
|
||||
user_id: str,
|
||||
admin: User = Depends(check_admin),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""删除用户(仅管理员,慎用)"""
|
||||
try:
|
||||
# 不允许删除自己
|
||||
if user_id == admin.user_id:
|
||||
raise HTTPException(status_code=400, detail="不能删除自己的账号")
|
||||
|
||||
# 获取目标用户
|
||||
target_user = await user_manager.get_user(user_id)
|
||||
if not target_user:
|
||||
raise HTTPException(status_code=404, detail="用户不存在")
|
||||
|
||||
# 检查是否是最后一个管理员
|
||||
if target_user.is_admin:
|
||||
all_users = await user_manager.get_all_users()
|
||||
admin_count = sum(1 for u in all_users if u.is_admin)
|
||||
if admin_count <= 1:
|
||||
raise HTTPException(status_code=400, detail="不能删除最后一个管理员账号")
|
||||
|
||||
# 删除用户(包括密码记录)
|
||||
async with await user_manager._get_session() as session:
|
||||
# 删除用户记录
|
||||
result = await session.execute(
|
||||
select(User).where(User.user_id == user_id)
|
||||
)
|
||||
db_user = result.scalar_one_or_none()
|
||||
if db_user:
|
||||
await session.delete(db_user)
|
||||
|
||||
# 删除密码记录
|
||||
from app.models.user import UserPassword
|
||||
result = await session.execute(
|
||||
select(UserPassword).where(UserPassword.user_id == user_id)
|
||||
)
|
||||
pwd_record = result.scalar_one_or_none()
|
||||
if pwd_record:
|
||||
await session.delete(pwd_record)
|
||||
|
||||
await session.commit()
|
||||
|
||||
logger.warning(f"管理员 {admin.user_id} 删除了用户 {user_id}")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": "用户已删除"
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"删除用户失败: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"删除用户失败: {str(e)}")
|
||||
+250
-32
@@ -9,7 +9,7 @@ import hashlib
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from app.services.oauth_service import LinuxDOOAuthService
|
||||
from app.user_manager import user_manager
|
||||
from app.database import init_db
|
||||
from app.user_password import password_manager
|
||||
from app.logger import get_logger
|
||||
from app.config import settings
|
||||
|
||||
@@ -49,6 +49,25 @@ class LocalLoginResponse(BaseModel):
|
||||
user: Optional[dict] = None
|
||||
|
||||
|
||||
class SetPasswordRequest(BaseModel):
|
||||
"""设置密码请求"""
|
||||
password: str
|
||||
|
||||
|
||||
class SetPasswordResponse(BaseModel):
|
||||
"""设置密码响应"""
|
||||
success: bool
|
||||
message: str
|
||||
|
||||
|
||||
class PasswordStatusResponse(BaseModel):
|
||||
"""密码状态响应"""
|
||||
has_password: bool
|
||||
has_custom_password: bool
|
||||
username: Optional[str] = None
|
||||
default_password: Optional[str] = None
|
||||
|
||||
|
||||
@router.get("/config")
|
||||
async def get_auth_config():
|
||||
"""获取认证配置信息"""
|
||||
@@ -60,37 +79,79 @@ async def get_auth_config():
|
||||
|
||||
@router.post("/local/login", response_model=LocalLoginResponse)
|
||||
async def local_login(request: LocalLoginRequest, response: Response):
|
||||
"""本地账户登录"""
|
||||
"""本地账户登录(支持.env配置的管理员账号和Linux DO授权后绑定的账号)"""
|
||||
# 检查是否启用本地登录
|
||||
if not settings.LOCAL_AUTH_ENABLED:
|
||||
raise HTTPException(status_code=403, detail="本地账户登录未启用")
|
||||
|
||||
# 检查是否配置了本地账户
|
||||
if not settings.LOCAL_AUTH_USERNAME or not settings.LOCAL_AUTH_PASSWORD:
|
||||
raise HTTPException(status_code=500, detail="本地账户未配置")
|
||||
logger.info(f"[本地登录] 尝试登录用户名: {request.username}")
|
||||
|
||||
# 验证用户名和密码
|
||||
if request.username != settings.LOCAL_AUTH_USERNAME or request.password != settings.LOCAL_AUTH_PASSWORD:
|
||||
raise HTTPException(status_code=401, detail="用户名或密码错误")
|
||||
# 首先尝试查找 Linux DO 授权后绑定的账号
|
||||
all_users = await user_manager.get_all_users()
|
||||
target_user = None
|
||||
|
||||
# 生成本地用户ID(使用用户名的hash)
|
||||
user_id = f"local_{hashlib.md5(request.username.encode()).hexdigest()[:16]}"
|
||||
for user in all_users:
|
||||
# 同时检查 users 表的 username 和 user_passwords 表的 username
|
||||
password_username = await password_manager.get_username(user.user_id)
|
||||
if user.username == request.username or password_username == request.username:
|
||||
target_user = user
|
||||
logger.info(f"[本地登录] 找到 Linux DO 授权用户: {user.user_id}")
|
||||
break
|
||||
|
||||
# 创建或更新本地用户
|
||||
user = await user_manager.create_or_update_from_linuxdo(
|
||||
linuxdo_id=user_id,
|
||||
username=request.username,
|
||||
display_name=settings.LOCAL_AUTH_DISPLAY_NAME,
|
||||
avatar_url=None,
|
||||
trust_level=9 # 本地用户给予高信任级别
|
||||
)
|
||||
# 如果找到了 Linux DO 授权的用户
|
||||
if target_user:
|
||||
# 检查是否有密码
|
||||
if not await password_manager.has_password(target_user.user_id):
|
||||
logger.warning(f"[本地登录] 用户 {target_user.user_id} 没有设置密码")
|
||||
raise HTTPException(status_code=401, detail="用户名或密码错误")
|
||||
|
||||
# 验证密码
|
||||
if not await password_manager.verify_password(target_user.user_id, request.password):
|
||||
logger.warning(f"[本地登录] 用户 {target_user.user_id} 密码验证失败")
|
||||
raise HTTPException(status_code=401, detail="用户名或密码错误")
|
||||
|
||||
logger.info(f"[本地登录] Linux DO 授权用户 {target_user.user_id} 登录成功")
|
||||
user = target_user
|
||||
else:
|
||||
# 没有找到 Linux DO 用户,尝试 .env 配置的管理员账号
|
||||
logger.info(f"[本地登录] 未找到 Linux DO 用户,检查 .env 管理员账号")
|
||||
|
||||
# 检查是否配置了本地账户
|
||||
if not settings.LOCAL_AUTH_USERNAME or not settings.LOCAL_AUTH_PASSWORD:
|
||||
raise HTTPException(status_code=401, detail="用户名或密码错误")
|
||||
|
||||
# 生成本地用户ID(使用用户名的hash)
|
||||
user_id = f"local_{hashlib.md5(request.username.encode()).hexdigest()[:16]}"
|
||||
|
||||
# 检查用户是否存在
|
||||
user = await user_manager.get_user(user_id)
|
||||
|
||||
# 如果用户不存在,使用.env中的默认密码验证
|
||||
if not user:
|
||||
# 验证用户名和密码(使用.env配置)
|
||||
if request.username != settings.LOCAL_AUTH_USERNAME or request.password != settings.LOCAL_AUTH_PASSWORD:
|
||||
raise HTTPException(status_code=401, detail="用户名或密码错误")
|
||||
|
||||
# 创建本地用户
|
||||
user = await user_manager.create_or_update_from_linuxdo(
|
||||
linuxdo_id=user_id,
|
||||
username=request.username,
|
||||
display_name=settings.LOCAL_AUTH_DISPLAY_NAME,
|
||||
avatar_url=None,
|
||||
trust_level=9 # 本地用户给予高信任级别
|
||||
)
|
||||
|
||||
# 为新用户设置默认密码到数据库
|
||||
await password_manager.set_password(user.user_id, request.username, request.password)
|
||||
logger.info(f"[本地登录] 管理员用户 {user.user_id} 初始密码已设置到数据库")
|
||||
else:
|
||||
# 用户已存在,使用数据库中的密码验证
|
||||
if not await password_manager.verify_password(user.user_id, request.password):
|
||||
raise HTTPException(status_code=401, detail="用户名或密码错误")
|
||||
|
||||
logger.info(f"[本地登录] 管理员用户 {user.user_id} 登录成功")
|
||||
|
||||
# 初始化用户数据库
|
||||
try:
|
||||
await init_db(user.user_id)
|
||||
logger.info(f"本地用户 {user.user_id} 数据库初始化成功")
|
||||
except Exception as e:
|
||||
logger.error(f"本地用户 {user.user_id} 数据库初始化失败: {e}")
|
||||
# Settings 将在首次访问设置页面时自动创建(延迟初始化)
|
||||
|
||||
# 设置 Cookie(2小时有效)
|
||||
max_age = settings.SESSION_EXPIRE_MINUTES * 60
|
||||
@@ -189,13 +250,12 @@ async def _handle_callback(
|
||||
trust_level=trust_level
|
||||
)
|
||||
|
||||
# 3.5. 初始化用户数据库(如果是新用户)
|
||||
try:
|
||||
await init_db(user.user_id)
|
||||
logger.info(f"用户 {user.user_id} 数据库初始化成功")
|
||||
except Exception as e:
|
||||
logger.error(f"用户 {user.user_id} 数据库初始化失败: {e}")
|
||||
# 继续执行,不影响登录流程(可能是已存在的用户)
|
||||
# 3.1. 检查是否是首次登录(没有密码记录)
|
||||
is_first_login = not await password_manager.has_password(user.user_id)
|
||||
if is_first_login:
|
||||
logger.info(f"用户 {user.user_id} ({username}) 首次登录,需要初始化密码")
|
||||
|
||||
# Settings 将在首次访问设置页面时自动创建(延迟初始化)
|
||||
|
||||
# 4. 设置 Cookie 并重定向到前端回调页面
|
||||
# 使用配置的前端URL,支持不同的部署环境
|
||||
@@ -229,6 +289,17 @@ async def _handle_callback(
|
||||
samesite="lax"
|
||||
)
|
||||
|
||||
# 如果是首次登录,设置标记 Cookie(5分钟有效,仅用于前端显示初始密码提示)
|
||||
if is_first_login:
|
||||
redirect_response.set_cookie(
|
||||
key="first_login",
|
||||
value="true",
|
||||
max_age=300, # 5分钟有效
|
||||
httponly=False, # 前端需要读取
|
||||
samesite="lax"
|
||||
)
|
||||
logger.info(f"✅ [OAuth登录] 用户 {user.user_id} 首次登录,已设置 first_login 标记")
|
||||
|
||||
return redirect_response
|
||||
|
||||
|
||||
@@ -337,4 +408,151 @@ async def get_current_user(request: Request):
|
||||
if not hasattr(request.state, "user") or not request.state.user:
|
||||
raise HTTPException(status_code=401, detail="未登录")
|
||||
|
||||
return request.state.user.dict()
|
||||
return request.state.user.dict()
|
||||
|
||||
|
||||
@router.get("/password/status", response_model=PasswordStatusResponse)
|
||||
async def get_password_status(request: Request):
|
||||
"""获取当前用户的密码状态"""
|
||||
if not hasattr(request.state, "user") or not request.state.user:
|
||||
raise HTTPException(status_code=401, detail="未登录")
|
||||
|
||||
user = request.state.user
|
||||
has_password = await password_manager.has_password(user.user_id)
|
||||
has_custom = await password_manager.has_custom_password(user.user_id)
|
||||
username = await password_manager.get_username(user.user_id)
|
||||
|
||||
# 如果使用默认密码,返回默认密码供用户查看
|
||||
default_password = None
|
||||
if has_password and not has_custom:
|
||||
default_password = f"{user.username}@666"
|
||||
|
||||
return PasswordStatusResponse(
|
||||
has_password=has_password,
|
||||
has_custom_password=has_custom,
|
||||
username=username or user.username,
|
||||
default_password=default_password
|
||||
)
|
||||
|
||||
|
||||
@router.post("/password/set", response_model=SetPasswordResponse)
|
||||
async def set_user_password(request: Request, password_req: SetPasswordRequest):
|
||||
"""设置当前用户的密码"""
|
||||
if not hasattr(request.state, "user") or not request.state.user:
|
||||
raise HTTPException(status_code=401, detail="未登录")
|
||||
|
||||
user = request.state.user
|
||||
|
||||
# 验证密码强度(至少6个字符)
|
||||
if len(password_req.password) < 6:
|
||||
raise HTTPException(status_code=400, detail="密码长度至少为6个字符")
|
||||
|
||||
# 设置密码
|
||||
await password_manager.set_password(user.user_id, user.username, password_req.password)
|
||||
logger.info(f"用户 {user.user_id} ({user.username}) 设置了自定义密码")
|
||||
|
||||
return SetPasswordResponse(
|
||||
success=True,
|
||||
message="密码设置成功"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/password/initialize", response_model=SetPasswordResponse)
|
||||
async def initialize_user_password(request: Request, password_req: SetPasswordRequest):
|
||||
"""
|
||||
初始化首次登录用户的密码
|
||||
|
||||
用于首次通过 Linux DO 授权登录的用户,可以选择设置自定义密码或使用默认密码
|
||||
"""
|
||||
if not hasattr(request.state, "user") or not request.state.user:
|
||||
raise HTTPException(status_code=401, detail="未登录")
|
||||
|
||||
user = request.state.user
|
||||
|
||||
# 检查是否已经有密码(防止重复初始化)
|
||||
if await password_manager.has_password(user.user_id):
|
||||
raise HTTPException(status_code=400, detail="密码已经初始化,请使用密码修改功能")
|
||||
|
||||
# 验证密码强度(至少6个字符)
|
||||
if len(password_req.password) < 6:
|
||||
raise HTTPException(status_code=400, detail="密码长度至少为6个字符")
|
||||
|
||||
# 设置密码
|
||||
await password_manager.set_password(user.user_id, user.username, password_req.password)
|
||||
logger.info(f"用户 {user.user_id} ({user.username}) 初始化密码成功")
|
||||
|
||||
return SetPasswordResponse(
|
||||
success=True,
|
||||
message="密码初始化成功"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/bind/login", response_model=LocalLoginResponse)
|
||||
async def bind_account_login(request: LocalLoginRequest, response: Response):
|
||||
"""使用绑定的账号密码登录(LinuxDO授权后绑定的账号)"""
|
||||
# 查找用户
|
||||
all_users = await user_manager.get_all_users()
|
||||
target_user = None
|
||||
|
||||
logger.info(f"[绑定账号登录] 尝试登录用户名: {request.username}")
|
||||
logger.info(f"[绑定账号登录] 当前共有 {len(all_users)} 个用户")
|
||||
|
||||
for user in all_users:
|
||||
# 同时检查 users 表的 username 和 user_passwords 表的 username
|
||||
password_username = await password_manager.get_username(user.user_id)
|
||||
logger.info(f"[绑定账号登录] 检查用户 {user.user_id}: users.username={user.username}, passwords.username={password_username}")
|
||||
|
||||
if user.username == request.username or password_username == request.username:
|
||||
target_user = user
|
||||
logger.info(f"[绑定账号登录] 找到匹配用户: {user.user_id}")
|
||||
break
|
||||
|
||||
if not target_user:
|
||||
logger.warning(f"[绑定账号登录] 用户名 {request.username} 未找到")
|
||||
raise HTTPException(status_code=401, detail="用户名或密码错误")
|
||||
|
||||
# 检查是否有密码记录
|
||||
has_pwd = await password_manager.has_password(target_user.user_id)
|
||||
if not has_pwd:
|
||||
logger.warning(f"[绑定账号登录] 用户 {target_user.user_id} 没有设置密码")
|
||||
raise HTTPException(status_code=401, detail="用户名或密码错误")
|
||||
|
||||
# 验证密码
|
||||
is_valid = await password_manager.verify_password(target_user.user_id, request.password)
|
||||
logger.info(f"[绑定账号登录] 用户 {target_user.user_id} 密码验证结果: {is_valid}")
|
||||
|
||||
if not is_valid:
|
||||
raise HTTPException(status_code=401, detail="用户名或密码错误")
|
||||
|
||||
# Settings 将在首次访问设置页面时自动创建(延迟初始化)
|
||||
|
||||
# 设置 Cookie(2小时有效)
|
||||
max_age = settings.SESSION_EXPIRE_MINUTES * 60
|
||||
response.set_cookie(
|
||||
key="user_id",
|
||||
value=target_user.user_id,
|
||||
max_age=max_age,
|
||||
httponly=True,
|
||||
samesite="lax"
|
||||
)
|
||||
|
||||
# 设置过期时间戳 Cookie(用于前端判断)
|
||||
china_now = get_china_now()
|
||||
expire_time = china_now + timedelta(minutes=settings.SESSION_EXPIRE_MINUTES)
|
||||
expire_at = int(expire_time.timestamp())
|
||||
|
||||
logger.info(f"✅ [绑定账号登录] 用户 {target_user.user_id} ({request.username}) 登录成功,会话有效期 {settings.SESSION_EXPIRE_MINUTES} 分钟")
|
||||
|
||||
response.set_cookie(
|
||||
key="session_expire_at",
|
||||
value=str(expire_at),
|
||||
max_age=max_age,
|
||||
httponly=False, # 前端需要读取
|
||||
samesite="lax"
|
||||
)
|
||||
|
||||
return LocalLoginResponse(
|
||||
success=True,
|
||||
message="登录成功",
|
||||
user=target_user.dict()
|
||||
)
|
||||
@@ -0,0 +1,938 @@
|
||||
|
||||
"""职业管理API"""
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, func, and_
|
||||
import json
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from app.database import get_db
|
||||
from app.utils.sse_response import SSEResponse, create_sse_response
|
||||
from app.models.career import Career, CharacterCareer
|
||||
from app.models.character import Character
|
||||
from app.models.project import Project
|
||||
from app.schemas.career import (
|
||||
CareerCreate,
|
||||
CareerUpdate,
|
||||
CareerResponse,
|
||||
CareerListResponse,
|
||||
CareerGenerateRequest,
|
||||
CharacterCareerResponse,
|
||||
CharacterCareerDetail,
|
||||
SetMainCareerRequest,
|
||||
AddSubCareerRequest,
|
||||
UpdateCareerStageRequest,
|
||||
CareerStage
|
||||
)
|
||||
from app.services.ai_service import AIService
|
||||
from app.logger import get_logger
|
||||
from app.api.settings import get_user_ai_service
|
||||
|
||||
router = APIRouter(prefix="/careers", tags=["职业管理"])
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
async def verify_project_access(project_id: str, user_id: str, db: AsyncSession) -> Project:
|
||||
"""验证用户是否有权访问指定项目"""
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="未登录")
|
||||
|
||||
result = await db.execute(
|
||||
select(Project).where(
|
||||
Project.id == project_id,
|
||||
Project.user_id == user_id
|
||||
)
|
||||
)
|
||||
project = result.scalar_one_or_none()
|
||||
|
||||
if not project:
|
||||
logger.warning(f"项目访问被拒绝: project_id={project_id}, user_id={user_id}")
|
||||
raise HTTPException(status_code=404, detail="项目不存在或无权访问")
|
||||
|
||||
return project
|
||||
|
||||
|
||||
@router.get("", response_model=CareerListResponse, summary="获取职业列表")
|
||||
async def get_careers(
|
||||
project_id: str,
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""获取指定项目的所有职业"""
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
await verify_project_access(project_id, user_id, db)
|
||||
|
||||
# 获取总数
|
||||
count_result = await db.execute(
|
||||
select(func.count(Career.id)).where(Career.project_id == project_id)
|
||||
)
|
||||
total = count_result.scalar_one()
|
||||
|
||||
# 获取职业列表
|
||||
result = await db.execute(
|
||||
select(Career)
|
||||
.where(Career.project_id == project_id)
|
||||
.order_by(Career.type, Career.created_at.desc())
|
||||
)
|
||||
careers = result.scalars().all()
|
||||
|
||||
# 分类返回
|
||||
main_careers = []
|
||||
sub_careers = []
|
||||
|
||||
for career in careers:
|
||||
# 解析JSON字段
|
||||
stages = json.loads(career.stages) if career.stages else []
|
||||
attribute_bonuses = json.loads(career.attribute_bonuses) if career.attribute_bonuses else None
|
||||
|
||||
career_dict = {
|
||||
"id": career.id,
|
||||
"project_id": career.project_id,
|
||||
"name": career.name,
|
||||
"type": career.type,
|
||||
"description": career.description,
|
||||
"category": career.category,
|
||||
"stages": stages,
|
||||
"max_stage": career.max_stage,
|
||||
"requirements": career.requirements,
|
||||
"special_abilities": career.special_abilities,
|
||||
"worldview_rules": career.worldview_rules,
|
||||
"attribute_bonuses": attribute_bonuses,
|
||||
"source": career.source,
|
||||
"created_at": career.created_at,
|
||||
"updated_at": career.updated_at
|
||||
}
|
||||
|
||||
if career.type == "main":
|
||||
main_careers.append(career_dict)
|
||||
else:
|
||||
sub_careers.append(career_dict)
|
||||
|
||||
return CareerListResponse(
|
||||
total=total,
|
||||
main_careers=main_careers,
|
||||
sub_careers=sub_careers
|
||||
)
|
||||
|
||||
|
||||
@router.post("", response_model=CareerResponse, summary="创建职业")
|
||||
async def create_career(
|
||||
career_data: CareerCreate,
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""手动创建职业"""
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
await verify_project_access(career_data.project_id, user_id, db)
|
||||
|
||||
try:
|
||||
# 转换stages为JSON字符串
|
||||
stages_json = json.dumps([stage.model_dump() for stage in career_data.stages], ensure_ascii=False)
|
||||
attribute_bonuses_json = json.dumps(career_data.attribute_bonuses, ensure_ascii=False) if career_data.attribute_bonuses else None
|
||||
|
||||
# 创建职业
|
||||
career = Career(
|
||||
project_id=career_data.project_id,
|
||||
name=career_data.name,
|
||||
type=career_data.type,
|
||||
description=career_data.description,
|
||||
category=career_data.category,
|
||||
stages=stages_json,
|
||||
max_stage=career_data.max_stage,
|
||||
requirements=career_data.requirements,
|
||||
special_abilities=career_data.special_abilities,
|
||||
worldview_rules=career_data.worldview_rules,
|
||||
attribute_bonuses=attribute_bonuses_json,
|
||||
source=career_data.source
|
||||
)
|
||||
db.add(career)
|
||||
await db.commit()
|
||||
await db.refresh(career)
|
||||
|
||||
logger.info(f"✅ 创建职业成功:{career.name} (ID: {career.id}, 类型: {career.type})")
|
||||
|
||||
return CareerResponse(
|
||||
id=career.id,
|
||||
project_id=career.project_id,
|
||||
name=career.name,
|
||||
type=career.type,
|
||||
description=career.description,
|
||||
category=career.category,
|
||||
stages=career_data.stages,
|
||||
max_stage=career.max_stage,
|
||||
requirements=career.requirements,
|
||||
special_abilities=career.special_abilities,
|
||||
worldview_rules=career.worldview_rules,
|
||||
attribute_bonuses=career_data.attribute_bonuses,
|
||||
source=career.source,
|
||||
created_at=career.created_at,
|
||||
updated_at=career.updated_at
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"创建职业失败: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"创建职业失败: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/generate-system", summary="AI生成新职业(增量式,流式)")
|
||||
async def generate_career_system(
|
||||
project_id: str,
|
||||
main_career_count: int = 3,
|
||||
sub_career_count: int = 6,
|
||||
enable_mcp: bool = False,
|
||||
http_request: Request = None,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
user_ai_service: AIService = Depends(get_user_ai_service)
|
||||
):
|
||||
"""
|
||||
使用AI生成新职业(增量式,基于已有职业补充,支持SSE流式进度显示)
|
||||
|
||||
通过Server-Sent Events返回实时进度信息
|
||||
"""
|
||||
async def generate() -> AsyncGenerator[str, None]:
|
||||
try:
|
||||
# 验证用户权限和项目是否存在
|
||||
user_id = getattr(http_request.state, 'user_id', None)
|
||||
project = await verify_project_access(project_id, user_id, db)
|
||||
|
||||
yield await SSEResponse.send_progress("开始生成新职业...", 0)
|
||||
|
||||
# 获取已有职业列表
|
||||
yield await SSEResponse.send_progress("分析已有职业...", 5)
|
||||
|
||||
existing_careers_result = await db.execute(
|
||||
select(Career).where(Career.project_id == project_id)
|
||||
)
|
||||
existing_careers = existing_careers_result.scalars().all()
|
||||
|
||||
# 构建已有职业摘要
|
||||
existing_main_careers = []
|
||||
existing_sub_careers = []
|
||||
for career in existing_careers:
|
||||
career_summary = f"- {career.name}({career.category or '未分类'},{career.max_stage}阶)"
|
||||
if career.description:
|
||||
career_summary += f": {career.description[:50]}"
|
||||
|
||||
if career.type == "main":
|
||||
existing_main_careers.append(career_summary)
|
||||
else:
|
||||
existing_sub_careers.append(career_summary)
|
||||
|
||||
existing_careers_text = ""
|
||||
if existing_main_careers:
|
||||
existing_careers_text += f"\n已有主职业({len(existing_main_careers)}个):\n" + "\n".join(existing_main_careers)
|
||||
if existing_sub_careers:
|
||||
existing_careers_text += f"\n\n已有副职业({len(existing_sub_careers)}个):\n" + "\n".join(existing_sub_careers)
|
||||
|
||||
if not existing_careers_text:
|
||||
existing_careers_text = "\n当前还没有任何职业,这是第一次创建职业体系。"
|
||||
|
||||
# 构建项目上下文
|
||||
yield await SSEResponse.send_progress("分析项目世界观...", 15)
|
||||
|
||||
project_context = f"""
|
||||
项目信息:
|
||||
- 书名:{project.title}
|
||||
- 类型:{project.genre or '未设定'}
|
||||
- 主题:{project.theme or '未设定'}
|
||||
- 时间背景:{project.world_time_period or '未设定'}
|
||||
- 地理位置:{project.world_location or '未设定'}
|
||||
- 氛围基调:{project.world_atmosphere or '未设定'}
|
||||
- 世界规则:{project.world_rules or '未设定'}
|
||||
"""
|
||||
|
||||
user_requirements = f"""
|
||||
已有职业情况:{existing_careers_text}
|
||||
|
||||
生成要求(增量式):
|
||||
- 本次新增主职业:{main_career_count}个
|
||||
- 本次新增副职业:{sub_career_count}个
|
||||
- ⚠️ 重要:请生成与已有职业**不重复**的新职业,形成互补体系
|
||||
- 新职业应填补已有职业体系的空缺,丰富职业多样性
|
||||
- 主职业必须严格符合世界观规则,体现核心能力体系
|
||||
- 副职业可以更加自由灵活,包含生产、辅助、特殊类型
|
||||
"""
|
||||
|
||||
yield await SSEResponse.send_progress("构建AI提示词...", 20)
|
||||
|
||||
# 构建提示词
|
||||
prompt = f"""{project_context}
|
||||
|
||||
{user_requirements}
|
||||
|
||||
请为这个小说项目生成新的补充职业(增量式)。要求:
|
||||
1. **仔细分析已有职业**,避免生成重复或相似的职业
|
||||
2. **填补职业体系的空缺**,让职业体系更加完善和多样化
|
||||
3. 如果已有职业较少,可以生成核心基础职业
|
||||
4. 如果已有职业较多,可以生成特色化、专精化的职业
|
||||
|
||||
返回JSON格式,结构如下:
|
||||
|
||||
{{
|
||||
"main_careers": [
|
||||
{{
|
||||
"name": "职业名称",
|
||||
"description": "职业描述",
|
||||
"category": "职业分类(如:战斗系、法术系等)",
|
||||
"stages": [
|
||||
{{"level": 1, "name": "阶段名称", "description": "阶段描述"}},
|
||||
{{"level": 2, "name": "阶段名称", "description": "阶段描述"}},
|
||||
...
|
||||
],
|
||||
"max_stage": 10,
|
||||
"requirements": "职业要求",
|
||||
"special_abilities": "特殊能力",
|
||||
"worldview_rules": "世界观规则关联",
|
||||
"attribute_bonuses": {{"strength": "+10%", "intelligence": "+5%"}}
|
||||
}}
|
||||
],
|
||||
"sub_careers": [
|
||||
{{
|
||||
"name": "副职业名称",
|
||||
"description": "职业描述",
|
||||
"category": "生产系/辅助系/特殊系",
|
||||
"stages": [...],
|
||||
"max_stage": 5,
|
||||
"requirements": "职业要求",
|
||||
"special_abilities": "特殊能力"
|
||||
}}
|
||||
]
|
||||
}}
|
||||
|
||||
注意事项:
|
||||
1. **避免重复**:生成的职业名称和定位不能与已有职业重复
|
||||
2. **互补性**:新职业应与已有职业形成互补,丰富职业体系
|
||||
3. 主职业的阶段设定要详细,体现明确的成长路径
|
||||
4. 阶段名称要符合世界观特色
|
||||
5. 副职业可以相对简化,但要有独特性
|
||||
6. 所有职业都要符合项目的整体世界观设定
|
||||
7. 只返回纯JSON,不要添加任何解释文字
|
||||
"""
|
||||
|
||||
yield await SSEResponse.send_progress("调用AI生成新职业...", 10)
|
||||
logger.info(f"🎯 开始为项目 {project_id} 生成新职业(增量式,已有{len(existing_careers)}个职业)")
|
||||
|
||||
try:
|
||||
# 使用流式生成替代非流式
|
||||
ai_response = ""
|
||||
chunk_count = 0
|
||||
last_progress = 10
|
||||
|
||||
async for chunk in user_ai_service.generate_text_stream(prompt=prompt):
|
||||
chunk_count += 1
|
||||
ai_response += chunk
|
||||
|
||||
# 发送内容块
|
||||
yield await SSEResponse.send_chunk(chunk)
|
||||
|
||||
# 平滑更新进度(10-90%,AI生成占60%)
|
||||
# 每10个chunk增加约1%的进度,最多到90%
|
||||
if chunk_count % 10 == 0:
|
||||
# 计算进度:10% + (chunk_count / 10) * 1%,但不超过90%
|
||||
current_progress = min(10 + (chunk_count // 10), 90)
|
||||
if current_progress > last_progress:
|
||||
last_progress = current_progress
|
||||
yield await SSEResponse.send_progress(
|
||||
f"AI生成职业体系中... (已生成 {len(ai_response)} 字符)",
|
||||
current_progress
|
||||
)
|
||||
|
||||
# 心跳
|
||||
if chunk_count % 20 == 0:
|
||||
yield await SSEResponse.send_heartbeat()
|
||||
|
||||
except Exception as ai_error:
|
||||
logger.error(f"❌ AI服务调用异常:{str(ai_error)}")
|
||||
yield await SSEResponse.send_error(f"AI服务调用失败:{str(ai_error)}")
|
||||
return
|
||||
|
||||
if not ai_response or not ai_response.strip():
|
||||
yield await SSEResponse.send_error("AI服务返回空响应")
|
||||
return
|
||||
|
||||
yield await SSEResponse.send_progress("解析AI响应...", 91)
|
||||
|
||||
# 清洗并解析JSON
|
||||
try:
|
||||
cleaned_response = user_ai_service._clean_json_response(ai_response)
|
||||
career_data = json.loads(cleaned_response)
|
||||
logger.info(f"✅ 职业体系JSON解析成功")
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"❌ 职业体系JSON解析失败: {e}")
|
||||
logger.error(f" 原始响应预览: {ai_response[:200]}")
|
||||
yield await SSEResponse.send_error(f"AI返回的内容无法解析为JSON:{str(e)}")
|
||||
return
|
||||
|
||||
yield await SSEResponse.send_progress("保存主职业到数据库...", 93)
|
||||
|
||||
# 保存主职业
|
||||
main_careers_created = []
|
||||
for idx, career_info in enumerate(career_data.get("main_careers", [])):
|
||||
try:
|
||||
stages_json = json.dumps(career_info.get("stages", []), ensure_ascii=False)
|
||||
attribute_bonuses = career_info.get("attribute_bonuses")
|
||||
attribute_bonuses_json = json.dumps(attribute_bonuses, ensure_ascii=False) if attribute_bonuses else None
|
||||
|
||||
career = Career(
|
||||
project_id=project_id,
|
||||
name=career_info.get("name", f"未命名主职业{idx+1}"),
|
||||
type="main",
|
||||
description=career_info.get("description"),
|
||||
category=career_info.get("category"),
|
||||
stages=stages_json,
|
||||
max_stage=career_info.get("max_stage", 10),
|
||||
requirements=career_info.get("requirements"),
|
||||
special_abilities=career_info.get("special_abilities"),
|
||||
worldview_rules=career_info.get("worldview_rules"),
|
||||
attribute_bonuses=attribute_bonuses_json,
|
||||
source="ai"
|
||||
)
|
||||
db.add(career)
|
||||
await db.flush()
|
||||
main_careers_created.append(career.name)
|
||||
logger.info(f" ✅ 创建主职业:{career.name}")
|
||||
except Exception as e:
|
||||
logger.error(f" ❌ 创建主职业失败:{str(e)}")
|
||||
continue
|
||||
|
||||
yield await SSEResponse.send_progress("保存副职业到数据库...", 96)
|
||||
|
||||
# 保存副职业
|
||||
sub_careers_created = []
|
||||
for idx, career_info in enumerate(career_data.get("sub_careers", [])):
|
||||
try:
|
||||
stages_json = json.dumps(career_info.get("stages", []), ensure_ascii=False)
|
||||
attribute_bonuses = career_info.get("attribute_bonuses")
|
||||
attribute_bonuses_json = json.dumps(attribute_bonuses, ensure_ascii=False) if attribute_bonuses else None
|
||||
|
||||
career = Career(
|
||||
project_id=project_id,
|
||||
name=career_info.get("name", f"未命名副职业{idx+1}"),
|
||||
type="sub",
|
||||
description=career_info.get("description"),
|
||||
category=career_info.get("category"),
|
||||
stages=stages_json,
|
||||
max_stage=career_info.get("max_stage", 5),
|
||||
requirements=career_info.get("requirements"),
|
||||
special_abilities=career_info.get("special_abilities"),
|
||||
worldview_rules=career_info.get("worldview_rules"),
|
||||
attribute_bonuses=attribute_bonuses_json,
|
||||
source="ai"
|
||||
)
|
||||
db.add(career)
|
||||
await db.flush()
|
||||
sub_careers_created.append(career.name)
|
||||
logger.info(f" ✅ 创建副职业:{career.name}")
|
||||
except Exception as e:
|
||||
logger.error(f" ❌ 创建副职业失败:{str(e)}")
|
||||
continue
|
||||
|
||||
await db.commit()
|
||||
|
||||
total_main = len(existing_main_careers) + len(main_careers_created)
|
||||
total_sub = len(existing_sub_careers) + len(sub_careers_created)
|
||||
|
||||
logger.info(f"🎉 新职业生成完成:新增主职业{len(main_careers_created)}个,新增副职业{len(sub_careers_created)}个")
|
||||
logger.info(f" 职业体系总数:主职业{total_main}个,副职业{total_sub}个")
|
||||
|
||||
yield await SSEResponse.send_progress(f"新职业生成完成!(主职业{total_main}个,副职业{total_sub}个)", 100, "success")
|
||||
|
||||
# 发送结果数据
|
||||
yield await SSEResponse.send_result({
|
||||
"main_careers_count": len(main_careers_created),
|
||||
"sub_careers_count": len(sub_careers_created),
|
||||
"main_careers": main_careers_created,
|
||||
"sub_careers": sub_careers_created
|
||||
})
|
||||
|
||||
yield await SSEResponse.send_done()
|
||||
|
||||
except HTTPException as he:
|
||||
logger.error(f"HTTP异常: {he.detail}")
|
||||
yield await SSEResponse.send_error(he.detail, he.status_code)
|
||||
except Exception as e:
|
||||
logger.error(f"生成职业体系失败: {str(e)}")
|
||||
yield await SSEResponse.send_error(f"生成新职业失败: {str(e)}")
|
||||
|
||||
return create_sse_response(generate())
|
||||
|
||||
|
||||
@router.put("/{career_id}", response_model=CareerResponse, summary="更新职业")
|
||||
async def update_career(
|
||||
career_id: str,
|
||||
career_update: CareerUpdate,
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""更新职业信息"""
|
||||
result = await db.execute(
|
||||
select(Career).where(Career.id == career_id)
|
||||
)
|
||||
career = result.scalar_one_or_none()
|
||||
|
||||
if not career:
|
||||
raise HTTPException(status_code=404, detail="职业不存在")
|
||||
|
||||
# 验证用户权限
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
await verify_project_access(career.project_id, user_id, db)
|
||||
|
||||
# 更新字段
|
||||
update_data = career_update.model_dump(exclude_unset=True)
|
||||
|
||||
for field, value in update_data.items():
|
||||
if field == "stages" and value is not None:
|
||||
# 转换为JSON字符串
|
||||
# model_dump() 已经将嵌套模型转换为字典,所以 value 中的元素已经是 dict
|
||||
stages_list = [
|
||||
stage if isinstance(stage, dict) else stage.model_dump()
|
||||
for stage in value
|
||||
]
|
||||
setattr(career, field, json.dumps(stages_list, ensure_ascii=False))
|
||||
elif field == "attribute_bonuses" and value is not None:
|
||||
# 转换为JSON字符串
|
||||
setattr(career, field, json.dumps(value, ensure_ascii=False))
|
||||
else:
|
||||
setattr(career, field, value)
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(career)
|
||||
|
||||
logger.info(f"✅ 更新职业成功:{career.name} (ID: {career_id})")
|
||||
|
||||
# 解析JSON返回
|
||||
stages = json.loads(career.stages) if career.stages else []
|
||||
attribute_bonuses = json.loads(career.attribute_bonuses) if career.attribute_bonuses else None
|
||||
|
||||
return CareerResponse(
|
||||
id=career.id,
|
||||
project_id=career.project_id,
|
||||
name=career.name,
|
||||
type=career.type,
|
||||
description=career.description,
|
||||
category=career.category,
|
||||
stages=stages,
|
||||
max_stage=career.max_stage,
|
||||
requirements=career.requirements,
|
||||
special_abilities=career.special_abilities,
|
||||
worldview_rules=career.worldview_rules,
|
||||
attribute_bonuses=attribute_bonuses,
|
||||
source=career.source,
|
||||
created_at=career.created_at,
|
||||
updated_at=career.updated_at
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/{career_id}", summary="删除职业")
|
||||
async def delete_career(
|
||||
career_id: str,
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""删除职业"""
|
||||
result = await db.execute(
|
||||
select(Career).where(Career.id == career_id)
|
||||
)
|
||||
career = result.scalar_one_or_none()
|
||||
|
||||
if not career:
|
||||
raise HTTPException(status_code=404, detail="职业不存在")
|
||||
|
||||
# 验证用户权限
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
await verify_project_access(career.project_id, user_id, db)
|
||||
|
||||
# 检查是否有角色使用该职业
|
||||
char_career_result = await db.execute(
|
||||
select(func.count(CharacterCareer.id)).where(CharacterCareer.career_id == career_id)
|
||||
)
|
||||
usage_count = char_career_result.scalar_one()
|
||||
|
||||
if usage_count > 0:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"该职业被{usage_count}个角色使用,无法删除。请先移除角色的职业关联。"
|
||||
)
|
||||
|
||||
await db.delete(career)
|
||||
await db.commit()
|
||||
|
||||
logger.info(f"✅ 删除职业成功:{career.name} (ID: {career_id})")
|
||||
|
||||
return {"message": "职业删除成功"}
|
||||
|
||||
|
||||
@router.get("/{career_id}", response_model=CareerResponse, summary="获取职业详情")
|
||||
async def get_career(
|
||||
career_id: str,
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""根据ID获取职业详情"""
|
||||
result = await db.execute(
|
||||
select(Career).where(Career.id == career_id)
|
||||
)
|
||||
career = result.scalar_one_or_none()
|
||||
|
||||
if not career:
|
||||
raise HTTPException(status_code=404, detail="职业不存在")
|
||||
|
||||
# 验证用户权限
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
await verify_project_access(career.project_id, user_id, db)
|
||||
|
||||
# 解析JSON字段
|
||||
stages = json.loads(career.stages) if career.stages else []
|
||||
attribute_bonuses = json.loads(career.attribute_bonuses) if career.attribute_bonuses else None
|
||||
|
||||
return CareerResponse(
|
||||
id=career.id,
|
||||
project_id=career.project_id,
|
||||
name=career.name,
|
||||
type=career.type,
|
||||
description=career.description,
|
||||
category=career.category,
|
||||
stages=stages,
|
||||
max_stage=career.max_stage,
|
||||
requirements=career.requirements,
|
||||
special_abilities=career.special_abilities,
|
||||
worldview_rules=career.worldview_rules,
|
||||
attribute_bonuses=attribute_bonuses,
|
||||
source=career.source,
|
||||
created_at=career.created_at,
|
||||
updated_at=career.updated_at
|
||||
)
|
||||
|
||||
|
||||
# ===== 角色职业关联API =====
|
||||
|
||||
@router.get("/character/{character_id}/careers", response_model=CharacterCareerResponse, summary="获取角色的职业信息")
|
||||
async def get_character_careers(
|
||||
character_id: str,
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""获取角色的所有职业信息(主职业和副职业)"""
|
||||
# 验证角色存在
|
||||
char_result = await db.execute(
|
||||
select(Character).where(Character.id == character_id)
|
||||
)
|
||||
character = char_result.scalar_one_or_none()
|
||||
|
||||
if not character:
|
||||
raise HTTPException(status_code=404, detail="角色不存在")
|
||||
|
||||
# 验证用户权限
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
await verify_project_access(character.project_id, user_id, db)
|
||||
|
||||
# 获取角色的所有职业关联
|
||||
result = await db.execute(
|
||||
select(CharacterCareer, Career)
|
||||
.join(Career, CharacterCareer.career_id == Career.id)
|
||||
.where(CharacterCareer.character_id == character_id)
|
||||
.order_by(CharacterCareer.career_type.desc()) # main排在前
|
||||
)
|
||||
career_relations = result.all()
|
||||
|
||||
main_career = None
|
||||
sub_careers = []
|
||||
|
||||
for char_career, career in career_relations:
|
||||
# 解析职业的阶段信息
|
||||
stages = json.loads(career.stages) if career.stages else []
|
||||
|
||||
# 找到当前阶段信息
|
||||
stage_name = "未知阶段"
|
||||
stage_description = None
|
||||
for stage in stages:
|
||||
if stage.get("level") == char_career.current_stage:
|
||||
stage_name = stage.get("name", f"第{char_career.current_stage}阶段")
|
||||
stage_description = stage.get("description")
|
||||
break
|
||||
|
||||
career_detail = CharacterCareerDetail(
|
||||
id=char_career.id,
|
||||
character_id=char_career.character_id,
|
||||
career_id=char_career.career_id,
|
||||
career_name=career.name,
|
||||
career_type=char_career.career_type,
|
||||
current_stage=char_career.current_stage,
|
||||
stage_name=stage_name,
|
||||
stage_description=stage_description,
|
||||
stage_progress=char_career.stage_progress,
|
||||
max_stage=career.max_stage,
|
||||
started_at=char_career.started_at,
|
||||
reached_current_stage_at=char_career.reached_current_stage_at,
|
||||
notes=char_career.notes,
|
||||
created_at=char_career.created_at,
|
||||
updated_at=char_career.updated_at
|
||||
)
|
||||
|
||||
if char_career.career_type == "main":
|
||||
main_career = career_detail
|
||||
else:
|
||||
sub_careers.append(career_detail)
|
||||
|
||||
return CharacterCareerResponse(
|
||||
main_career=main_career,
|
||||
sub_careers=sub_careers
|
||||
)
|
||||
|
||||
|
||||
@router.post("/character/{character_id}/careers/main", summary="设置角色主职业")
|
||||
async def set_main_career(
|
||||
character_id: str,
|
||||
career_request: SetMainCareerRequest,
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""设置或更换角色的主职业"""
|
||||
# 验证角色存在
|
||||
char_result = await db.execute(
|
||||
select(Character).where(Character.id == character_id)
|
||||
)
|
||||
character = char_result.scalar_one_or_none()
|
||||
|
||||
if not character:
|
||||
raise HTTPException(status_code=404, detail="角色不存在")
|
||||
|
||||
# 验证用户权限
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
await verify_project_access(character.project_id, user_id, db)
|
||||
|
||||
# 验证职业存在且为主职业类型
|
||||
career_result = await db.execute(
|
||||
select(Career).where(
|
||||
Career.id == career_request.career_id,
|
||||
Career.project_id == character.project_id
|
||||
)
|
||||
)
|
||||
career = career_result.scalar_one_or_none()
|
||||
|
||||
if not career:
|
||||
raise HTTPException(status_code=404, detail="职业不存在")
|
||||
|
||||
if career.type != "main":
|
||||
raise HTTPException(status_code=400, detail="该职业不是主职业类型,无法设置为主职业")
|
||||
|
||||
# 验证阶段有效性
|
||||
if career_request.current_stage > career.max_stage:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"阶段超出范围,该职业最大阶段为{career.max_stage}"
|
||||
)
|
||||
|
||||
# 检查是否已有主职业
|
||||
existing_main = await db.execute(
|
||||
select(CharacterCareer).where(
|
||||
CharacterCareer.character_id == character_id,
|
||||
CharacterCareer.career_type == "main"
|
||||
)
|
||||
)
|
||||
current_main = existing_main.scalar_one_or_none()
|
||||
|
||||
if current_main:
|
||||
# 删除旧的主职业
|
||||
await db.delete(current_main)
|
||||
logger.info(f" 移除旧主职业关联: {current_main.career_id}")
|
||||
|
||||
# 创建新的主职业关联
|
||||
char_career = CharacterCareer(
|
||||
character_id=character_id,
|
||||
career_id=career_request.career_id,
|
||||
career_type="main",
|
||||
current_stage=career_request.current_stage,
|
||||
stage_progress=0,
|
||||
started_at=career_request.started_at,
|
||||
reached_current_stage_at=career_request.started_at
|
||||
)
|
||||
db.add(char_career)
|
||||
await db.commit()
|
||||
|
||||
logger.info(f"✅ 设置主职业成功:角色{character.name} -> {career.name}(第{career_request.current_stage}阶段)")
|
||||
|
||||
return {"message": "主职业设置成功", "career_name": career.name}
|
||||
|
||||
|
||||
@router.post("/character/{character_id}/careers/sub", summary="添加角色副职业")
|
||||
async def add_sub_career(
|
||||
character_id: str,
|
||||
career_request: AddSubCareerRequest,
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""为角色添加副职业"""
|
||||
# 验证角色存在
|
||||
char_result = await db.execute(
|
||||
select(Character).where(Character.id == character_id)
|
||||
)
|
||||
character = char_result.scalar_one_or_none()
|
||||
|
||||
if not character:
|
||||
raise HTTPException(status_code=404, detail="角色不存在")
|
||||
|
||||
# 验证用户权限
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
await verify_project_access(character.project_id, user_id, db)
|
||||
|
||||
# 验证职业存在且为副职业类型
|
||||
career_result = await db.execute(
|
||||
select(Career).where(
|
||||
Career.id == career_request.career_id,
|
||||
Career.project_id == character.project_id
|
||||
)
|
||||
)
|
||||
career = career_result.scalar_one_or_none()
|
||||
|
||||
if not career:
|
||||
raise HTTPException(status_code=404, detail="职业不存在")
|
||||
|
||||
if career.type != "sub":
|
||||
raise HTTPException(status_code=400, detail="该职业不是副职业类型,无法添加为副职业")
|
||||
|
||||
# 验证阶段有效性
|
||||
if career_request.current_stage > career.max_stage:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"阶段超出范围,该职业最大阶段为{career.max_stage}"
|
||||
)
|
||||
|
||||
# 检查是否已存在
|
||||
existing_check = await db.execute(
|
||||
select(CharacterCareer).where(
|
||||
CharacterCareer.character_id == character_id,
|
||||
CharacterCareer.career_id == career_request.career_id
|
||||
)
|
||||
)
|
||||
if existing_check.scalar_one_or_none():
|
||||
raise HTTPException(status_code=400, detail="该角色已拥有此副职业")
|
||||
|
||||
# 检查副职业数量限制(可选,这里设置为最多5个)
|
||||
sub_count_result = await db.execute(
|
||||
select(func.count(CharacterCareer.id)).where(
|
||||
CharacterCareer.character_id == character_id,
|
||||
CharacterCareer.career_type == "sub"
|
||||
)
|
||||
)
|
||||
sub_count = sub_count_result.scalar_one()
|
||||
|
||||
if sub_count >= 5:
|
||||
raise HTTPException(status_code=400, detail="副职业数量已达上限(最多5个)")
|
||||
|
||||
# 创建副职业关联
|
||||
char_career = CharacterCareer(
|
||||
character_id=character_id,
|
||||
career_id=career_request.career_id,
|
||||
career_type="sub",
|
||||
current_stage=career_request.current_stage,
|
||||
stage_progress=0,
|
||||
started_at=career_request.started_at,
|
||||
reached_current_stage_at=career_request.started_at
|
||||
)
|
||||
db.add(char_career)
|
||||
await db.commit()
|
||||
|
||||
logger.info(f"✅ 添加副职业成功:角色{character.name} -> {career.name}(第{career_request.current_stage}阶段)")
|
||||
|
||||
return {"message": "副职业添加成功", "career_name": career.name}
|
||||
|
||||
|
||||
@router.put("/character/{character_id}/careers/{career_id}/stage", summary="更新职业阶段")
|
||||
async def update_career_stage(
|
||||
character_id: str,
|
||||
career_id: str,
|
||||
stage_request: UpdateCareerStageRequest,
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""更新角色在某个职业的阶段"""
|
||||
# 验证角色职业关联存在
|
||||
result = await db.execute(
|
||||
select(CharacterCareer, Career, Character)
|
||||
.join(Career, CharacterCareer.career_id == Career.id)
|
||||
.join(Character, CharacterCareer.character_id == Character.id)
|
||||
.where(
|
||||
CharacterCareer.character_id == character_id,
|
||||
CharacterCareer.career_id == career_id
|
||||
)
|
||||
)
|
||||
relation_data = result.one_or_none()
|
||||
|
||||
if not relation_data:
|
||||
raise HTTPException(status_code=404, detail="角色职业关联不存在")
|
||||
|
||||
char_career, career, character = relation_data
|
||||
|
||||
# 验证用户权限
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
await verify_project_access(character.project_id, user_id, db)
|
||||
|
||||
# 验证新阶段有效性
|
||||
if stage_request.current_stage > career.max_stage:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"阶段超出范围,该职业最大阶段为{career.max_stage}"
|
||||
)
|
||||
|
||||
# 验证阶段递增规则(不能倒退,除非降级)
|
||||
if stage_request.current_stage < char_career.current_stage:
|
||||
logger.warning(f"⚠️ 角色{character.name}的职业{career.name}阶段降低:{char_career.current_stage} -> {stage_request.current_stage}")
|
||||
|
||||
# 更新阶段信息
|
||||
char_career.current_stage = stage_request.current_stage
|
||||
char_career.stage_progress = stage_request.stage_progress
|
||||
if stage_request.reached_current_stage_at:
|
||||
char_career.reached_current_stage_at = stage_request.reached_current_stage_at
|
||||
if stage_request.notes is not None:
|
||||
char_career.notes = stage_request.notes
|
||||
|
||||
await db.commit()
|
||||
|
||||
logger.info(f"✅ 更新职业阶段成功:{character.name}的{career.name} -> 第{stage_request.current_stage}阶段")
|
||||
|
||||
return {
|
||||
"message": "职业阶段更新成功",
|
||||
"career_name": career.name,
|
||||
"new_stage": stage_request.current_stage
|
||||
}
|
||||
|
||||
|
||||
@router.delete("/character/{character_id}/careers/{career_id}", summary="删除角色副职业")
|
||||
async def remove_sub_career(
|
||||
character_id: str,
|
||||
career_id: str,
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""删除角色的副职业"""
|
||||
# 验证角色职业关联存在
|
||||
result = await db.execute(
|
||||
select(CharacterCareer, Character)
|
||||
.join(Character, CharacterCareer.character_id == Character.id)
|
||||
.where(
|
||||
CharacterCareer.character_id == character_id,
|
||||
CharacterCareer.career_id == career_id
|
||||
)
|
||||
)
|
||||
relation_data = result.one_or_none()
|
||||
|
||||
if not relation_data:
|
||||
raise HTTPException(status_code=404, detail="角色职业关联不存在")
|
||||
|
||||
char_career, character = relation_data
|
||||
|
||||
# 验证用户权限
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
await verify_project_access(character.project_id, user_id, db)
|
||||
|
||||
# 不允许删除主职业
|
||||
if char_career.career_type == "main":
|
||||
raise HTTPException(status_code=400, detail="无法删除主职业,只能更换")
|
||||
|
||||
await db.delete(char_career)
|
||||
await db.commit()
|
||||
|
||||
logger.info(f"✅ 删除副职业成功:角色{character.name}移除职业{career_id}")
|
||||
|
||||
return {"message": "副职业删除成功"}
|
||||
@@ -0,0 +1,233 @@
|
||||
"""
|
||||
更新日志API
|
||||
提供GitHub提交历史的缓存和代理服务
|
||||
"""
|
||||
from fastapi import APIRouter, HTTPException, Query
|
||||
from typing import List, Optional
|
||||
import httpx
|
||||
from datetime import datetime, timedelta
|
||||
from pydantic import BaseModel
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
# GitHub API配置
|
||||
GITHUB_API_BASE = "https://api.github.com"
|
||||
REPO_OWNER = "xiamuceer-j"
|
||||
REPO_NAME = "MuMuAINovel"
|
||||
|
||||
# 缓存配置
|
||||
_cache = {
|
||||
"data": None,
|
||||
"timestamp": None,
|
||||
"ttl": timedelta(hours=1) # 缓存1小时
|
||||
}
|
||||
|
||||
|
||||
class GitHubAuthor(BaseModel):
|
||||
"""GitHub作者信息"""
|
||||
name: str
|
||||
email: str
|
||||
date: str
|
||||
|
||||
|
||||
class GitHubCommitInfo(BaseModel):
|
||||
"""GitHub提交信息"""
|
||||
author: GitHubAuthor
|
||||
message: str
|
||||
|
||||
|
||||
class GitHubUser(BaseModel):
|
||||
"""GitHub用户信息"""
|
||||
login: str
|
||||
avatar_url: str
|
||||
|
||||
|
||||
class GitHubCommit(BaseModel):
|
||||
"""GitHub提交数据"""
|
||||
sha: str
|
||||
commit: GitHubCommitInfo
|
||||
html_url: str
|
||||
author: Optional[GitHubUser] = None
|
||||
|
||||
|
||||
class ChangelogResponse(BaseModel):
|
||||
"""更新日志响应"""
|
||||
commits: List[GitHubCommit]
|
||||
cached: bool
|
||||
cache_time: Optional[str] = None
|
||||
|
||||
|
||||
def is_cache_valid() -> bool:
|
||||
"""检查缓存是否有效"""
|
||||
if _cache["data"] is None or _cache["timestamp"] is None:
|
||||
return False
|
||||
|
||||
now = datetime.now()
|
||||
cache_age = now - _cache["timestamp"]
|
||||
|
||||
return cache_age < _cache["ttl"]
|
||||
|
||||
|
||||
async def fetch_github_commits(page: int = 1, per_page: int = 30) -> List[dict]:
|
||||
"""从GitHub API获取提交历史"""
|
||||
url = f"{GITHUB_API_BASE}/repos/{REPO_OWNER}/{REPO_NAME}/commits"
|
||||
params = {
|
||||
"author": REPO_OWNER,
|
||||
"page": page,
|
||||
"per_page": per_page
|
||||
}
|
||||
|
||||
headers = {
|
||||
"Accept": "application/vnd.github.v3+json",
|
||||
"User-Agent": "MuMuAINovel-App"
|
||||
}
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
response = await client.get(url, params=params, headers=headers)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
except httpx.HTTPError as e:
|
||||
logger.error(f"GitHub API请求失败: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail=f"获取GitHub提交历史失败: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/changelog", response_model=ChangelogResponse)
|
||||
async def get_changelog(
|
||||
page: int = Query(1, ge=1, description="页码"),
|
||||
per_page: int = Query(30, ge=1, le=100, description="每页数量")
|
||||
):
|
||||
"""
|
||||
获取更新日志
|
||||
|
||||
从GitHub获取项目的提交历史,支持缓存以减少API调用
|
||||
|
||||
- **page**: 页码,从1开始
|
||||
- **per_page**: 每页返回的提交数量,最大100
|
||||
"""
|
||||
try:
|
||||
# 只缓存第一页
|
||||
if page == 1 and is_cache_valid():
|
||||
logger.info("使用缓存的更新日志")
|
||||
return ChangelogResponse(
|
||||
commits=_cache["data"],
|
||||
cached=True,
|
||||
cache_time=_cache["timestamp"].isoformat()
|
||||
)
|
||||
|
||||
# 从GitHub获取数据
|
||||
logger.info(f"从GitHub获取更新日志 (page={page}, per_page={per_page})")
|
||||
commits_data = await fetch_github_commits(page, per_page)
|
||||
|
||||
# 解析数据
|
||||
commits = []
|
||||
for commit_data in commits_data:
|
||||
try:
|
||||
commit = GitHubCommit(
|
||||
sha=commit_data["sha"],
|
||||
commit=GitHubCommitInfo(
|
||||
author=GitHubAuthor(
|
||||
name=commit_data["commit"]["author"]["name"],
|
||||
email=commit_data["commit"]["author"]["email"],
|
||||
date=commit_data["commit"]["author"]["date"]
|
||||
),
|
||||
message=commit_data["commit"]["message"]
|
||||
),
|
||||
html_url=commit_data["html_url"],
|
||||
author=GitHubUser(
|
||||
login=commit_data["author"]["login"],
|
||||
avatar_url=commit_data["author"]["avatar_url"]
|
||||
) if commit_data.get("author") else None
|
||||
)
|
||||
commits.append(commit)
|
||||
except (KeyError, TypeError) as e:
|
||||
logger.warning(f"解析提交数据失败: {str(e)}")
|
||||
continue
|
||||
|
||||
# 缓存第一页数据
|
||||
if page == 1:
|
||||
_cache["data"] = commits
|
||||
_cache["timestamp"] = datetime.now()
|
||||
logger.info("已缓存更新日志")
|
||||
|
||||
return ChangelogResponse(
|
||||
commits=commits,
|
||||
cached=False,
|
||||
cache_time=None
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"获取更新日志时发生错误: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"获取更新日志失败: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/changelog/refresh")
|
||||
async def refresh_changelog():
|
||||
"""
|
||||
刷新更新日志缓存
|
||||
|
||||
强制从GitHub重新获取最新的提交历史
|
||||
"""
|
||||
try:
|
||||
logger.info("刷新更新日志缓存")
|
||||
|
||||
# 清除缓存
|
||||
_cache["data"] = None
|
||||
_cache["timestamp"] = None
|
||||
|
||||
# 重新获取
|
||||
commits_data = await fetch_github_commits(1, 30)
|
||||
|
||||
# 解析数据
|
||||
commits = []
|
||||
for commit_data in commits_data:
|
||||
try:
|
||||
commit = GitHubCommit(
|
||||
sha=commit_data["sha"],
|
||||
commit=GitHubCommitInfo(
|
||||
author=GitHubAuthor(
|
||||
name=commit_data["commit"]["author"]["name"],
|
||||
email=commit_data["commit"]["author"]["email"],
|
||||
date=commit_data["commit"]["author"]["date"]
|
||||
),
|
||||
message=commit_data["commit"]["message"]
|
||||
),
|
||||
html_url=commit_data["html_url"],
|
||||
author=GitHubUser(
|
||||
login=commit_data["author"]["login"],
|
||||
avatar_url=commit_data["author"]["avatar_url"]
|
||||
) if commit_data.get("author") else None
|
||||
)
|
||||
commits.append(commit)
|
||||
except (KeyError, TypeError) as e:
|
||||
logger.warning(f"解析提交数据失败: {str(e)}")
|
||||
continue
|
||||
|
||||
# 更新缓存
|
||||
_cache["data"] = commits
|
||||
_cache["timestamp"] = datetime.now()
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": "缓存已刷新",
|
||||
"commit_count": len(commits),
|
||||
"cache_time": _cache["timestamp"].isoformat()
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"刷新缓存时发生错误: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"刷新缓存失败: {str(e)}"
|
||||
)
|
||||
+2150
-168
File diff suppressed because it is too large
Load Diff
+1193
-317
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,490 @@
|
||||
"""灵感模式API - 通过对话引导创建项目"""
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from typing import Dict, Any
|
||||
import json
|
||||
|
||||
from app.database import get_db
|
||||
from app.services.ai_service import AIService
|
||||
from app.api.settings import get_user_ai_service
|
||||
from app.services.prompt_service import PromptService
|
||||
from app.logger import get_logger
|
||||
|
||||
router = APIRouter(prefix="/inspiration", tags=["灵感模式"])
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
# 不同阶段的temperature设置(递减以保持一致性)
|
||||
TEMPERATURE_SETTINGS = {
|
||||
"title": 0.8, # 书名阶段可以更有创意
|
||||
"description": 0.65, # 简介需要贴合书名和原始想法
|
||||
"theme": 0.55, # 主题需要更加贴合
|
||||
"genre": 0.45 # 类型应该很明确
|
||||
}
|
||||
|
||||
|
||||
def validate_options_response(result: Dict[str, Any], step: str, max_retries: int = 3) -> tuple[bool, str]:
|
||||
"""
|
||||
校验AI返回的选项格式是否正确
|
||||
|
||||
Returns:
|
||||
(is_valid, error_message)
|
||||
"""
|
||||
# 检查必需字段
|
||||
if "options" not in result:
|
||||
return False, "缺少options字段"
|
||||
|
||||
options = result.get("options", [])
|
||||
|
||||
# 检查options是否为数组
|
||||
if not isinstance(options, list):
|
||||
return False, "options必须是数组"
|
||||
|
||||
# 检查数组长度
|
||||
if len(options) < 3:
|
||||
return False, f"选项数量不足,至少需要3个,当前只有{len(options)}个"
|
||||
|
||||
if len(options) > 10:
|
||||
return False, f"选项数量过多,最多10个,当前有{len(options)}个"
|
||||
|
||||
# 检查每个选项是否为字符串且不为空
|
||||
for i, option in enumerate(options):
|
||||
if not isinstance(option, str):
|
||||
return False, f"第{i+1}个选项不是字符串类型"
|
||||
if not option.strip():
|
||||
return False, f"第{i+1}个选项为空"
|
||||
if len(option) > 500:
|
||||
return False, f"第{i+1}个选项过长(超过500字符)"
|
||||
|
||||
# 根据不同步骤进行特定校验
|
||||
if step == "genre":
|
||||
# 类型标签应该比较短
|
||||
for i, option in enumerate(options):
|
||||
if len(option) > 10:
|
||||
return False, f"类型标签【{option}】过长,应该在2-10字之间"
|
||||
|
||||
return True, ""
|
||||
|
||||
|
||||
@router.post("/generate-options")
|
||||
async def generate_options(
|
||||
data: Dict[str, Any],
|
||||
http_request: Request,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
ai_service: AIService = Depends(get_user_ai_service)
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
根据当前收集的信息生成下一步的选项建议(带自动重试)
|
||||
|
||||
Request:
|
||||
{
|
||||
"step": "title", // title/description/theme/genre
|
||||
"context": {
|
||||
"title": "...",
|
||||
"description": "...",
|
||||
"theme": "..."
|
||||
}
|
||||
}
|
||||
|
||||
Response:
|
||||
{
|
||||
"prompt": "引导语",
|
||||
"options": ["选项1", "选项2", ...]
|
||||
}
|
||||
"""
|
||||
max_retries = 3
|
||||
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
step = data.get("step", "title")
|
||||
context = data.get("context", {})
|
||||
|
||||
logger.info(f"灵感模式:生成{step}阶段的选项(第{attempt + 1}次尝试)")
|
||||
|
||||
# 获取用户ID
|
||||
user_id = getattr(http_request.state, 'user_id', None)
|
||||
|
||||
# 获取对应的提示词模板(根据step确定模板key)
|
||||
# 新结构:每个步骤有独立的 SYSTEM 和 USER 模板
|
||||
template_key_map = {
|
||||
"title": ("INSPIRATION_TITLE_SYSTEM", "INSPIRATION_TITLE_USER"),
|
||||
"description": ("INSPIRATION_DESCRIPTION_SYSTEM", "INSPIRATION_DESCRIPTION_USER"),
|
||||
"theme": ("INSPIRATION_THEME_SYSTEM", "INSPIRATION_THEME_USER"),
|
||||
"genre": ("INSPIRATION_GENRE_SYSTEM", "INSPIRATION_GENRE_USER")
|
||||
}
|
||||
template_keys = template_key_map.get(step)
|
||||
|
||||
if not template_keys:
|
||||
return {
|
||||
"error": f"不支持的步骤: {step}",
|
||||
"prompt": "",
|
||||
"options": []
|
||||
}
|
||||
|
||||
system_key, user_key = template_keys
|
||||
|
||||
# 获取自定义提示词模板(分别获取 system 和 user)
|
||||
system_template = await PromptService.get_template(system_key, user_id, db)
|
||||
user_template = await PromptService.get_template(user_key, user_id, db)
|
||||
|
||||
# 准备格式化参数
|
||||
format_params = {
|
||||
"initial_idea": context.get("initial_idea", context.get("description", "")),
|
||||
"title": context.get("title", ""),
|
||||
"description": context.get("description", ""),
|
||||
"theme": context.get("theme", "")
|
||||
}
|
||||
|
||||
# 格式化提示词
|
||||
system_prompt = system_template.format(**format_params)
|
||||
user_prompt = user_template.format(**format_params)
|
||||
|
||||
# 如果是重试,在提示词中强调格式要求
|
||||
if attempt > 0:
|
||||
system_prompt += f"\n\n⚠️ 这是第{attempt + 1}次生成,请务必严格按照JSON格式返回,确保options数组包含6个有效选项!"
|
||||
|
||||
# 调用AI生成选项
|
||||
# 关键改进:使用递减的temperature以保持后续阶段与前文的一致性
|
||||
temperature = TEMPERATURE_SETTINGS.get(step, 0.7)
|
||||
logger.info(f"调用AI生成{step}选项... (temperature={temperature})")
|
||||
|
||||
# 流式生成并累积文本
|
||||
accumulated_text = ""
|
||||
async for chunk in ai_service.generate_text_stream(
|
||||
prompt=user_prompt,
|
||||
system_prompt=system_prompt,
|
||||
temperature=temperature
|
||||
):
|
||||
accumulated_text += chunk
|
||||
|
||||
response = {"content": accumulated_text}
|
||||
content = accumulated_text
|
||||
logger.info(f"AI返回内容长度: {len(content)}")
|
||||
|
||||
# 解析JSON(使用统一的JSON清洗方法)
|
||||
try:
|
||||
# 使用统一的JSON清洗方法
|
||||
cleaned_content = ai_service._clean_json_response(content)
|
||||
|
||||
result = json.loads(cleaned_content)
|
||||
|
||||
# 校验返回格式
|
||||
is_valid, error_msg = validate_options_response(result, step)
|
||||
|
||||
if not is_valid:
|
||||
logger.warning(f"⚠️ 第{attempt + 1}次生成格式校验失败: {error_msg}")
|
||||
if attempt < max_retries - 1:
|
||||
logger.info("准备重试...")
|
||||
continue # 重试
|
||||
else:
|
||||
# 最后一次尝试也失败了
|
||||
return {
|
||||
"prompt": f"请为【{step}】提供内容:",
|
||||
"options": ["让AI重新生成", "我自己输入"],
|
||||
"error": f"AI生成格式错误({error_msg}),已自动重试{max_retries}次,请手动重试或自己输入"
|
||||
}
|
||||
|
||||
logger.info(f"✅ 第{attempt + 1}次成功生成{len(result.get('options', []))}个有效选项")
|
||||
return result
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"第{attempt + 1}次JSON解析失败: {e}")
|
||||
|
||||
if attempt < max_retries - 1:
|
||||
logger.info("JSON解析失败,准备重试...")
|
||||
continue # 重试
|
||||
else:
|
||||
# 最后一次尝试也失败了
|
||||
return {
|
||||
"prompt": f"请为【{step}】提供内容:",
|
||||
"options": ["让AI重新生成", "我自己输入"],
|
||||
"error": f"AI返回格式错误,已自动重试{max_retries}次,请手动重试或自己输入"
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"第{attempt + 1}次生成失败: {e}", exc_info=True)
|
||||
if attempt < max_retries - 1:
|
||||
logger.info("发生异常,准备重试...")
|
||||
continue
|
||||
else:
|
||||
return {
|
||||
"error": str(e),
|
||||
"prompt": "生成失败,请重试",
|
||||
"options": ["重新生成", "我自己输入"]
|
||||
}
|
||||
|
||||
# 理论上不会到这里
|
||||
return {
|
||||
"error": "生成失败",
|
||||
"prompt": "请重试",
|
||||
"options": []
|
||||
}
|
||||
|
||||
|
||||
@router.post("/refine-options")
|
||||
async def refine_options(
|
||||
data: Dict[str, Any],
|
||||
http_request: Request,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
ai_service: AIService = Depends(get_user_ai_service)
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
基于用户反馈重新生成选项(支持多轮对话)
|
||||
|
||||
Request:
|
||||
{
|
||||
"step": "title", // 当前步骤
|
||||
"context": {
|
||||
"initial_idea": "...",
|
||||
"title": "...",
|
||||
"description": "...",
|
||||
"theme": "..."
|
||||
},
|
||||
"feedback": "我想要更悲剧一些的主题", // 用户反馈
|
||||
"previous_options": ["选项1", "选项2", ...] // 之前的选项(可选)
|
||||
}
|
||||
|
||||
Response:
|
||||
{
|
||||
"prompt": "引导语",
|
||||
"options": ["新选项1", "新选项2", ...]
|
||||
}
|
||||
"""
|
||||
max_retries = 3
|
||||
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
step = data.get("step", "title")
|
||||
context = data.get("context", {})
|
||||
feedback = data.get("feedback", "")
|
||||
previous_options = data.get("previous_options", [])
|
||||
|
||||
logger.info(f"灵感模式:根据反馈重新生成{step}阶段的选项(第{attempt + 1}次尝试)")
|
||||
logger.info(f"用户反馈: {feedback}")
|
||||
|
||||
# 获取用户ID
|
||||
user_id = getattr(http_request.state, 'user_id', None)
|
||||
|
||||
# 获取对应的提示词模板
|
||||
template_key_map = {
|
||||
"title": ("INSPIRATION_TITLE_SYSTEM", "INSPIRATION_TITLE_USER"),
|
||||
"description": ("INSPIRATION_DESCRIPTION_SYSTEM", "INSPIRATION_DESCRIPTION_USER"),
|
||||
"theme": ("INSPIRATION_THEME_SYSTEM", "INSPIRATION_THEME_USER"),
|
||||
"genre": ("INSPIRATION_GENRE_SYSTEM", "INSPIRATION_GENRE_USER")
|
||||
}
|
||||
template_keys = template_key_map.get(step)
|
||||
|
||||
if not template_keys:
|
||||
return {
|
||||
"error": f"不支持的步骤: {step}",
|
||||
"prompt": "",
|
||||
"options": []
|
||||
}
|
||||
|
||||
system_key, user_key = template_keys
|
||||
|
||||
# 获取自定义提示词模板
|
||||
system_template = await PromptService.get_template(system_key, user_id, db)
|
||||
user_template = await PromptService.get_template(user_key, user_id, db)
|
||||
|
||||
# 准备格式化参数
|
||||
format_params = {
|
||||
"initial_idea": context.get("initial_idea", context.get("description", "")),
|
||||
"title": context.get("title", ""),
|
||||
"description": context.get("description", ""),
|
||||
"theme": context.get("theme", "")
|
||||
}
|
||||
|
||||
# 格式化提示词
|
||||
system_prompt = system_template.format(**format_params)
|
||||
user_prompt = user_template.format(**format_params)
|
||||
|
||||
# 添加反馈信息到提示词
|
||||
feedback_instruction = f"""
|
||||
|
||||
⚠️ 用户对之前的选项不太满意,提供了以下反馈:
|
||||
「{feedback}」
|
||||
|
||||
之前生成的选项:
|
||||
{chr(10).join([f"- {opt}" for opt in previous_options]) if previous_options else "(无)"}
|
||||
|
||||
请根据用户的反馈调整生成策略,提供更符合用户期望的新选项。
|
||||
注意:
|
||||
1. 仔细理解用户的反馈意图
|
||||
2. 生成的新选项要明显体现用户要求的调整方向
|
||||
3. 保持与已有上下文的一致性
|
||||
4. 确保返回6个有效选项
|
||||
"""
|
||||
|
||||
system_prompt += feedback_instruction
|
||||
|
||||
# 如果是重试,强调格式要求
|
||||
if attempt > 0:
|
||||
system_prompt += f"\n\n⚠️ 这是第{attempt + 1}次生成,请务必严格按照JSON格式返回!"
|
||||
|
||||
# 调用AI生成选项
|
||||
temperature = TEMPERATURE_SETTINGS.get(step, 0.7)
|
||||
# 反馈生成时使用稍高的temperature以获得更多样化的结果
|
||||
temperature = min(temperature + 0.1, 0.9)
|
||||
logger.info(f"调用AI根据反馈生成{step}选项... (temperature={temperature})")
|
||||
|
||||
# 流式生成并累积文本
|
||||
accumulated_text = ""
|
||||
async for chunk in ai_service.generate_text_stream(
|
||||
prompt=user_prompt,
|
||||
system_prompt=system_prompt,
|
||||
temperature=temperature
|
||||
):
|
||||
accumulated_text += chunk
|
||||
|
||||
content = accumulated_text
|
||||
logger.info(f"AI返回内容长度: {len(content)}")
|
||||
|
||||
# 解析JSON
|
||||
try:
|
||||
cleaned_content = ai_service._clean_json_response(content)
|
||||
result = json.loads(cleaned_content)
|
||||
|
||||
# 校验返回格式
|
||||
is_valid, error_msg = validate_options_response(result, step)
|
||||
|
||||
if not is_valid:
|
||||
logger.warning(f"⚠️ 第{attempt + 1}次生成格式校验失败: {error_msg}")
|
||||
if attempt < max_retries - 1:
|
||||
logger.info("准备重试...")
|
||||
continue
|
||||
else:
|
||||
return {
|
||||
"prompt": f"请为【{step}】提供内容:",
|
||||
"options": ["让AI重新生成", "我自己输入"],
|
||||
"error": f"AI生成格式错误({error_msg}),已自动重试{max_retries}次"
|
||||
}
|
||||
|
||||
logger.info(f"✅ 第{attempt + 1}次根据反馈成功生成{len(result.get('options', []))}个有效选项")
|
||||
return result
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"第{attempt + 1}次JSON解析失败: {e}")
|
||||
|
||||
if attempt < max_retries - 1:
|
||||
logger.info("JSON解析失败,准备重试...")
|
||||
continue
|
||||
else:
|
||||
return {
|
||||
"prompt": f"请为【{step}】提供内容:",
|
||||
"options": ["让AI重新生成", "我自己输入"],
|
||||
"error": f"AI返回格式错误,已自动重试{max_retries}次"
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"第{attempt + 1}次根据反馈生成失败: {e}", exc_info=True)
|
||||
if attempt < max_retries - 1:
|
||||
logger.info("发生异常,准备重试...")
|
||||
continue
|
||||
else:
|
||||
return {
|
||||
"error": str(e),
|
||||
"prompt": "生成失败,请重试",
|
||||
"options": ["重新生成", "我自己输入"]
|
||||
}
|
||||
|
||||
return {
|
||||
"error": "生成失败",
|
||||
"prompt": "请重试",
|
||||
"options": []
|
||||
}
|
||||
|
||||
|
||||
@router.post("/quick-generate")
|
||||
async def quick_generate(
|
||||
data: Dict[str, Any],
|
||||
http_request: Request,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
ai_service: AIService = Depends(get_user_ai_service)
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
智能补全:根据用户已提供的部分信息,AI自动补全缺失字段
|
||||
|
||||
Request:
|
||||
{
|
||||
"title": "书名(可选)",
|
||||
"description": "简介(可选)",
|
||||
"theme": "主题(可选)",
|
||||
"genre": ["类型1", "类型2"](可选)
|
||||
}
|
||||
|
||||
Response:
|
||||
{
|
||||
"title": "补全的书名",
|
||||
"description": "补全的简介",
|
||||
"theme": "补全的主题",
|
||||
"genre": ["补全的类型"]
|
||||
}
|
||||
"""
|
||||
try:
|
||||
logger.info("灵感模式:智能补全")
|
||||
|
||||
# 获取用户ID
|
||||
user_id = getattr(http_request.state, 'user_id', None)
|
||||
|
||||
# 构建补全提示词
|
||||
existing_info = []
|
||||
if data.get("title"):
|
||||
existing_info.append(f"- 书名:{data['title']}")
|
||||
if data.get("description"):
|
||||
existing_info.append(f"- 简介:{data['description']}")
|
||||
if data.get("theme"):
|
||||
existing_info.append(f"- 主题:{data['theme']}")
|
||||
if data.get("genre"):
|
||||
existing_info.append(f"- 类型:{', '.join(data['genre'])}")
|
||||
|
||||
existing_text = "\n".join(existing_info) if existing_info else "暂无信息"
|
||||
|
||||
# 获取自定义提示词模板
|
||||
system_template = await PromptService.get_template("INSPIRATION_QUICK_COMPLETE", user_id, db)
|
||||
|
||||
# 格式化提示词
|
||||
prompts = {
|
||||
"system": PromptService.format_prompt(system_template, existing=existing_text),
|
||||
"user": "请补全小说信息"
|
||||
}
|
||||
|
||||
# 调用AI - 流式生成并累积文本
|
||||
accumulated_text = ""
|
||||
async for chunk in ai_service.generate_text_stream(
|
||||
prompt=prompts["user"],
|
||||
system_prompt=prompts["system"],
|
||||
temperature=0.7
|
||||
):
|
||||
accumulated_text += chunk
|
||||
|
||||
response = {"content": accumulated_text}
|
||||
content = accumulated_text
|
||||
|
||||
# 解析JSON(使用统一的JSON清洗方法)
|
||||
try:
|
||||
# 使用统一的JSON清洗方法
|
||||
cleaned_content = ai_service._clean_json_response(content)
|
||||
|
||||
result = json.loads(cleaned_content)
|
||||
|
||||
# 合并用户已提供的信息(用户输入优先)
|
||||
final_result = {
|
||||
"title": data.get("title") or result.get("title", ""),
|
||||
"description": data.get("description") or result.get("description", ""),
|
||||
"theme": data.get("theme") or result.get("theme", ""),
|
||||
"genre": data.get("genre") or result.get("genre", [])
|
||||
}
|
||||
|
||||
logger.info(f"✅ 智能补全成功")
|
||||
return final_result
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"JSON解析失败: {e}")
|
||||
raise Exception("AI返回格式错误,请重试")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"智能补全失败: {e}", exc_info=True)
|
||||
return {
|
||||
"error": str(e)
|
||||
}
|
||||
@@ -0,0 +1,801 @@
|
||||
"""MCP插件管理API
|
||||
|
||||
重构后使用统一的MCPClientFacade门面来管理所有MCP操作。
|
||||
"""
|
||||
from fastapi import APIRouter, HTTPException, Depends, Query, Request
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select
|
||||
from typing import List, Optional
|
||||
from datetime import datetime
|
||||
|
||||
from app.database import get_db
|
||||
from app.models.mcp_plugin import MCPPlugin
|
||||
from app.schemas.mcp_plugin import (
|
||||
MCPPluginCreate,
|
||||
MCPPluginSimpleCreate,
|
||||
MCPPluginUpdate,
|
||||
MCPPluginResponse,
|
||||
MCPToolCall,
|
||||
MCPTestResult
|
||||
)
|
||||
import json
|
||||
from app.user_manager import User
|
||||
from app.mcp import mcp_client, MCPPluginConfig, PluginStatus
|
||||
from app.services.mcp_test_service import mcp_test_service
|
||||
from app.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/mcp/plugins", tags=["MCP插件管理"])
|
||||
|
||||
|
||||
def require_login(request: Request) -> User:
|
||||
"""依赖:要求用户已登录"""
|
||||
if not hasattr(request.state, "user") or not request.state.user:
|
||||
raise HTTPException(status_code=401, detail="需要登录")
|
||||
return request.state.user
|
||||
|
||||
|
||||
async def _register_plugin_to_facade(plugin: MCPPlugin, user_id: str) -> bool:
|
||||
"""
|
||||
将插件注册到统一门面
|
||||
|
||||
Args:
|
||||
plugin: 插件对象
|
||||
user_id: 用户ID
|
||||
|
||||
Returns:
|
||||
是否注册成功
|
||||
"""
|
||||
if plugin.plugin_type in ["http", "streamable_http", "sse"] and plugin.server_url:
|
||||
return await mcp_client.register(MCPPluginConfig(
|
||||
user_id=user_id,
|
||||
plugin_name=plugin.plugin_name,
|
||||
url=plugin.server_url,
|
||||
plugin_type=plugin.plugin_type,
|
||||
headers=plugin.headers,
|
||||
timeout=plugin.config.get('timeout', 60.0) if plugin.config else 60.0
|
||||
))
|
||||
else:
|
||||
logger.warning(f"暂不支持的插件类型: {plugin.plugin_type}")
|
||||
return False
|
||||
|
||||
|
||||
@router.get("", response_model=List[MCPPluginResponse])
|
||||
async def list_plugins(
|
||||
enabled_only: bool = Query(False, description="只返回启用的插件"),
|
||||
category: Optional[str] = Query(None, description="按分类筛选"),
|
||||
user: User = Depends(require_login),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
获取用户的所有MCP插件
|
||||
"""
|
||||
query = select(MCPPlugin).where(MCPPlugin.user_id == user.user_id)
|
||||
|
||||
if enabled_only:
|
||||
query = query.where(MCPPlugin.enabled == True)
|
||||
|
||||
if category:
|
||||
query = query.where(MCPPlugin.category == category)
|
||||
|
||||
query = query.order_by(MCPPlugin.sort_order, MCPPlugin.created_at)
|
||||
|
||||
result = await db.execute(query)
|
||||
plugins = result.scalars().all()
|
||||
|
||||
logger.info(f"用户 {user.user_id} 查询插件列表,共 {len(plugins)} 个")
|
||||
return plugins
|
||||
|
||||
|
||||
@router.post("", response_model=MCPPluginResponse)
|
||||
async def create_plugin(
|
||||
data: MCPPluginCreate,
|
||||
user: User = Depends(require_login),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
创建新的MCP插件
|
||||
"""
|
||||
# 检查插件名是否已存在
|
||||
result = await db.execute(
|
||||
select(MCPPlugin).where(
|
||||
MCPPlugin.user_id == user.user_id,
|
||||
MCPPlugin.plugin_name == data.plugin_name
|
||||
)
|
||||
)
|
||||
existing = result.scalar_one_or_none()
|
||||
|
||||
if existing:
|
||||
raise HTTPException(status_code=400, detail=f"插件名已存在: {data.plugin_name}")
|
||||
|
||||
# 创建插件数据
|
||||
plugin_data = data.model_dump()
|
||||
|
||||
# 如果没有提供display_name,使用plugin_name作为默认值
|
||||
if not plugin_data.get("display_name"):
|
||||
plugin_data["display_name"] = plugin_data["plugin_name"]
|
||||
|
||||
# 创建插件
|
||||
plugin = MCPPlugin(
|
||||
user_id=user.user_id,
|
||||
**plugin_data
|
||||
)
|
||||
|
||||
db.add(plugin)
|
||||
await db.commit()
|
||||
await db.refresh(plugin)
|
||||
|
||||
# 如果启用,注册到统一门面
|
||||
if plugin.enabled:
|
||||
success = await _register_plugin_to_facade(plugin, user.user_id)
|
||||
if success:
|
||||
plugin.status = "active"
|
||||
else:
|
||||
plugin.status = "error"
|
||||
plugin.last_error = "加载失败"
|
||||
await db.commit()
|
||||
await db.refresh(plugin)
|
||||
|
||||
logger.info(f"用户 {user.user_id} 创建插件: {plugin.plugin_name}")
|
||||
return plugin
|
||||
|
||||
|
||||
@router.post("/simple", response_model=MCPPluginResponse)
|
||||
async def create_plugin_simple(
|
||||
data: MCPPluginSimpleCreate,
|
||||
user: User = Depends(require_login),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
通过标准MCP配置JSON创建或更新插件(简化版)
|
||||
|
||||
接受格式:
|
||||
{
|
||||
"config_json": '{"mcpServers": {"exa": {"type": "http", "url": "...", "headers": {}}}}',
|
||||
"category": "search"
|
||||
}
|
||||
|
||||
自动从mcpServers中提取插件名称(取第一个键)
|
||||
如果插件已存在,则更新;否则创建新插件
|
||||
"""
|
||||
try:
|
||||
# 解析配置JSON
|
||||
config = json.loads(data.config_json)
|
||||
|
||||
# 验证格式
|
||||
if "mcpServers" not in config:
|
||||
raise HTTPException(status_code=400, detail="配置JSON必须包含mcpServers字段")
|
||||
|
||||
servers = config["mcpServers"]
|
||||
if not servers or len(servers) == 0:
|
||||
raise HTTPException(status_code=400, detail="mcpServers不能为空")
|
||||
|
||||
# 自动提取第一个插件名称
|
||||
plugin_name = list(servers.keys())[0]
|
||||
server_config = servers[plugin_name]
|
||||
|
||||
logger.info(f"从配置中提取插件名称: {plugin_name}")
|
||||
|
||||
# 提取配置
|
||||
server_type = server_config.get("type", "http")
|
||||
|
||||
if server_type not in ["http", "stdio", "streamable_http", "sse"]:
|
||||
raise HTTPException(status_code=400, detail=f"不支持的服务器类型: {server_type}")
|
||||
|
||||
# 检查插件名是否已存在
|
||||
result = await db.execute(
|
||||
select(MCPPlugin).where(
|
||||
MCPPlugin.user_id == user.user_id,
|
||||
MCPPlugin.plugin_name == plugin_name
|
||||
)
|
||||
)
|
||||
existing = result.scalar_one_or_none()
|
||||
|
||||
# 构建插件数据
|
||||
plugin_data = {
|
||||
"plugin_name": plugin_name,
|
||||
"display_name": plugin_name,
|
||||
"plugin_type": server_type,
|
||||
"enabled": data.enabled,
|
||||
"category": data.category,
|
||||
"sort_order": 0
|
||||
}
|
||||
|
||||
if server_type in ["http", "streamable_http", "sse"]:
|
||||
plugin_data["server_url"] = server_config.get("url")
|
||||
plugin_data["headers"] = server_config.get("headers", {})
|
||||
|
||||
if not plugin_data["server_url"]:
|
||||
raise HTTPException(status_code=400, detail=f"{server_type}类型插件必须提供url字段")
|
||||
|
||||
elif server_type == "stdio":
|
||||
plugin_data["command"] = server_config.get("command")
|
||||
plugin_data["args"] = server_config.get("args", [])
|
||||
plugin_data["env"] = server_config.get("env", {})
|
||||
|
||||
if not plugin_data["command"]:
|
||||
raise HTTPException(status_code=400, detail="Stdio类型插件必须提供command字段")
|
||||
|
||||
if existing:
|
||||
# 更新现有插件
|
||||
logger.info(f"插件 {plugin_name} 已存在,执行更新操作")
|
||||
|
||||
# 保存旧状态
|
||||
old_enabled = existing.enabled
|
||||
old_plugin_name = existing.plugin_name
|
||||
|
||||
# 更新字段
|
||||
for key, value in plugin_data.items():
|
||||
setattr(existing, key, value)
|
||||
|
||||
plugin = existing
|
||||
await db.commit()
|
||||
await db.refresh(plugin)
|
||||
|
||||
# 数据库完成后进行MCP操作
|
||||
if old_enabled:
|
||||
try:
|
||||
await mcp_client.unregister(user.user_id, old_plugin_name)
|
||||
except Exception as e:
|
||||
logger.warning(f"注销旧插件出错: {e}")
|
||||
|
||||
if plugin.enabled:
|
||||
try:
|
||||
success = await _register_plugin_to_facade(plugin, user.user_id)
|
||||
plugin.status = "active" if success else "error"
|
||||
plugin.last_error = None if success else "加载失败"
|
||||
await db.commit()
|
||||
except Exception as e:
|
||||
logger.error(f"注册插件失败: {e}")
|
||||
plugin.status = "error"
|
||||
plugin.last_error = str(e)
|
||||
await db.commit()
|
||||
|
||||
logger.info(f"用户 {user.user_id} 更新插件: {plugin_name}")
|
||||
else:
|
||||
# 创建新插件
|
||||
plugin = MCPPlugin(
|
||||
user_id=user.user_id,
|
||||
**plugin_data
|
||||
)
|
||||
|
||||
db.add(plugin)
|
||||
await db.commit()
|
||||
await db.refresh(plugin)
|
||||
|
||||
# 数据库完成后进行MCP操作
|
||||
if plugin.enabled:
|
||||
try:
|
||||
success = await _register_plugin_to_facade(plugin, user.user_id)
|
||||
plugin.status = "active" if success else "error"
|
||||
plugin.last_error = None if success else "加载失败"
|
||||
await db.commit()
|
||||
except Exception as e:
|
||||
logger.error(f"注册插件失败: {e}")
|
||||
plugin.status = "error"
|
||||
plugin.last_error = str(e)
|
||||
await db.commit()
|
||||
|
||||
logger.info(f"用户 {user.user_id} 通过简化配置创建插件: {plugin_name}")
|
||||
|
||||
return plugin
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
raise HTTPException(status_code=400, detail=f"配置JSON格式错误: {str(e)}")
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"创建插件失败: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"创建插件失败: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/{plugin_id}", response_model=MCPPluginResponse)
|
||||
async def get_plugin(
|
||||
plugin_id: str,
|
||||
user: User = Depends(require_login),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
获取插件详情
|
||||
"""
|
||||
result = await db.execute(
|
||||
select(MCPPlugin).where(
|
||||
MCPPlugin.id == plugin_id,
|
||||
MCPPlugin.user_id == user.user_id
|
||||
)
|
||||
)
|
||||
plugin = result.scalar_one_or_none()
|
||||
|
||||
if not plugin:
|
||||
raise HTTPException(status_code=404, detail="插件不存在")
|
||||
|
||||
return plugin
|
||||
|
||||
|
||||
@router.put("/{plugin_id}", response_model=MCPPluginResponse)
|
||||
async def update_plugin(
|
||||
plugin_id: str,
|
||||
data: MCPPluginUpdate,
|
||||
user: User = Depends(require_login),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
更新插件配置
|
||||
"""
|
||||
result = await db.execute(
|
||||
select(MCPPlugin).where(
|
||||
MCPPlugin.id == plugin_id,
|
||||
MCPPlugin.user_id == user.user_id
|
||||
)
|
||||
)
|
||||
plugin = result.scalar_one_or_none()
|
||||
|
||||
if not plugin:
|
||||
raise HTTPException(status_code=404, detail="插件不存在")
|
||||
|
||||
# 更新字段
|
||||
update_data = data.model_dump(exclude_unset=True)
|
||||
for key, value in update_data.items():
|
||||
setattr(plugin, key, value)
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(plugin)
|
||||
|
||||
# 如果插件已启用,重新注册
|
||||
if plugin.enabled:
|
||||
await mcp_client.unregister(user.user_id, plugin.plugin_name)
|
||||
await _register_plugin_to_facade(plugin, user.user_id)
|
||||
|
||||
logger.info(f"用户 {user.user_id} 更新插件: {plugin.plugin_name}")
|
||||
return plugin
|
||||
|
||||
|
||||
@router.delete("/{plugin_id}")
|
||||
async def delete_plugin(
|
||||
plugin_id: str,
|
||||
user: User = Depends(require_login),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
删除插件
|
||||
"""
|
||||
result = await db.execute(
|
||||
select(MCPPlugin).where(
|
||||
MCPPlugin.id == plugin_id,
|
||||
MCPPlugin.user_id == user.user_id
|
||||
)
|
||||
)
|
||||
plugin = result.scalar_one_or_none()
|
||||
|
||||
if not plugin:
|
||||
raise HTTPException(status_code=404, detail="插件不存在")
|
||||
|
||||
# 从统一门面注销
|
||||
await mcp_client.unregister(user.user_id, plugin.plugin_name)
|
||||
|
||||
# 删除数据库记录
|
||||
await db.delete(plugin)
|
||||
await db.commit()
|
||||
|
||||
logger.info(f"用户 {user.user_id} 删除插件: {plugin.plugin_name}")
|
||||
return {"message": "插件已删除", "plugin_name": plugin.plugin_name}
|
||||
|
||||
|
||||
@router.post("/{plugin_id}/toggle", response_model=MCPPluginResponse)
|
||||
async def toggle_plugin(
|
||||
plugin_id: str,
|
||||
enabled: bool = Query(..., description="启用或禁用"),
|
||||
user: User = Depends(require_login),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
启用或禁用插件
|
||||
"""
|
||||
result = await db.execute(
|
||||
select(MCPPlugin).where(
|
||||
MCPPlugin.id == plugin_id,
|
||||
MCPPlugin.user_id == user.user_id
|
||||
)
|
||||
)
|
||||
plugin = result.scalar_one_or_none()
|
||||
|
||||
if not plugin:
|
||||
raise HTTPException(status_code=404, detail="插件不存在")
|
||||
|
||||
# 保存插件信息用于后续MCP操作
|
||||
plugin_name = plugin.plugin_name
|
||||
plugin_type = plugin.plugin_type
|
||||
server_url = plugin.server_url
|
||||
headers = plugin.headers
|
||||
config = plugin.config
|
||||
|
||||
# 先更新数据库状态
|
||||
plugin.enabled = enabled
|
||||
if not enabled:
|
||||
plugin.status = "inactive"
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(plugin)
|
||||
|
||||
# 数据库操作完成后,再进行MCP操作
|
||||
if enabled:
|
||||
# 启用:注册到统一门面
|
||||
try:
|
||||
if plugin_type in ["http", "streamable_http", "sse"] and server_url:
|
||||
success = await mcp_client.register(MCPPluginConfig(
|
||||
user_id=user.user_id,
|
||||
plugin_name=plugin_name,
|
||||
url=server_url,
|
||||
plugin_type=plugin_type,
|
||||
headers=headers,
|
||||
timeout=config.get('timeout', 60.0) if config else 60.0
|
||||
))
|
||||
else:
|
||||
success = False
|
||||
|
||||
# 更新状态
|
||||
plugin.status = "active" if success else "error"
|
||||
plugin.last_error = None if success else "加载失败"
|
||||
await db.commit()
|
||||
await db.refresh(plugin)
|
||||
except Exception as e:
|
||||
logger.error(f"注册插件失败: {plugin_name}, 错误: {e}")
|
||||
plugin.status = "error"
|
||||
plugin.last_error = str(e)
|
||||
await db.commit()
|
||||
await db.refresh(plugin)
|
||||
else:
|
||||
# 禁用:从统一门面注销(不影响数据库状态)
|
||||
try:
|
||||
await mcp_client.unregister(user.user_id, plugin_name)
|
||||
except Exception as e:
|
||||
logger.warning(f"注销插件时出错(可忽略): {plugin_name}, 错误: {e}")
|
||||
|
||||
action = "启用" if enabled else "禁用"
|
||||
logger.info(f"用户 {user.user_id} {action}插件: {plugin_name}")
|
||||
return plugin
|
||||
|
||||
|
||||
@router.post("/{plugin_id}/test", response_model=MCPTestResult)
|
||||
async def test_plugin(
|
||||
plugin_id: str,
|
||||
user: User = Depends(require_login),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
测试插件连接并调用工具验证功能
|
||||
|
||||
使用MCPTestService进行测试
|
||||
"""
|
||||
|
||||
result = await db.execute(
|
||||
select(MCPPlugin).where(
|
||||
MCPPlugin.id == plugin_id,
|
||||
MCPPlugin.user_id == user.user_id
|
||||
)
|
||||
)
|
||||
plugin = result.scalar_one_or_none()
|
||||
|
||||
if not plugin:
|
||||
raise HTTPException(status_code=404, detail="插件不存在")
|
||||
|
||||
if not plugin.enabled:
|
||||
return MCPTestResult(
|
||||
success=False,
|
||||
message="插件未启用",
|
||||
error="请先启用插件",
|
||||
suggestions=["点击开关按钮启用插件"]
|
||||
)
|
||||
|
||||
# 使用测试服务
|
||||
try:
|
||||
test_result = await mcp_test_service.test_plugin_with_ai(plugin, user, db)
|
||||
|
||||
# 更新插件状态
|
||||
if test_result.success:
|
||||
plugin.status = "active"
|
||||
plugin.last_error = None
|
||||
else:
|
||||
plugin.status = "error"
|
||||
plugin.last_error = test_result.error
|
||||
|
||||
plugin.last_test_at = datetime.now()
|
||||
await db.commit()
|
||||
|
||||
return test_result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"测试插件失败: {plugin.plugin_name}, 错误: {e}")
|
||||
plugin.status = "error"
|
||||
plugin.last_error = str(e)
|
||||
plugin.last_test_at = datetime.now()
|
||||
await db.commit()
|
||||
raise HTTPException(status_code=500, detail=f"测试失败: {str(e)}")
|
||||
|
||||
|
||||
async def _ensure_plugin_registered(
|
||||
plugin: MCPPlugin,
|
||||
user_id: str
|
||||
) -> bool:
|
||||
"""
|
||||
确保插件已注册到统一门面
|
||||
|
||||
Args:
|
||||
plugin: 插件对象
|
||||
user_id: 用户ID
|
||||
|
||||
Returns:
|
||||
是否成功
|
||||
|
||||
Raises:
|
||||
HTTPException: 注册失败
|
||||
"""
|
||||
try:
|
||||
# 使用ensure_registered方法,它会检查是否已注册
|
||||
if plugin.plugin_type in ["http", "streamable_http", "sse"] and plugin.server_url:
|
||||
return await mcp_client.ensure_registered(
|
||||
user_id=user_id,
|
||||
plugin_name=plugin.plugin_name,
|
||||
url=plugin.server_url,
|
||||
plugin_type=plugin.plugin_type,
|
||||
headers=plugin.headers
|
||||
)
|
||||
return False
|
||||
except ValueError as e:
|
||||
logger.info(f"插件 {plugin.plugin_name} 未注册,自动注册中...")
|
||||
success = await _register_plugin_to_facade(plugin, user_id)
|
||||
if not success:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"插件注册失败: {plugin.plugin_name}"
|
||||
)
|
||||
return True
|
||||
|
||||
|
||||
@router.get("/{plugin_id}/status")
|
||||
async def get_plugin_status(
|
||||
plugin_id: str,
|
||||
user: User = Depends(require_login),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""获取插件的实时状态(包括内存中的会话状态)"""
|
||||
result = await db.execute(
|
||||
select(MCPPlugin).where(
|
||||
MCPPlugin.id == plugin_id,
|
||||
MCPPlugin.user_id == user.user_id
|
||||
)
|
||||
)
|
||||
plugin = result.scalar_one_or_none()
|
||||
|
||||
if not plugin:
|
||||
raise HTTPException(status_code=404, detail="插件不存在")
|
||||
|
||||
session_stats = mcp_client.get_session_stats()
|
||||
session_key = f"{user.user_id}:{plugin.plugin_name}"
|
||||
session_info = next((s for s in session_stats.get("sessions", []) if s["key"] == session_key), None)
|
||||
|
||||
return {
|
||||
"plugin_id": plugin_id,
|
||||
"plugin_name": plugin.plugin_name,
|
||||
"db_status": plugin.status,
|
||||
"session_status": session_info["status"] if session_info else None,
|
||||
"is_registered": session_info is not None,
|
||||
"error_rate": session_info["error_rate"] if session_info else 0,
|
||||
"in_sync": (plugin.status == session_info["status"]) if session_info else (plugin.status == "inactive"),
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
|
||||
@router.get("/metrics")
|
||||
async def get_metrics(
|
||||
tool_name: Optional[str] = Query(None, description="工具名称(可选,获取特定工具的指标)"),
|
||||
user: User = Depends(require_login)
|
||||
):
|
||||
"""
|
||||
获取MCP工具调用指标
|
||||
|
||||
Query参数:
|
||||
- tool_name: 可选,指定工具名称获取特定工具的指标
|
||||
|
||||
Returns:
|
||||
工具调用指标字典,包含:
|
||||
- total_calls: 总调用次数
|
||||
- success_calls: 成功调用次数
|
||||
- failed_calls: 失败调用次数
|
||||
- success_rate: 成功率
|
||||
- avg_duration_ms: 平均耗时(毫秒)
|
||||
- last_call_time: 最后调用时间
|
||||
"""
|
||||
# 使用统一门面获取指标
|
||||
metrics = mcp_client.get_metrics(tool_name)
|
||||
|
||||
return {
|
||||
"metrics": metrics,
|
||||
"tool_name": tool_name,
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
|
||||
@router.get("/cache/stats")
|
||||
async def get_cache_stats(
|
||||
user: User = Depends(require_login)
|
||||
):
|
||||
"""
|
||||
获取工具缓存统计信息
|
||||
|
||||
Returns:
|
||||
缓存统计信息,包含:
|
||||
- total_entries: 缓存条目总数
|
||||
- total_hits: 缓存总命中次数
|
||||
- cache_ttl_minutes: 缓存TTL(分钟)
|
||||
- entries: 各缓存条目详情
|
||||
"""
|
||||
# 使用统一门面获取缓存统计
|
||||
stats = mcp_client.get_cache_stats()
|
||||
|
||||
return {
|
||||
"cache_stats": stats,
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
|
||||
@router.get("/sessions/stats")
|
||||
async def get_session_stats(
|
||||
user: User = Depends(require_login)
|
||||
):
|
||||
"""
|
||||
获取MCP会话统计信息
|
||||
|
||||
Returns:
|
||||
会话统计信息,包含:
|
||||
- total_sessions: 会话总数
|
||||
- sessions: 各会话详情
|
||||
"""
|
||||
# 使用统一门面获取会话统计
|
||||
stats = mcp_client.get_session_stats()
|
||||
|
||||
return {
|
||||
"session_stats": stats,
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
|
||||
@router.post("/cache/clear")
|
||||
async def clear_cache(
|
||||
user_id: Optional[str] = Query(None, description="用户ID(可选)"),
|
||||
plugin_name: Optional[str] = Query(None, description="插件名称(可选)"),
|
||||
user: User = Depends(require_login)
|
||||
):
|
||||
"""
|
||||
清理工具缓存
|
||||
|
||||
Query参数:
|
||||
- user_id: 可选,清理特定用户的缓存
|
||||
- plugin_name: 可选,清理特定插件的缓存
|
||||
|
||||
说明:
|
||||
- 不提供任何参数:清理所有缓存
|
||||
- 只提供user_id:清理该用户的所有缓存
|
||||
- 提供user_id和plugin_name:清理特定插件的缓存
|
||||
"""
|
||||
# 非管理员只能清理自己的缓存
|
||||
if user_id and user_id != user.user_id:
|
||||
raise HTTPException(status_code=403, detail="无权清理其他用户的缓存")
|
||||
|
||||
# 如果没有指定user_id,使用当前用户
|
||||
target_user_id = user_id or user.user_id
|
||||
|
||||
# 使用统一门面清理缓存
|
||||
mcp_client.clear_cache(target_user_id, plugin_name)
|
||||
|
||||
message = "已清理"
|
||||
if plugin_name:
|
||||
message += f"插件 {plugin_name} 的缓存"
|
||||
elif target_user_id:
|
||||
message += f"用户 {target_user_id} 的所有缓存"
|
||||
else:
|
||||
message += "所有缓存"
|
||||
|
||||
logger.info(f"用户 {user.user_id} {message}")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": message,
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
|
||||
@router.get("/{plugin_id}/tools")
|
||||
async def get_plugin_tools(
|
||||
plugin_id: str,
|
||||
user: User = Depends(require_login),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
获取插件提供的工具列表
|
||||
"""
|
||||
result = await db.execute(
|
||||
select(MCPPlugin).where(
|
||||
MCPPlugin.id == plugin_id,
|
||||
MCPPlugin.user_id == user.user_id
|
||||
)
|
||||
)
|
||||
plugin = result.scalar_one_or_none()
|
||||
|
||||
if not plugin:
|
||||
raise HTTPException(status_code=404, detail="插件不存在")
|
||||
|
||||
if not plugin.enabled:
|
||||
raise HTTPException(status_code=400, detail="插件未启用")
|
||||
|
||||
try:
|
||||
# 确保插件已注册
|
||||
await _ensure_plugin_registered(plugin, user.user_id)
|
||||
|
||||
# 使用统一门面获取工具列表
|
||||
tools = await mcp_client.get_tools(user.user_id, plugin.plugin_name)
|
||||
|
||||
# 更新数据库中的工具缓存
|
||||
plugin.tools = tools
|
||||
await db.commit()
|
||||
|
||||
return {
|
||||
"plugin_name": plugin.plugin_name,
|
||||
"tools": tools,
|
||||
"count": len(tools)
|
||||
}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"获取工具列表失败: {plugin.plugin_name}, 错误: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"获取工具列表失败: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/call")
|
||||
async def call_mcp_tool(
|
||||
data: MCPToolCall,
|
||||
user: User = Depends(require_login),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
调用MCP工具
|
||||
"""
|
||||
# 获取插件
|
||||
result = await db.execute(
|
||||
select(MCPPlugin).where(
|
||||
MCPPlugin.id == data.plugin_id,
|
||||
MCPPlugin.user_id == user.user_id
|
||||
)
|
||||
)
|
||||
plugin = result.scalar_one_or_none()
|
||||
|
||||
if not plugin:
|
||||
raise HTTPException(status_code=404, detail="插件不存在")
|
||||
|
||||
if not plugin.enabled:
|
||||
raise HTTPException(status_code=400, detail="插件未启用")
|
||||
|
||||
try:
|
||||
# 确保插件已注册
|
||||
await _ensure_plugin_registered(plugin, user.user_id)
|
||||
|
||||
# 使用统一门面调用工具
|
||||
tool_result = await mcp_client.call_tool(
|
||||
user_id=user.user_id,
|
||||
plugin_name=plugin.plugin_name,
|
||||
tool_name=data.tool_name,
|
||||
arguments=data.arguments
|
||||
)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"plugin_name": plugin.plugin_name,
|
||||
"tool_name": data.tool_name,
|
||||
"result": tool_result
|
||||
}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"调用工具失败: {plugin.plugin_name}.{data.tool_name}, 错误: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"工具调用失败: {str(e)}")
|
||||
@@ -6,6 +6,7 @@ from typing import List, Optional
|
||||
from app.database import get_db
|
||||
from app.models.memory import StoryMemory, PlotAnalysis
|
||||
from app.models.chapter import Chapter
|
||||
from app.models.project import Project
|
||||
from app.services.memory_service import memory_service
|
||||
from app.services.plot_analyzer import get_plot_analyzer
|
||||
from app.services.ai_service import create_user_ai_service
|
||||
@@ -17,6 +18,26 @@ logger = get_logger(__name__)
|
||||
router = APIRouter(prefix="/api/memories", tags=["memories"])
|
||||
|
||||
|
||||
async def verify_project_access(project_id: str, user_id: str, db: AsyncSession) -> Project:
|
||||
"""验证用户是否有权访问指定项目"""
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="未登录")
|
||||
|
||||
result = await db.execute(
|
||||
select(Project).where(
|
||||
Project.id == project_id,
|
||||
Project.user_id == user_id
|
||||
)
|
||||
)
|
||||
project = result.scalar_one_or_none()
|
||||
|
||||
if not project:
|
||||
logger.warning(f"项目访问被拒绝: project_id={project_id}, user_id={user_id}")
|
||||
raise HTTPException(status_code=404, detail="项目不存在或无权访问")
|
||||
|
||||
return project
|
||||
|
||||
|
||||
@router.post("/projects/{project_id}/analyze-chapter/{chapter_id}")
|
||||
async def analyze_chapter(
|
||||
project_id: str,
|
||||
@@ -30,7 +51,10 @@ async def analyze_chapter(
|
||||
对指定章节进行剧情分析,提取钩子、伏笔、情节点等,并存入记忆系统
|
||||
"""
|
||||
try:
|
||||
user_id = request.state.user_id
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
|
||||
# 验证用户权限
|
||||
await verify_project_access(project_id, user_id, db)
|
||||
|
||||
# 获取章节内容
|
||||
result = await db.execute(
|
||||
@@ -192,7 +216,10 @@ async def get_project_memories(
|
||||
):
|
||||
"""获取项目的记忆列表"""
|
||||
try:
|
||||
user_id = request.state.user_id
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
|
||||
# 验证用户权限
|
||||
await verify_project_access(project_id, user_id, db)
|
||||
|
||||
# 构建查询
|
||||
query = select(StoryMemory).where(StoryMemory.project_id == project_id)
|
||||
@@ -222,10 +249,16 @@ async def get_project_memories(
|
||||
async def get_chapter_analysis(
|
||||
project_id: str,
|
||||
chapter_id: str,
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""获取章节的剧情分析"""
|
||||
try:
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
|
||||
# 验证用户权限
|
||||
await verify_project_access(project_id, user_id, db)
|
||||
|
||||
result = await db.execute(
|
||||
select(PlotAnalysis).where(
|
||||
and_(
|
||||
@@ -258,11 +291,15 @@ async def search_memories(
|
||||
query: str,
|
||||
memory_types: Optional[List[str]] = None,
|
||||
limit: int = 10,
|
||||
min_importance: float = 0.0
|
||||
min_importance: float = 0.0,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""语义搜索项目记忆"""
|
||||
try:
|
||||
user_id = request.state.user_id
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
|
||||
# 验证用户权限
|
||||
await verify_project_access(project_id, user_id, db)
|
||||
|
||||
memories = await memory_service.search_memories(
|
||||
user_id=user_id,
|
||||
@@ -294,7 +331,10 @@ async def get_unresolved_foreshadows(
|
||||
):
|
||||
"""获取未完结的伏笔"""
|
||||
try:
|
||||
user_id = request.state.user_id
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
|
||||
# 验证用户权限
|
||||
await verify_project_access(project_id, user_id, db)
|
||||
|
||||
# 从向量库搜索
|
||||
foreshadows = await memory_service.find_unresolved_foreshadows(
|
||||
@@ -317,11 +357,15 @@ async def get_unresolved_foreshadows(
|
||||
@router.get("/projects/{project_id}/stats")
|
||||
async def get_memory_stats(
|
||||
project_id: str,
|
||||
request: Request
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""获取记忆统计信息"""
|
||||
try:
|
||||
user_id = request.state.user_id
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
|
||||
# 验证用户权限
|
||||
await verify_project_access(project_id, user_id, db)
|
||||
|
||||
stats = await memory_service.get_memory_stats(
|
||||
user_id=user_id,
|
||||
@@ -347,7 +391,10 @@ async def delete_chapter_memories(
|
||||
):
|
||||
"""删除章节的所有记忆"""
|
||||
try:
|
||||
user_id = request.state.user_id
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
|
||||
# 验证用户权限
|
||||
await verify_project_access(project_id, user_id, db)
|
||||
|
||||
# 从数据库删除
|
||||
result = await db.execute(
|
||||
|
||||
@@ -1,12 +1,17 @@
|
||||
"""组织管理API"""
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, and_
|
||||
from typing import List
|
||||
from typing import List, Optional, AsyncGenerator
|
||||
from pydantic import BaseModel, Field
|
||||
import json
|
||||
|
||||
from app.database import get_db
|
||||
from app.utils.sse_response import SSEResponse, create_sse_response
|
||||
from app.models.relationship import Organization, OrganizationMember
|
||||
from app.models.character import Character
|
||||
from app.models.project import Project
|
||||
from app.models.generation_history import GenerationHistory
|
||||
from app.schemas.relationship import (
|
||||
OrganizationCreate,
|
||||
OrganizationUpdate,
|
||||
@@ -17,17 +22,56 @@ from app.schemas.relationship import (
|
||||
OrganizationMemberResponse,
|
||||
OrganizationMemberDetailResponse
|
||||
)
|
||||
from app.schemas.character import CharacterResponse
|
||||
from app.services.ai_service import AIService
|
||||
from app.services.prompt_service import prompt_service, PromptService
|
||||
from app.logger import get_logger
|
||||
from app.api.settings import get_user_ai_service
|
||||
|
||||
router = APIRouter(prefix="/organizations", tags=["组织管理"])
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
async def verify_project_access(project_id: str, user_id: str, db: AsyncSession) -> Project:
|
||||
"""验证用户是否有权访问指定项目"""
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="未登录")
|
||||
|
||||
result = await db.execute(
|
||||
select(Project).where(
|
||||
Project.id == project_id,
|
||||
Project.user_id == user_id
|
||||
)
|
||||
)
|
||||
project = result.scalar_one_or_none()
|
||||
|
||||
if not project:
|
||||
logger.warning(f"项目访问被拒绝: project_id={project_id}, user_id={user_id}")
|
||||
raise HTTPException(status_code=404, detail="项目不存在或无权访问")
|
||||
|
||||
return project
|
||||
|
||||
|
||||
class OrganizationGenerateRequest(BaseModel):
|
||||
"""AI生成组织的请求模型"""
|
||||
project_id: str = Field(..., description="项目ID")
|
||||
name: Optional[str] = Field(None, description="组织名称")
|
||||
organization_type: Optional[str] = Field(None, description="组织类型")
|
||||
background: Optional[str] = Field(None, description="组织背景")
|
||||
requirements: Optional[str] = Field(None, description="特殊要求")
|
||||
enable_mcp: bool = Field(True, description="是否启用MCP工具增强(搜索组织架构参考)")
|
||||
|
||||
|
||||
@router.get("/project/{project_id}", response_model=List[OrganizationDetailResponse], summary="获取项目的所有组织")
|
||||
async def get_project_organizations(
|
||||
project_id: str,
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
# 验证用户权限
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
await verify_project_access(project_id, user_id, db)
|
||||
|
||||
"""
|
||||
获取项目中的所有组织及其详情
|
||||
|
||||
@@ -67,6 +111,7 @@ async def get_project_organizations(
|
||||
@router.get("/{org_id}", response_model=OrganizationResponse, summary="获取组织详情")
|
||||
async def get_organization(
|
||||
org_id: str,
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""获取组织的详细信息"""
|
||||
@@ -78,12 +123,17 @@ async def get_organization(
|
||||
if not org:
|
||||
raise HTTPException(status_code=404, detail="组织不存在")
|
||||
|
||||
# 验证用户权限
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
await verify_project_access(org.project_id, user_id, db)
|
||||
|
||||
return org
|
||||
|
||||
|
||||
@router.post("/", response_model=OrganizationResponse, summary="创建组织")
|
||||
@router.post("", response_model=OrganizationResponse, summary="创建组织")
|
||||
async def create_organization(
|
||||
organization: OrganizationCreate,
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
@@ -92,6 +142,10 @@ async def create_organization(
|
||||
- 需要关联到一个已存在的角色记录(is_organization=True)
|
||||
- 可以设置父组织、势力等级等属性
|
||||
"""
|
||||
# 验证用户权限
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
await verify_project_access(organization.project_id, user_id, db)
|
||||
|
||||
# 验证角色是否存在且是组织
|
||||
char_result = await db.execute(
|
||||
select(Character).where(Character.id == organization.character_id)
|
||||
@@ -124,6 +178,7 @@ async def create_organization(
|
||||
async def update_organization(
|
||||
org_id: str,
|
||||
organization: OrganizationUpdate,
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""更新组织的属性"""
|
||||
@@ -135,7 +190,11 @@ async def update_organization(
|
||||
if not db_org:
|
||||
raise HTTPException(status_code=404, detail="组织不存在")
|
||||
|
||||
# 更新字段
|
||||
# 验证用户权限
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
await verify_project_access(db_org.project_id, user_id, db)
|
||||
|
||||
# 更新 Organization 表字段
|
||||
update_data = organization.model_dump(exclude_unset=True)
|
||||
for field, value in update_data.items():
|
||||
setattr(db_org, field, value)
|
||||
@@ -150,6 +209,7 @@ async def update_organization(
|
||||
@router.delete("/{org_id}", summary="删除组织")
|
||||
async def delete_organization(
|
||||
org_id: str,
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""删除组织(会级联删除所有成员关系)"""
|
||||
@@ -161,6 +221,10 @@ async def delete_organization(
|
||||
if not db_org:
|
||||
raise HTTPException(status_code=404, detail="组织不存在")
|
||||
|
||||
# 验证用户权限
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
await verify_project_access(db_org.project_id, user_id, db)
|
||||
|
||||
await db.delete(db_org)
|
||||
await db.commit()
|
||||
|
||||
@@ -173,6 +237,7 @@ async def delete_organization(
|
||||
@router.get("/{org_id}/members", response_model=List[OrganizationMemberDetailResponse], summary="获取组织成员")
|
||||
async def get_organization_members(
|
||||
org_id: str,
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
@@ -184,9 +249,14 @@ async def get_organization_members(
|
||||
org_result = await db.execute(
|
||||
select(Organization).where(Organization.id == org_id)
|
||||
)
|
||||
if not org_result.scalar_one_or_none():
|
||||
org = org_result.scalar_one_or_none()
|
||||
if not org:
|
||||
raise HTTPException(status_code=404, detail="组织不存在")
|
||||
|
||||
# 验证用户权限
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
await verify_project_access(org.project_id, user_id, db)
|
||||
|
||||
# 获取成员列表
|
||||
result = await db.execute(
|
||||
select(OrganizationMember)
|
||||
@@ -226,6 +296,7 @@ async def get_organization_members(
|
||||
async def add_organization_member(
|
||||
org_id: str,
|
||||
member: OrganizationMemberCreate,
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
@@ -242,6 +313,10 @@ async def add_organization_member(
|
||||
if not org:
|
||||
raise HTTPException(status_code=404, detail="组织不存在")
|
||||
|
||||
# 验证用户权限
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
await verify_project_access(org.project_id, user_id, db)
|
||||
|
||||
# 验证角色存在
|
||||
char_result = await db.execute(
|
||||
select(Character).where(Character.id == member.character_id)
|
||||
@@ -286,6 +361,7 @@ async def add_organization_member(
|
||||
async def update_organization_member(
|
||||
member_id: str,
|
||||
member: OrganizationMemberUpdate,
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""更新组织成员的职位、忠诚度等信息"""
|
||||
@@ -297,6 +373,14 @@ async def update_organization_member(
|
||||
if not db_member:
|
||||
raise HTTPException(status_code=404, detail="成员记录不存在")
|
||||
|
||||
# 通过成员所属的组织验证用户权限
|
||||
org_result = await db.execute(
|
||||
select(Organization).where(Organization.id == db_member.organization_id)
|
||||
)
|
||||
org = org_result.scalar_one()
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
await verify_project_access(org.project_id, user_id, db)
|
||||
|
||||
# 更新字段
|
||||
update_data = member.model_dump(exclude_unset=True)
|
||||
for field, value in update_data.items():
|
||||
@@ -312,6 +396,7 @@ async def update_organization_member(
|
||||
@router.delete("/members/{member_id}", summary="移除组织成员")
|
||||
async def remove_organization_member(
|
||||
member_id: str,
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
@@ -332,10 +417,225 @@ async def remove_organization_member(
|
||||
select(Organization).where(Organization.id == db_member.organization_id)
|
||||
)
|
||||
org = org_result.scalar_one()
|
||||
|
||||
# 验证用户权限
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
await verify_project_access(org.project_id, user_id, db)
|
||||
org.member_count = max(0, org.member_count - 1)
|
||||
|
||||
await db.delete(db_member)
|
||||
await db.commit()
|
||||
|
||||
logger.info(f"移除成员成功:{member_id}")
|
||||
return {"message": "成员移除成功", "id": member_id}
|
||||
return {"message": "成员移除成功", "id": member_id}
|
||||
|
||||
@router.post("/generate-stream", summary="AI生成组织(流式)")
|
||||
async def generate_organization_stream(
|
||||
gen_request: OrganizationGenerateRequest,
|
||||
http_request: Request,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
user_ai_service: AIService = Depends(get_user_ai_service)
|
||||
):
|
||||
"""
|
||||
使用AI生成组织设定(支持SSE流式进度显示)
|
||||
|
||||
通过Server-Sent Events返回实时进度信息
|
||||
"""
|
||||
async def generate() -> AsyncGenerator[str, None]:
|
||||
try:
|
||||
# 验证用户权限和项目是否存在
|
||||
user_id = getattr(http_request.state, 'user_id', None)
|
||||
project = await verify_project_access(gen_request.project_id, user_id, db)
|
||||
|
||||
yield await SSEResponse.send_progress("开始生成组织...", 0)
|
||||
|
||||
# 获取已存在的角色和组织列表
|
||||
yield await SSEResponse.send_progress("获取项目上下文...", 10)
|
||||
|
||||
existing_chars_result = await db.execute(
|
||||
select(Character)
|
||||
.where(Character.project_id == gen_request.project_id)
|
||||
.order_by(Character.created_at.desc())
|
||||
)
|
||||
existing_characters = existing_chars_result.scalars().all()
|
||||
|
||||
# 构建现有角色和组织信息摘要
|
||||
existing_info = ""
|
||||
character_list = []
|
||||
organization_list = []
|
||||
|
||||
if existing_characters:
|
||||
for c in existing_characters[:10]:
|
||||
if c.is_organization:
|
||||
organization_list.append(f"- {c.name} [{c.organization_type or '组织'}]")
|
||||
else:
|
||||
character_list.append(f"- {c.name}({c.role_type or '未知'})")
|
||||
|
||||
if character_list:
|
||||
existing_info += "\n已有角色:\n" + "\n".join(character_list)
|
||||
if organization_list:
|
||||
existing_info += "\n\n已有组织:\n" + "\n".join(organization_list)
|
||||
|
||||
# 构建项目上下文
|
||||
project_context = f"""
|
||||
项目信息:
|
||||
- 书名:{project.title}
|
||||
- 主题:{project.theme or '未设定'}
|
||||
- 类型:{project.genre or '未设定'}
|
||||
- 时间背景:{project.world_time_period or '未设定'}
|
||||
- 地理位置:{project.world_location or '未设定'}
|
||||
- 氛围基调:{project.world_atmosphere or '未设定'}
|
||||
- 世界规则:{project.world_rules or '未设定'}
|
||||
{existing_info}
|
||||
"""
|
||||
|
||||
user_input = f"""
|
||||
用户要求:
|
||||
- 组织名称:{gen_request.name or '请AI生成'}
|
||||
- 组织类型:{gen_request.organization_type or '请AI根据世界观决定'}
|
||||
- 背景设定:{gen_request.background or '无特殊要求'}
|
||||
- 其他要求:{gen_request.requirements or '无'}
|
||||
"""
|
||||
|
||||
yield await SSEResponse.send_progress("构建AI提示词...", 5)
|
||||
|
||||
# 获取自定义提示词模板
|
||||
template = await PromptService.get_template("SINGLE_ORGANIZATION_GENERATION", user_id, db)
|
||||
# 格式化提示词
|
||||
prompt = PromptService.format_prompt(
|
||||
template,
|
||||
project_context=project_context,
|
||||
user_input=user_input
|
||||
)
|
||||
|
||||
yield await SSEResponse.send_progress("调用AI服务生成组织...", 10)
|
||||
logger.info(f"🎯 开始为项目 {gen_request.project_id} 生成组织(SSE流式)")
|
||||
|
||||
try:
|
||||
# 使用流式生成替代非流式
|
||||
ai_content = ""
|
||||
chunk_count = 0
|
||||
|
||||
async for chunk in user_ai_service.generate_text_stream(prompt=prompt):
|
||||
chunk_count += 1
|
||||
ai_content += chunk
|
||||
|
||||
# 发送内容块
|
||||
yield await SSEResponse.send_chunk(chunk)
|
||||
|
||||
# 定期更新字数(5-95%,AI生成占90%)
|
||||
if chunk_count % 5 == 0:
|
||||
progress = min(10 + (chunk_count // 5), 95)
|
||||
yield await SSEResponse.send_progress(
|
||||
f"AI生成组织中... ({len(ai_content)}字符)",
|
||||
progress
|
||||
)
|
||||
|
||||
# 心跳
|
||||
if chunk_count % 20 == 0:
|
||||
yield await SSEResponse.send_heartbeat()
|
||||
|
||||
except Exception as ai_error:
|
||||
logger.error(f"❌ AI服务调用异常:{str(ai_error)}")
|
||||
yield await SSEResponse.send_error(f"AI服务调用失败:{str(ai_error)}")
|
||||
return
|
||||
|
||||
if not ai_content or not ai_content.strip():
|
||||
yield await SSEResponse.send_error("AI服务返回空响应")
|
||||
return
|
||||
|
||||
yield await SSEResponse.send_progress("解析AI响应...", 90)
|
||||
|
||||
# ✅ 使用统一的 JSON 清洗方法
|
||||
try:
|
||||
cleaned_response = user_ai_service._clean_json_response(ai_content)
|
||||
organization_data = json.loads(cleaned_response)
|
||||
logger.info(f"✅ 组织JSON解析成功")
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"❌ 组织JSON解析失败: {e}")
|
||||
logger.error(f" 原始响应预览: {ai_content[:200]}")
|
||||
yield await SSEResponse.send_error(f"AI返回的内容无法解析为JSON:{str(e)}")
|
||||
return
|
||||
|
||||
yield await SSEResponse.send_progress("创建组织记录...", 95)
|
||||
|
||||
# 创建角色记录(组织也是角色的一种)
|
||||
character = Character(
|
||||
project_id=gen_request.project_id,
|
||||
name=organization_data.get("name", gen_request.name or "未命名组织"),
|
||||
is_organization=True,
|
||||
role_type="supporting",
|
||||
personality=organization_data.get("personality", ""),
|
||||
background=organization_data.get("background", ""),
|
||||
appearance=organization_data.get("appearance", ""),
|
||||
organization_type=organization_data.get("organization_type"),
|
||||
organization_purpose=organization_data.get("organization_purpose"),
|
||||
organization_members=json.dumps(
|
||||
organization_data.get("organization_members", []),
|
||||
ensure_ascii=False
|
||||
),
|
||||
traits=json.dumps(
|
||||
organization_data.get("traits", []),
|
||||
ensure_ascii=False
|
||||
)
|
||||
)
|
||||
db.add(character)
|
||||
await db.flush()
|
||||
|
||||
logger.info(f"✅ 组织角色创建成功:{character.name} (ID: {character.id})")
|
||||
|
||||
yield await SSEResponse.send_progress("创建组织详情...", 98)
|
||||
|
||||
# 自动创建Organization详情记录
|
||||
organization = Organization(
|
||||
character_id=character.id,
|
||||
project_id=gen_request.project_id,
|
||||
member_count=0,
|
||||
power_level=organization_data.get("power_level", 50),
|
||||
location=organization_data.get("location"),
|
||||
motto=organization_data.get("motto"),
|
||||
color=organization_data.get("color")
|
||||
)
|
||||
db.add(organization)
|
||||
await db.flush()
|
||||
|
||||
logger.info(f"✅ 组织详情创建成功:{character.name} (Org ID: {organization.id})")
|
||||
|
||||
yield await SSEResponse.send_progress("保存生成历史...", 99)
|
||||
|
||||
# 记录生成历史
|
||||
history = GenerationHistory(
|
||||
project_id=gen_request.project_id,
|
||||
prompt=prompt,
|
||||
generated_content=ai_content,
|
||||
model=user_ai_service.default_model
|
||||
)
|
||||
db.add(history)
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(character)
|
||||
|
||||
logger.info(f"🎉 成功生成组织: {character.name}")
|
||||
|
||||
yield await SSEResponse.send_progress("组织生成完成!", 100, "success")
|
||||
|
||||
# 发送结果数据
|
||||
yield await SSEResponse.send_result({
|
||||
"character": {
|
||||
"id": character.id,
|
||||
"name": character.name,
|
||||
"organization_type": character.organization_type,
|
||||
"is_organization": character.is_organization
|
||||
}
|
||||
})
|
||||
|
||||
yield await SSEResponse.send_done()
|
||||
|
||||
except HTTPException as he:
|
||||
logger.error(f"HTTP异常: {he.detail}")
|
||||
yield await SSEResponse.send_error(he.detail, he.status_code)
|
||||
except Exception as e:
|
||||
logger.error(f"生成组织失败: {str(e)}")
|
||||
yield await SSEResponse.send_error(f"生成组织失败: {str(e)}")
|
||||
|
||||
return create_sse_response(generate())
|
||||
|
||||
+2543
-385
File diff suppressed because it is too large
Load Diff
@@ -1,12 +1,12 @@
|
||||
"""AI去味API - 核心特色功能"""
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.database import get_db
|
||||
from app.models.generation_history import GenerationHistory
|
||||
from app.schemas.polish import PolishRequest, PolishResponse
|
||||
from app.services.ai_service import AIService
|
||||
from app.services.prompt_service import prompt_service
|
||||
from app.services.prompt_service import prompt_service, PromptService
|
||||
from app.logger import get_logger
|
||||
from app.api.settings import get_user_ai_service
|
||||
|
||||
@@ -17,6 +17,7 @@ logger = get_logger(__name__)
|
||||
@router.post("", response_model=PolishResponse, summary="AI去味")
|
||||
async def polish_text(
|
||||
request: PolishRequest,
|
||||
http_request: Request,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
user_ai_service: AIService = Depends(get_user_ai_service)
|
||||
):
|
||||
@@ -32,15 +33,21 @@ async def polish_text(
|
||||
这是本项目的核心特色功能!
|
||||
"""
|
||||
try:
|
||||
# 构建AI去味提示词
|
||||
prompt = prompt_service.get_denoising_prompt(
|
||||
# 获取用户ID
|
||||
user_id = getattr(http_request.state, 'user_id', None)
|
||||
|
||||
# 获取自定义提示词模板
|
||||
template = await PromptService.get_template("AI_DENOISING", user_id, db)
|
||||
# 格式化提示词
|
||||
prompt = PromptService.format_prompt(
|
||||
template,
|
||||
original_text=request.original_text
|
||||
)
|
||||
|
||||
logger.info(f"开始AI去味处理,原文长度: {len(request.original_text)}")
|
||||
|
||||
# 调用AI进行去味处理
|
||||
polished_text = await ai_service.generate_text(
|
||||
polished_text = await user_ai_service.generate_text(
|
||||
prompt=prompt,
|
||||
provider=request.provider,
|
||||
model=request.model,
|
||||
@@ -85,6 +92,7 @@ async def polish_batch(
|
||||
project_id: int = None,
|
||||
provider: str = None,
|
||||
model: str = None,
|
||||
http_request: Request = None,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
user_ai_service: AIService = Depends(get_user_ai_service)
|
||||
):
|
||||
@@ -94,12 +102,18 @@ async def polish_batch(
|
||||
适用于一次性处理多个章节或段落
|
||||
"""
|
||||
try:
|
||||
# 获取用户ID
|
||||
user_id = getattr(http_request.state, 'user_id', None) if http_request else None
|
||||
|
||||
results = []
|
||||
|
||||
for idx, text in enumerate(texts):
|
||||
logger.info(f"处理第 {idx+1}/{len(texts)} 个文本")
|
||||
|
||||
prompt = prompt_service.get_denoising_prompt(original_text=text)
|
||||
# 获取自定义提示词模板
|
||||
template = await PromptService.get_template("AI_DENOISING", user_id, db)
|
||||
# 格式化提示词
|
||||
prompt = PromptService.format_prompt(template, original_text=text)
|
||||
|
||||
polished_text = await user_ai_service.generate_text(
|
||||
prompt=prompt,
|
||||
|
||||
+188
-43
@@ -1,5 +1,5 @@
|
||||
"""项目管理API"""
|
||||
from fastapi import APIRouter, Depends, HTTPException, UploadFile, File
|
||||
from fastapi import APIRouter, Depends, HTTPException, UploadFile, File, Request
|
||||
from fastapi.responses import Response
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, func, delete
|
||||
@@ -13,6 +13,7 @@ from app.models.outline import Outline
|
||||
from app.models.chapter import Chapter
|
||||
from app.models.generation_history import GenerationHistory
|
||||
from app.models.relationship import CharacterRelationship, Organization, OrganizationMember
|
||||
from app.models.memory import StoryMemory, PlotAnalysis
|
||||
from app.schemas.project import (
|
||||
ProjectCreate,
|
||||
ProjectUpdate,
|
||||
@@ -25,6 +26,7 @@ from app.schemas.import_export import (
|
||||
ImportResult
|
||||
)
|
||||
from app.services.import_export_service import ImportExportService
|
||||
from app.services.memory_service import memory_service
|
||||
from app.logger import get_logger
|
||||
from app.utils.data_consistency import (
|
||||
run_full_data_consistency_check,
|
||||
@@ -39,17 +41,31 @@ router = APIRouter(prefix="/projects", tags=["项目管理"])
|
||||
@router.post("", response_model=ProjectResponse, summary="创建项目")
|
||||
async def create_project(
|
||||
project: ProjectCreate,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
request: Request = None
|
||||
):
|
||||
try:
|
||||
logger.info(f"创建新项目: {project.title}")
|
||||
db_project = Project(**project.model_dump())
|
||||
# 从认证中间件获取用户ID
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
if not user_id:
|
||||
logger.warning("未登录用户尝试创建项目")
|
||||
raise HTTPException(status_code=401, detail="未登录")
|
||||
|
||||
logger.info(f"创建新项目: {project.title}, user_id={user_id}")
|
||||
|
||||
# 创建项目时自动设置user_id
|
||||
project_data = project.model_dump()
|
||||
project_data['user_id'] = user_id
|
||||
db_project = Project(**project_data)
|
||||
|
||||
db.add(db_project)
|
||||
await db.commit()
|
||||
await db.refresh(db_project)
|
||||
logger.info(f"项目创建成功: {db_project.id}")
|
||||
logger.info(f"项目创建成功: project_id={db_project.id}, user_id={user_id}")
|
||||
|
||||
return db_project
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"创建项目失败: {str(e)}", exc_info=True)
|
||||
raise
|
||||
@@ -59,24 +75,38 @@ async def create_project(
|
||||
async def get_projects(
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
request: Request = None
|
||||
):
|
||||
"""获取所有项目列表"""
|
||||
"""获取当前用户的项目列表"""
|
||||
try:
|
||||
logger.debug(f"获取项目列表: skip={skip}, limit={limit}")
|
||||
count_result = await db.execute(select(func.count(Project.id)))
|
||||
# 从认证中间件获取用户ID
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
if not user_id:
|
||||
logger.warning("未登录用户尝试获取项目列表")
|
||||
raise HTTPException(status_code=401, detail="未登录")
|
||||
|
||||
logger.debug(f"获取项目列表: user_id={user_id}, skip={skip}, limit={limit}")
|
||||
|
||||
# 只查询当前用户的项目
|
||||
count_result = await db.execute(
|
||||
select(func.count(Project.id)).where(Project.user_id == user_id)
|
||||
)
|
||||
total = count_result.scalar_one()
|
||||
|
||||
result = await db.execute(
|
||||
select(Project)
|
||||
.where(Project.user_id == user_id)
|
||||
.order_by(Project.updated_at.desc())
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
)
|
||||
projects = result.scalars().all()
|
||||
logger.info(f"获取项目列表成功: 共{total}个项目")
|
||||
logger.info(f"获取项目列表成功: user_id={user_id}, 共{total}个项目")
|
||||
|
||||
return ProjectListResponse(total=total, items=projects)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"获取项目列表失败: {str(e)}", exc_info=True)
|
||||
raise
|
||||
@@ -85,17 +115,29 @@ async def get_projects(
|
||||
@router.get("/{project_id}", response_model=ProjectResponse, summary="获取项目详情")
|
||||
async def get_project(
|
||||
project_id: str,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
request: Request = None
|
||||
):
|
||||
try:
|
||||
logger.debug(f"获取项目详情: {project_id}")
|
||||
# 从认证中间件获取用户ID
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
if not user_id:
|
||||
logger.warning("未登录用户尝试获取项目详情")
|
||||
raise HTTPException(status_code=401, detail="未登录")
|
||||
|
||||
logger.debug(f"获取项目详情: project_id={project_id}, user_id={user_id}")
|
||||
|
||||
# 只查询当前用户的项目
|
||||
result = await db.execute(
|
||||
select(Project).where(Project.id == project_id)
|
||||
select(Project).where(
|
||||
Project.id == project_id,
|
||||
Project.user_id == user_id
|
||||
)
|
||||
)
|
||||
project = result.scalar_one_or_none()
|
||||
|
||||
if not project:
|
||||
logger.warning(f"项目不存在: {project_id}")
|
||||
logger.warning(f"项目不存在或无权访问: project_id={project_id}, user_id={user_id}")
|
||||
raise HTTPException(status_code=404, detail="项目不存在")
|
||||
|
||||
logger.info(f"获取项目详情成功: {project.title}")
|
||||
@@ -111,17 +153,29 @@ async def get_project(
|
||||
async def update_project(
|
||||
project_id: str,
|
||||
project_update: ProjectUpdate,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
request: Request = None
|
||||
):
|
||||
try:
|
||||
logger.info(f"更新项目: {project_id}")
|
||||
# 从认证中间件获取用户ID
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
if not user_id:
|
||||
logger.warning("未登录用户尝试更新项目")
|
||||
raise HTTPException(status_code=401, detail="未登录")
|
||||
|
||||
logger.info(f"更新项目: project_id={project_id}, user_id={user_id}")
|
||||
|
||||
# 只查询当前用户的项目
|
||||
result = await db.execute(
|
||||
select(Project).where(Project.id == project_id)
|
||||
select(Project).where(
|
||||
Project.id == project_id,
|
||||
Project.user_id == user_id
|
||||
)
|
||||
)
|
||||
project = result.scalar_one_or_none()
|
||||
|
||||
if not project:
|
||||
logger.warning(f"项目不存在: {project_id}")
|
||||
logger.warning(f"项目不存在或无权访问: project_id={project_id}, user_id={user_id}")
|
||||
raise HTTPException(status_code=404, detail="项目不存在")
|
||||
|
||||
update_data = project_update.model_dump(exclude_unset=True)
|
||||
@@ -143,21 +197,43 @@ async def update_project(
|
||||
@router.delete("/{project_id}", summary="删除项目")
|
||||
async def delete_project(
|
||||
project_id: str,
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
try:
|
||||
logger.info(f"删除项目: {project_id}")
|
||||
# 从认证中间件获取用户ID
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
if not user_id:
|
||||
logger.warning("未登录用户尝试删除项目")
|
||||
raise HTTPException(status_code=401, detail="未登录")
|
||||
|
||||
logger.info(f"删除项目: project_id={project_id}, user_id={user_id}")
|
||||
|
||||
# 只查询当前用户的项目
|
||||
result = await db.execute(
|
||||
select(Project).where(Project.id == project_id)
|
||||
select(Project).where(
|
||||
Project.id == project_id,
|
||||
Project.user_id == user_id
|
||||
)
|
||||
)
|
||||
project = result.scalar_one_or_none()
|
||||
|
||||
if not project:
|
||||
logger.warning(f"项目不存在: {project_id}")
|
||||
logger.warning(f"项目不存在或无权访问: project_id={project_id}, user_id={user_id}")
|
||||
raise HTTPException(status_code=404, detail="项目不存在")
|
||||
|
||||
project_title = project.title
|
||||
|
||||
# 删除向量数据库中的记忆(user_id已在上面获取)
|
||||
if user_id:
|
||||
try:
|
||||
await memory_service.delete_project_memories(user_id, project_id)
|
||||
logger.info(f"✅ 向量数据库清理成功")
|
||||
except Exception as e:
|
||||
logger.warning(f"⚠️ 向量数据库清理失败(继续删除其他数据): {str(e)}")
|
||||
else:
|
||||
logger.warning(f"⚠️ 未找到用户ID,跳过向量数据库清理")
|
||||
|
||||
relationships_result = await db.execute(
|
||||
delete(CharacterRelationship).where(CharacterRelationship.project_id == project_id)
|
||||
)
|
||||
@@ -200,11 +276,14 @@ async def delete_project(
|
||||
)
|
||||
logger.debug(f"删除角色数: {characters_result.rowcount}")
|
||||
|
||||
# 注意:StoryMemory和PlotAnalysis会通过数据库级联删除自动清理
|
||||
# 但向量数据库已在上面手动清理
|
||||
|
||||
await db.delete(project)
|
||||
await db.commit()
|
||||
|
||||
logger.info(f"项目删除成功: {project_title}")
|
||||
return {"message": "项目及所有关联数据删除成功"}
|
||||
return {"message": "项目及所有关联数据(包括向量数据库)删除成功"}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
@@ -215,22 +294,33 @@ async def delete_project(
|
||||
@router.get("/{project_id}/export", summary="导出项目章节为TXT")
|
||||
async def export_project_chapters(
|
||||
project_id: str,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
request: Request = None
|
||||
):
|
||||
"""
|
||||
导出项目的所有章节内容为TXT文本文件
|
||||
按章节顺序组织,包含项目基本信息
|
||||
"""
|
||||
try:
|
||||
logger.info(f"开始导出项目: {project_id}")
|
||||
# 从认证中间件获取用户ID
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
if not user_id:
|
||||
logger.warning("未登录用户尝试导出项目")
|
||||
raise HTTPException(status_code=401, detail="未登录")
|
||||
|
||||
logger.info(f"开始导出项目: project_id={project_id}, user_id={user_id}")
|
||||
|
||||
# 只查询当前用户的项目
|
||||
result = await db.execute(
|
||||
select(Project).where(Project.id == project_id)
|
||||
select(Project).where(
|
||||
Project.id == project_id,
|
||||
Project.user_id == user_id
|
||||
)
|
||||
)
|
||||
project = result.scalar_one_or_none()
|
||||
|
||||
if not project:
|
||||
logger.warning(f"项目不存在: {project_id}")
|
||||
logger.warning(f"项目不存在或无权访问: project_id={project_id}, user_id={user_id}")
|
||||
raise HTTPException(status_code=404, detail="项目不存在")
|
||||
|
||||
chapters_result = await db.execute(
|
||||
@@ -264,6 +354,7 @@ async def export_project_chapters(
|
||||
txt_content.append("\n" + "=" * 80 + "\n\n")
|
||||
|
||||
for chapter in chapters:
|
||||
# 只显示主章节号,不显示子索引
|
||||
txt_content.append(f"第 {chapter.chapter_number} 章 {chapter.title}")
|
||||
txt_content.append("-" * 80)
|
||||
txt_content.append("") # 空行
|
||||
@@ -276,7 +367,11 @@ async def export_project_chapters(
|
||||
txt_content.append("\n\n" + "=" * 80 + "\n\n")
|
||||
|
||||
txt_content.append(f"--- 全文完 ---")
|
||||
txt_content.append(f"\n导出时间: {func.now()}")
|
||||
|
||||
# 获取当前时间
|
||||
from datetime import datetime
|
||||
export_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
txt_content.append(f"\n导出时间: {export_time}")
|
||||
|
||||
final_content = "\n".join(txt_content)
|
||||
|
||||
@@ -307,6 +402,7 @@ async def export_project_chapters(
|
||||
@router.post("/{project_id}/check-consistency", summary="检查数据一致性")
|
||||
async def check_project_consistency(
|
||||
project_id: str,
|
||||
request: Request,
|
||||
auto_fix: bool = True,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
@@ -324,15 +420,25 @@ async def check_project_consistency(
|
||||
- organization_members: 验证组织成员数据完整性
|
||||
"""
|
||||
try:
|
||||
logger.info(f"开始数据一致性检查: {project_id}, auto_fix={auto_fix}")
|
||||
# 从认证中间件获取用户ID
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
if not user_id:
|
||||
logger.warning("未登录用户尝试检查数据一致性")
|
||||
raise HTTPException(status_code=401, detail="未登录")
|
||||
|
||||
logger.info(f"开始数据一致性检查: project_id={project_id}, user_id={user_id}, auto_fix={auto_fix}")
|
||||
|
||||
# 只查询当前用户的项目
|
||||
result = await db.execute(
|
||||
select(Project).where(Project.id == project_id)
|
||||
select(Project).where(
|
||||
Project.id == project_id,
|
||||
Project.user_id == user_id
|
||||
)
|
||||
)
|
||||
project = result.scalar_one_or_none()
|
||||
|
||||
if not project:
|
||||
logger.warning(f"项目不存在: {project_id}")
|
||||
logger.warning(f"项目不存在或无权访问: project_id={project_id}, user_id={user_id}")
|
||||
raise HTTPException(status_code=404, detail="项目不存在")
|
||||
|
||||
report = await run_full_data_consistency_check(project_id, db, auto_fix)
|
||||
@@ -350,6 +456,7 @@ async def check_project_consistency(
|
||||
@router.post("/{project_id}/fix-organizations", summary="修复组织记录")
|
||||
async def fix_project_organizations(
|
||||
project_id: str,
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
@@ -358,15 +465,25 @@ async def fix_project_organizations(
|
||||
为所有is_organization=True但没有Organization记录的Character创建记录
|
||||
"""
|
||||
try:
|
||||
logger.info(f"开始修复组织记录: {project_id}")
|
||||
# 从认证中间件获取用户ID
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
if not user_id:
|
||||
logger.warning("未登录用户尝试修复组织记录")
|
||||
raise HTTPException(status_code=401, detail="未登录")
|
||||
|
||||
logger.info(f"开始修复组织记录: project_id={project_id}, user_id={user_id}")
|
||||
|
||||
# 只查询当前用户的项目
|
||||
result = await db.execute(
|
||||
select(Project).where(Project.id == project_id)
|
||||
select(Project).where(
|
||||
Project.id == project_id,
|
||||
Project.user_id == user_id
|
||||
)
|
||||
)
|
||||
project = result.scalar_one_or_none()
|
||||
|
||||
if not project:
|
||||
logger.warning(f"项目不存在: {project_id}")
|
||||
logger.warning(f"项目不存在或无权访问: project_id={project_id}, user_id={user_id}")
|
||||
raise HTTPException(status_code=404, detail="项目不存在")
|
||||
|
||||
fixed_count, total_count = await fix_missing_organization_records(project_id, db)
|
||||
@@ -388,6 +505,7 @@ async def fix_project_organizations(
|
||||
@router.post("/{project_id}/fix-member-counts", summary="修复成员计数")
|
||||
async def fix_project_member_counts(
|
||||
project_id: str,
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
@@ -396,15 +514,25 @@ async def fix_project_member_counts(
|
||||
从实际成员记录重新计算每个组织的member_count
|
||||
"""
|
||||
try:
|
||||
logger.info(f"开始修复成员计数: {project_id}")
|
||||
# 从认证中间件获取用户ID
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
if not user_id:
|
||||
logger.warning("未登录用户尝试修复成员计数")
|
||||
raise HTTPException(status_code=401, detail="未登录")
|
||||
|
||||
logger.info(f"开始修复成员计数: project_id={project_id}, user_id={user_id}")
|
||||
|
||||
# 只查询当前用户的项目
|
||||
result = await db.execute(
|
||||
select(Project).where(Project.id == project_id)
|
||||
select(Project).where(
|
||||
Project.id == project_id,
|
||||
Project.user_id == user_id
|
||||
)
|
||||
)
|
||||
project = result.scalar_one_or_none()
|
||||
|
||||
if not project:
|
||||
logger.warning(f"项目不存在: {project_id}")
|
||||
logger.warning(f"项目不存在或无权访问: project_id={project_id}, user_id={user_id}")
|
||||
raise HTTPException(status_code=404, detail="项目不存在")
|
||||
|
||||
fixed_count, total_count = await fix_organization_member_counts(project_id, db)
|
||||
@@ -426,6 +554,7 @@ async def fix_project_member_counts(
|
||||
@router.post("/{project_id}/export-data", summary="导出项目数据为JSON")
|
||||
async def export_project_data(
|
||||
project_id: str,
|
||||
request: Request,
|
||||
options: ExportOptions,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
@@ -440,16 +569,25 @@ async def export_project_data(
|
||||
JSON文件下载
|
||||
"""
|
||||
try:
|
||||
logger.info(f"开始导出项目数据: {project_id}")
|
||||
# 从认证中间件获取用户ID
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
if not user_id:
|
||||
logger.warning("未登录用户尝试导出项目数据")
|
||||
raise HTTPException(status_code=401, detail="未登录")
|
||||
|
||||
# 检查项目是否存在
|
||||
logger.info(f"开始导出项目数据: project_id={project_id}, user_id={user_id}")
|
||||
|
||||
# 只查询当前用户的项目
|
||||
result = await db.execute(
|
||||
select(Project).where(Project.id == project_id)
|
||||
select(Project).where(
|
||||
Project.id == project_id,
|
||||
Project.user_id == user_id
|
||||
)
|
||||
)
|
||||
project = result.scalar_one_or_none()
|
||||
|
||||
if not project:
|
||||
logger.warning(f"项目不存在: {project_id}")
|
||||
logger.warning(f"项目不存在或无权访问: project_id={project_id}, user_id={user_id}")
|
||||
raise HTTPException(status_code=404, detail="项目不存在")
|
||||
|
||||
# 导出数据
|
||||
@@ -538,6 +676,7 @@ async def validate_import_file(
|
||||
@router.post("/import", response_model=ImportResult, summary="导入项目")
|
||||
async def import_project(
|
||||
file: UploadFile = File(...),
|
||||
request: Request = None,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
@@ -550,7 +689,13 @@ async def import_project(
|
||||
导入结果
|
||||
"""
|
||||
try:
|
||||
logger.info(f"开始导入项目: {file.filename}")
|
||||
# 从认证中间件获取用户ID
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
if not user_id:
|
||||
logger.warning("未登录用户尝试导入项目")
|
||||
raise HTTPException(status_code=401, detail="未登录")
|
||||
|
||||
logger.info(f"开始导入项目: {file.filename}, user_id={user_id}")
|
||||
|
||||
# 检查文件类型
|
||||
if not file.filename.endswith('.json'):
|
||||
@@ -570,8 +715,8 @@ async def import_project(
|
||||
except json.JSONDecodeError as e:
|
||||
raise HTTPException(status_code=400, detail=f"无效的JSON格式: {str(e)}")
|
||||
|
||||
# 导入数据
|
||||
import_result = await ImportExportService.import_project(data, db)
|
||||
# 导入数据(传入user_id)
|
||||
import_result = await ImportExportService.import_project(data, db, user_id)
|
||||
|
||||
if import_result.success:
|
||||
logger.info(f"项目导入成功: {import_result.project_id}")
|
||||
|
||||
@@ -0,0 +1,630 @@
|
||||
"""提示词模板管理 API"""
|
||||
from fastapi import APIRouter, HTTPException, Depends, Query, Request
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, func, delete
|
||||
from typing import List, Optional
|
||||
from datetime import datetime
|
||||
import json
|
||||
import hashlib
|
||||
|
||||
from app.database import get_db
|
||||
from app.models.prompt_template import PromptTemplate
|
||||
from app.schemas.prompt_template import (
|
||||
PromptTemplateCreate,
|
||||
PromptTemplateUpdate,
|
||||
PromptTemplateResponse,
|
||||
PromptTemplateListResponse,
|
||||
PromptTemplateCategoryResponse,
|
||||
PromptTemplateExport,
|
||||
PromptTemplateExportItem,
|
||||
PromptTemplateImportResult,
|
||||
PromptTemplatePreviewRequest
|
||||
)
|
||||
from app.services.prompt_service import PromptService
|
||||
from app.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
def calculate_content_hash(content: str) -> str:
|
||||
"""计算模板内容的SHA256哈希值"""
|
||||
return hashlib.sha256(content.strip().encode('utf-8')).hexdigest()[:16]
|
||||
|
||||
router = APIRouter(prefix="/prompt-templates", tags=["提示词模板管理"])
|
||||
|
||||
|
||||
@router.get("", response_model=PromptTemplateListResponse)
|
||||
async def get_all_templates(
|
||||
request: Request,
|
||||
category: Optional[str] = Query(None, description="按分类筛选"),
|
||||
is_active: Optional[bool] = Query(None, description="按启用状态筛选"),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
获取用户所有提示词模板
|
||||
"""
|
||||
# 从认证中间件获取用户ID
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="未登录")
|
||||
|
||||
query = select(PromptTemplate).where(PromptTemplate.user_id == user_id)
|
||||
|
||||
if category:
|
||||
query = query.where(PromptTemplate.category == category)
|
||||
if is_active is not None:
|
||||
query = query.where(PromptTemplate.is_active == is_active)
|
||||
|
||||
query = query.order_by(PromptTemplate.category, PromptTemplate.template_key)
|
||||
|
||||
result = await db.execute(query)
|
||||
templates = result.scalars().all()
|
||||
|
||||
# 获取所有分类
|
||||
categories_result = await db.execute(
|
||||
select(PromptTemplate.category)
|
||||
.where(PromptTemplate.user_id == user_id)
|
||||
.distinct()
|
||||
)
|
||||
categories = [c for c in categories_result.scalars().all() if c]
|
||||
|
||||
return PromptTemplateListResponse(
|
||||
templates=templates,
|
||||
total=len(templates),
|
||||
categories=sorted(categories)
|
||||
)
|
||||
|
||||
|
||||
@router.get("/categories", response_model=List[PromptTemplateCategoryResponse])
|
||||
async def get_templates_by_category(
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
按分类获取提示词模板(合并用户自定义和系统默认)
|
||||
"""
|
||||
# 从认证中间件获取用户ID
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="未登录")
|
||||
|
||||
# 1. 查询用户自定义模板
|
||||
result = await db.execute(
|
||||
select(PromptTemplate)
|
||||
.where(PromptTemplate.user_id == user_id)
|
||||
.order_by(PromptTemplate.category, PromptTemplate.template_key)
|
||||
)
|
||||
user_templates = result.scalars().all()
|
||||
|
||||
# 2. 获取所有系统默认模板
|
||||
system_templates = PromptService.get_all_system_templates()
|
||||
|
||||
# 3. 构建用户自定义模板的键集合
|
||||
user_template_keys = {t.template_key for t in user_templates}
|
||||
|
||||
# 4. 合并模板:用户自定义的 + 未自定义的系统默认
|
||||
all_templates = []
|
||||
current_time = datetime.now()
|
||||
|
||||
# 添加用户自定义的模板
|
||||
for user_template in user_templates:
|
||||
user_template.is_system_default = False # 标记为已自定义
|
||||
all_templates.append(user_template)
|
||||
|
||||
# 添加未自定义的系统默认模板
|
||||
for sys_template in system_templates:
|
||||
if sys_template['template_key'] not in user_template_keys:
|
||||
# 这个系统模板用户还没有自定义,创建临时对象
|
||||
template_obj = PromptTemplate(
|
||||
id=sys_template['template_key'], # 使用template_key作为临时ID
|
||||
user_id=user_id,
|
||||
template_key=sys_template['template_key'],
|
||||
template_name=sys_template['template_name'],
|
||||
template_content=sys_template['content'],
|
||||
description=sys_template['description'],
|
||||
category=sys_template['category'],
|
||||
parameters=json.dumps(sys_template['parameters']),
|
||||
is_active=True,
|
||||
is_system_default=True,
|
||||
created_at=current_time,
|
||||
updated_at=current_time
|
||||
)
|
||||
all_templates.append(template_obj)
|
||||
|
||||
# 5. 按分类分组
|
||||
category_dict = {}
|
||||
for template in all_templates:
|
||||
cat = template.category or "未分类"
|
||||
if cat not in category_dict:
|
||||
category_dict[cat] = []
|
||||
category_dict[cat].append(template)
|
||||
|
||||
# 6. 构建响应
|
||||
response = []
|
||||
for category, temps in sorted(category_dict.items()):
|
||||
# 按template_key排序,确保顺序一致
|
||||
temps.sort(key=lambda t: t.template_key)
|
||||
response.append(PromptTemplateCategoryResponse(
|
||||
category=category,
|
||||
count=len(temps),
|
||||
templates=temps
|
||||
))
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@router.get("/system-defaults")
|
||||
async def get_system_defaults(
|
||||
request: Request
|
||||
):
|
||||
"""
|
||||
获取所有系统默认提示词模板
|
||||
"""
|
||||
# 从认证中间件获取用户ID
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="未登录")
|
||||
|
||||
# 从PromptService获取所有系统默认模板
|
||||
system_templates = PromptService.get_all_system_templates()
|
||||
|
||||
return {
|
||||
"templates": system_templates,
|
||||
"total": len(system_templates)
|
||||
}
|
||||
|
||||
|
||||
@router.get("/{template_key}", response_model=PromptTemplateResponse)
|
||||
async def get_template(
|
||||
template_key: str,
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
获取指定的提示词模板
|
||||
"""
|
||||
# 从认证中间件获取用户ID
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="未登录")
|
||||
|
||||
result = await db.execute(
|
||||
select(PromptTemplate).where(
|
||||
PromptTemplate.user_id == user_id,
|
||||
PromptTemplate.template_key == template_key
|
||||
)
|
||||
)
|
||||
template = result.scalar_one_or_none()
|
||||
|
||||
if not template:
|
||||
raise HTTPException(status_code=404, detail=f"模板 {template_key} 不存在")
|
||||
|
||||
return template
|
||||
|
||||
|
||||
@router.post("", response_model=PromptTemplateResponse)
|
||||
async def create_or_update_template(
|
||||
data: PromptTemplateCreate,
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
创建或更新提示词模板(Upsert)
|
||||
"""
|
||||
# 从认证中间件获取用户ID
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="未登录")
|
||||
|
||||
# 查找现有模板
|
||||
result = await db.execute(
|
||||
select(PromptTemplate).where(
|
||||
PromptTemplate.user_id == user_id,
|
||||
PromptTemplate.template_key == data.template_key
|
||||
)
|
||||
)
|
||||
template = result.scalar_one_or_none()
|
||||
|
||||
if template:
|
||||
# 更新现有模板
|
||||
for key, value in data.model_dump(exclude_unset=True).items():
|
||||
setattr(template, key, value)
|
||||
logger.info(f"用户 {user_id} 更新模板 {data.template_key}")
|
||||
else:
|
||||
# 创建新模板
|
||||
template = PromptTemplate(
|
||||
user_id=user_id,
|
||||
**data.model_dump()
|
||||
)
|
||||
db.add(template)
|
||||
logger.info(f"用户 {user_id} 创建模板 {data.template_key}")
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(template)
|
||||
|
||||
return template
|
||||
|
||||
|
||||
@router.put("/{template_key}", response_model=PromptTemplateResponse)
|
||||
async def update_template(
|
||||
template_key: str,
|
||||
data: PromptTemplateUpdate,
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
更新提示词模板
|
||||
"""
|
||||
# 从认证中间件获取用户ID
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="未登录")
|
||||
|
||||
result = await db.execute(
|
||||
select(PromptTemplate).where(
|
||||
PromptTemplate.user_id == user_id,
|
||||
PromptTemplate.template_key == template_key
|
||||
)
|
||||
)
|
||||
template = result.scalar_one_or_none()
|
||||
|
||||
if not template:
|
||||
raise HTTPException(status_code=404, detail=f"模板 {template_key} 不存在")
|
||||
|
||||
# 更新模板
|
||||
update_data = data.model_dump(exclude_unset=True)
|
||||
for key, value in update_data.items():
|
||||
setattr(template, key, value)
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(template)
|
||||
logger.info(f"用户 {user_id} 更新模板 {template_key}")
|
||||
|
||||
return template
|
||||
|
||||
|
||||
@router.delete("/{template_key}")
|
||||
async def delete_template(
|
||||
template_key: str,
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
删除自定义提示词模板
|
||||
"""
|
||||
# 从认证中间件获取用户ID
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="未登录")
|
||||
|
||||
result = await db.execute(
|
||||
select(PromptTemplate).where(
|
||||
PromptTemplate.user_id == user_id,
|
||||
PromptTemplate.template_key == template_key
|
||||
)
|
||||
)
|
||||
template = result.scalar_one_or_none()
|
||||
|
||||
if not template:
|
||||
raise HTTPException(status_code=404, detail=f"模板 {template_key} 不存在")
|
||||
|
||||
await db.delete(template)
|
||||
await db.commit()
|
||||
logger.info(f"用户 {user_id} 删除模板 {template_key}")
|
||||
|
||||
return {"message": "模板已删除", "template_key": template_key}
|
||||
|
||||
|
||||
@router.post("/{template_key}/reset")
|
||||
async def reset_to_default(
|
||||
template_key: str,
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
重置为系统默认模板(删除用户自定义版本)
|
||||
"""
|
||||
# 从认证中间件获取用户ID
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="未登录")
|
||||
|
||||
# 验证系统默认模板是否存在
|
||||
system_template = PromptService.get_system_template_info(template_key)
|
||||
if not system_template:
|
||||
raise HTTPException(status_code=404, detail=f"系统默认模板 {template_key} 不存在")
|
||||
|
||||
# 查找并删除用户的自定义模板
|
||||
result = await db.execute(
|
||||
select(PromptTemplate).where(
|
||||
PromptTemplate.user_id == user_id,
|
||||
PromptTemplate.template_key == template_key
|
||||
)
|
||||
)
|
||||
template = result.scalar_one_or_none()
|
||||
|
||||
if template:
|
||||
await db.delete(template)
|
||||
await db.commit()
|
||||
logger.info(f"用户 {user_id} 删除自定义模板 {template_key},恢复为系统默认")
|
||||
return {"message": "已重置为系统默认", "template_key": template_key}
|
||||
else:
|
||||
# 用户本来就没有自定义,已经是系统默认状态
|
||||
logger.info(f"用户 {user_id} 的模板 {template_key} 本来就是系统默认")
|
||||
return {"message": "已是系统默认状态", "template_key": template_key}
|
||||
|
||||
|
||||
@router.post("/export", response_model=PromptTemplateExport)
|
||||
async def export_templates(
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
导出所有提示词模板(包括用户自定义和系统默认)
|
||||
- 用户自定义的提示词标记为 is_customized=true
|
||||
- 系统默认的提示词标记为 is_customized=false
|
||||
"""
|
||||
# 从认证中间件获取用户ID
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="未登录")
|
||||
|
||||
# 1. 查询用户自定义模板
|
||||
result = await db.execute(
|
||||
select(PromptTemplate).where(PromptTemplate.user_id == user_id)
|
||||
)
|
||||
user_templates = result.scalars().all()
|
||||
|
||||
# 2. 获取所有系统默认模板
|
||||
system_templates = PromptService.get_all_system_templates()
|
||||
|
||||
# 3. 构建用户自定义模板的键集合
|
||||
user_template_keys = {t.template_key for t in user_templates}
|
||||
|
||||
# 4. 准备导出数据
|
||||
export_items = []
|
||||
customized_count = 0
|
||||
system_default_count = 0
|
||||
|
||||
# 添加用户自定义的模板
|
||||
for user_template in user_templates:
|
||||
# 获取对应的系统模板用于计算哈希
|
||||
system_template = next(
|
||||
(t for t in system_templates if t["template_key"] == user_template.template_key),
|
||||
None
|
||||
)
|
||||
system_hash = calculate_content_hash(system_template["content"]) if system_template else None
|
||||
|
||||
export_items.append(PromptTemplateExportItem(
|
||||
template_key=user_template.template_key,
|
||||
template_name=user_template.template_name,
|
||||
template_content=user_template.template_content,
|
||||
description=user_template.description,
|
||||
category=user_template.category,
|
||||
parameters=user_template.parameters,
|
||||
is_active=user_template.is_active,
|
||||
is_customized=True,
|
||||
system_content_hash=system_hash
|
||||
))
|
||||
customized_count += 1
|
||||
|
||||
# 添加未自定义的系统默认模板
|
||||
for sys_template in system_templates:
|
||||
if sys_template['template_key'] not in user_template_keys:
|
||||
export_items.append(PromptTemplateExportItem(
|
||||
template_key=sys_template['template_key'],
|
||||
template_name=sys_template['template_name'],
|
||||
template_content=sys_template['content'],
|
||||
description=sys_template['description'],
|
||||
category=sys_template['category'],
|
||||
parameters=json.dumps(sys_template['parameters']),
|
||||
is_active=True,
|
||||
is_customized=False,
|
||||
system_content_hash=calculate_content_hash(sys_template['content'])
|
||||
))
|
||||
system_default_count += 1
|
||||
|
||||
statistics = {
|
||||
"total": len(export_items),
|
||||
"customized": customized_count,
|
||||
"system_default": system_default_count
|
||||
}
|
||||
|
||||
logger.info(f"用户 {user_id} 导出了 {statistics['total']} 个模板 "
|
||||
f"(自定义: {statistics['customized']}, 系统默认: {statistics['system_default']})")
|
||||
|
||||
return PromptTemplateExport(
|
||||
templates=export_items,
|
||||
export_time=datetime.now(),
|
||||
version="2.0",
|
||||
statistics=statistics
|
||||
)
|
||||
|
||||
|
||||
@router.post("/import", response_model=PromptTemplateImportResult)
|
||||
async def import_templates(
|
||||
data: PromptTemplateExport,
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
智能导入提示词模板
|
||||
- 如果导入的是系统默认且内容未修改 → 删除自定义记录(使用系统默认)
|
||||
- 如果导入的是系统默认但内容已修改 → 创建自定义记录
|
||||
- 如果导入的是用户自定义 → 创建/更新自定义记录
|
||||
"""
|
||||
# 从认证中间件获取用户ID
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="未登录")
|
||||
|
||||
# 获取所有系统默认模板用于比对
|
||||
system_templates = PromptService.get_all_system_templates()
|
||||
system_template_dict = {t["template_key"]: t for t in system_templates}
|
||||
|
||||
# 统计信息
|
||||
kept_system_default = 0 # 保持系统默认
|
||||
created_or_updated = 0 # 创建或更新自定义
|
||||
converted_to_custom = 0 # 从系统默认转为自定义
|
||||
converted_templates = [] # 被转换的模板列表
|
||||
|
||||
for template_data in data.templates:
|
||||
template_key = template_data.template_key
|
||||
is_customized = template_data.is_customized
|
||||
imported_content = template_data.template_content.strip()
|
||||
|
||||
# 查找当前用户是否已有该模板的自定义版本
|
||||
result = await db.execute(
|
||||
select(PromptTemplate).where(
|
||||
PromptTemplate.user_id == user_id,
|
||||
PromptTemplate.template_key == template_key
|
||||
)
|
||||
)
|
||||
existing = result.scalar_one_or_none()
|
||||
|
||||
# 获取系统默认模板
|
||||
system_template = system_template_dict.get(template_key)
|
||||
|
||||
if not is_customized:
|
||||
# 导入的标记为系统默认
|
||||
if system_template:
|
||||
system_content = system_template["content"].strip()
|
||||
|
||||
# 比对内容是否与系统默认一致
|
||||
if imported_content == system_content:
|
||||
# 内容一致,删除自定义记录(如果有)
|
||||
if existing:
|
||||
await db.delete(existing)
|
||||
logger.info(f"用户 {user_id} 的模板 {template_key} 恢复为系统默认(删除自定义)")
|
||||
kept_system_default += 1
|
||||
else:
|
||||
# 内容不一致,用户修改过,创建/更新为自定义
|
||||
if existing:
|
||||
# 更新现有自定义
|
||||
existing.template_name = template_data.template_name
|
||||
existing.template_content = template_data.template_content
|
||||
existing.description = template_data.description
|
||||
existing.category = template_data.category
|
||||
existing.parameters = template_data.parameters
|
||||
existing.is_active = template_data.is_active
|
||||
else:
|
||||
# 创建新自定义
|
||||
new_template = PromptTemplate(
|
||||
user_id=user_id,
|
||||
template_key=template_data.template_key,
|
||||
template_name=template_data.template_name,
|
||||
template_content=template_data.template_content,
|
||||
description=template_data.description,
|
||||
category=template_data.category,
|
||||
parameters=template_data.parameters,
|
||||
is_active=template_data.is_active
|
||||
)
|
||||
db.add(new_template)
|
||||
|
||||
converted_to_custom += 1
|
||||
converted_templates.append({
|
||||
"template_key": template_key,
|
||||
"template_name": template_data.template_name,
|
||||
"reason": "内容与系统默认不一致,已转为自定义"
|
||||
})
|
||||
logger.info(f"用户 {user_id} 的模板 {template_key} 内容已修改,转为自定义")
|
||||
else:
|
||||
# 系统中不存在该模板,作为自定义导入
|
||||
if existing:
|
||||
existing.template_name = template_data.template_name
|
||||
existing.template_content = template_data.template_content
|
||||
existing.description = template_data.description
|
||||
existing.category = template_data.category
|
||||
existing.parameters = template_data.parameters
|
||||
existing.is_active = template_data.is_active
|
||||
else:
|
||||
new_template = PromptTemplate(
|
||||
user_id=user_id,
|
||||
template_key=template_data.template_key,
|
||||
template_name=template_data.template_name,
|
||||
template_content=template_data.template_content,
|
||||
description=template_data.description,
|
||||
category=template_data.category,
|
||||
parameters=template_data.parameters,
|
||||
is_active=template_data.is_active
|
||||
)
|
||||
db.add(new_template)
|
||||
created_or_updated += 1
|
||||
else:
|
||||
# 导入的标记为用户自定义,直接创建/更新
|
||||
if existing:
|
||||
existing.template_name = template_data.template_name
|
||||
existing.template_content = template_data.template_content
|
||||
existing.description = template_data.description
|
||||
existing.category = template_data.category
|
||||
existing.parameters = template_data.parameters
|
||||
existing.is_active = template_data.is_active
|
||||
else:
|
||||
new_template = PromptTemplate(
|
||||
user_id=user_id,
|
||||
template_key=template_data.template_key,
|
||||
template_name=template_data.template_name,
|
||||
template_content=template_data.template_content,
|
||||
description=template_data.description,
|
||||
category=template_data.category,
|
||||
parameters=template_data.parameters,
|
||||
is_active=template_data.is_active
|
||||
)
|
||||
db.add(new_template)
|
||||
created_or_updated += 1
|
||||
|
||||
await db.commit()
|
||||
|
||||
statistics = {
|
||||
"total": len(data.templates),
|
||||
"kept_system_default": kept_system_default,
|
||||
"created_or_updated": created_or_updated,
|
||||
"converted_to_custom": converted_to_custom
|
||||
}
|
||||
|
||||
logger.info(f"用户 {user_id} 导入完成: {statistics}")
|
||||
|
||||
return PromptTemplateImportResult(
|
||||
message="导入成功",
|
||||
statistics=statistics,
|
||||
converted_templates=converted_templates
|
||||
)
|
||||
|
||||
|
||||
@router.post("/{template_key}/preview")
|
||||
async def preview_template(
|
||||
template_key: str,
|
||||
data: PromptTemplatePreviewRequest,
|
||||
request: Request
|
||||
):
|
||||
"""
|
||||
预览提示词模板(渲染变量)
|
||||
"""
|
||||
# 从认证中间件获取用户ID
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="未登录")
|
||||
|
||||
try:
|
||||
# 使用PromptService的format_prompt方法
|
||||
rendered = PromptService.format_prompt(
|
||||
data.template_content,
|
||||
**data.parameters
|
||||
)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"rendered_content": rendered,
|
||||
"parameters_used": list(data.parameters.keys())
|
||||
}
|
||||
except KeyError as e:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"缺少必需的参数: {str(e)}",
|
||||
"rendered_content": None
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"渲染失败: {str(e)}",
|
||||
"rendered_content": None
|
||||
}
|
||||
@@ -1,5 +1,5 @@
|
||||
"""关系管理API"""
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Request
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, or_, and_
|
||||
from typing import List, Optional
|
||||
@@ -12,6 +12,7 @@ from app.models.relationship import (
|
||||
OrganizationMember
|
||||
)
|
||||
from app.models.character import Character
|
||||
from app.models.project import Project
|
||||
from app.schemas.relationship import (
|
||||
RelationshipTypeResponse,
|
||||
CharacterRelationshipCreate,
|
||||
@@ -27,6 +28,26 @@ router = APIRouter(prefix="/relationships", tags=["关系管理"])
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
async def verify_project_access(project_id: str, user_id: str, db: AsyncSession) -> Project:
|
||||
"""验证用户是否有权访问指定项目"""
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="未登录")
|
||||
|
||||
result = await db.execute(
|
||||
select(Project).where(
|
||||
Project.id == project_id,
|
||||
Project.user_id == user_id
|
||||
)
|
||||
)
|
||||
project = result.scalar_one_or_none()
|
||||
|
||||
if not project:
|
||||
logger.warning(f"项目访问被拒绝: project_id={project_id}, user_id={user_id}")
|
||||
raise HTTPException(status_code=404, detail="项目不存在或无权访问")
|
||||
|
||||
return project
|
||||
|
||||
|
||||
@router.get("/types", response_model=List[RelationshipTypeResponse], summary="获取关系类型列表")
|
||||
async def get_relationship_types(db: AsyncSession = Depends(get_db)):
|
||||
"""获取所有预定义的关系类型"""
|
||||
@@ -38,9 +59,14 @@ async def get_relationship_types(db: AsyncSession = Depends(get_db)):
|
||||
@router.get("/project/{project_id}", response_model=List[CharacterRelationshipResponse], summary="获取项目的所有关系")
|
||||
async def get_project_relationships(
|
||||
project_id: str,
|
||||
request: Request,
|
||||
character_id: Optional[str] = Query(None, description="筛选特定角色的关系"),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
# 验证用户权限
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
await verify_project_access(project_id, user_id, db)
|
||||
|
||||
"""
|
||||
获取项目中的所有角色关系
|
||||
|
||||
@@ -70,8 +96,13 @@ async def get_project_relationships(
|
||||
@router.get("/graph/{project_id}", response_model=RelationshipGraphData, summary="获取关系图谱数据")
|
||||
async def get_relationship_graph(
|
||||
project_id: str,
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
# 验证用户权限
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
await verify_project_access(project_id, user_id, db)
|
||||
|
||||
"""
|
||||
获取用于可视化的关系图谱数据
|
||||
|
||||
@@ -122,6 +153,7 @@ async def get_relationship_graph(
|
||||
@router.post("/", response_model=CharacterRelationshipResponse, summary="创建角色关系")
|
||||
async def create_relationship(
|
||||
relationship: CharacterRelationshipCreate,
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
@@ -131,6 +163,10 @@ async def create_relationship(
|
||||
- 可以指定预定义的关系类型或自定义关系名称
|
||||
- 可以设置亲密度、状态等属性
|
||||
"""
|
||||
# 验证用户权限
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
await verify_project_access(relationship.project_id, user_id, db)
|
||||
|
||||
# 验证角色是否存在
|
||||
char_from = await db.execute(
|
||||
select(Character).where(Character.id == relationship.character_from_id)
|
||||
@@ -161,6 +197,7 @@ async def create_relationship(
|
||||
async def update_relationship(
|
||||
relationship_id: str,
|
||||
relationship: CharacterRelationshipUpdate,
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""更新角色关系的属性(亲密度、状态等)"""
|
||||
@@ -174,6 +211,10 @@ async def update_relationship(
|
||||
if not db_rel:
|
||||
raise HTTPException(status_code=404, detail="关系不存在")
|
||||
|
||||
# 验证用户权限
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
await verify_project_access(db_rel.project_id, user_id, db)
|
||||
|
||||
# 更新字段
|
||||
update_data = relationship.model_dump(exclude_unset=True)
|
||||
for field, value in update_data.items():
|
||||
@@ -189,6 +230,7 @@ async def update_relationship(
|
||||
@router.delete("/{relationship_id}", summary="删除关系")
|
||||
async def delete_relationship(
|
||||
relationship_id: str,
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""删除角色关系"""
|
||||
@@ -202,6 +244,10 @@ async def delete_relationship(
|
||||
if not db_rel:
|
||||
raise HTTPException(status_code=404, detail="关系不存在")
|
||||
|
||||
# 验证用户权限
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
await verify_project_access(db_rel.project_id, user_id, db)
|
||||
|
||||
await db.delete(db_rel)
|
||||
await db.commit()
|
||||
|
||||
|
||||
+618
-25
@@ -4,18 +4,25 @@
|
||||
from fastapi import APIRouter, HTTPException, Request, Depends
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select
|
||||
from typing import Dict, Any, List
|
||||
from typing import Dict, Any, List, Optional
|
||||
from pathlib import Path
|
||||
from pydantic import BaseModel
|
||||
from datetime import datetime
|
||||
import httpx
|
||||
import json
|
||||
import time
|
||||
|
||||
from app.database import get_db
|
||||
from app.models.settings import Settings
|
||||
from app.schemas.settings import SettingsCreate, SettingsUpdate, SettingsResponse
|
||||
from app.schemas.settings import (
|
||||
SettingsCreate, SettingsUpdate, SettingsResponse,
|
||||
APIKeyPreset, APIKeyPresetConfig, PresetCreateRequest,
|
||||
PresetUpdateRequest, PresetResponse, PresetListResponse
|
||||
)
|
||||
from app.user_manager import User
|
||||
from app.logger import get_logger
|
||||
from app.config import settings as app_settings, PROJECT_ROOT
|
||||
from app.services.ai_service import AIService, create_user_ai_service
|
||||
from app.services.ai_service import AIService, create_user_ai_service, create_user_ai_service_with_mcp
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -46,9 +53,14 @@ async def get_user_ai_service(
|
||||
db: AsyncSession = Depends(get_db)
|
||||
) -> AIService:
|
||||
"""
|
||||
依赖:获取当前用户的AI服务实例
|
||||
从数据库读取用户设置并创建对应的AI服务
|
||||
依赖:获取当前用户的AI服务实例(支持MCP工具自动加载)
|
||||
|
||||
从数据库读取用户设置并创建对应的AI服务。
|
||||
自动传递 user_id 和 db_session,使得 AIService 能够加载用户配置的MCP工具。
|
||||
根据用户的所有MCP插件状态决定是否启用MCP:如果有启用的插件则启用,否则禁用。
|
||||
"""
|
||||
from app.models.mcp_plugin import MCPPlugin
|
||||
|
||||
result = await db.execute(
|
||||
select(Settings).where(Settings.user_id == user.user_id)
|
||||
)
|
||||
@@ -66,14 +78,34 @@ async def get_user_ai_service(
|
||||
await db.refresh(settings)
|
||||
logger.info(f"用户 {user.user_id} 首次使用AI服务,已从.env同步设置到数据库")
|
||||
|
||||
# 使用用户设置创建AI服务实例
|
||||
return create_user_ai_service(
|
||||
# 查询用户的所有MCP插件状态
|
||||
mcp_result = await db.execute(
|
||||
select(MCPPlugin).where(MCPPlugin.user_id == user.user_id)
|
||||
)
|
||||
mcp_plugins = mcp_result.scalars().all()
|
||||
|
||||
# 检查是否有启用的MCP插件
|
||||
enable_mcp = any(plugin.enabled for plugin in mcp_plugins) if mcp_plugins else False
|
||||
|
||||
if mcp_plugins:
|
||||
enabled_count = sum(1 for p in mcp_plugins if p.enabled)
|
||||
logger.info(f"用户 {user.user_id} 有 {len(mcp_plugins)} 个MCP插件,{enabled_count} 个启用,{enable_mcp} 决定使用MCP")
|
||||
else:
|
||||
logger.debug(f"用户 {user.user_id} 没有配置MCP插件,禁用MCP")
|
||||
|
||||
# ✅ 使用支持MCP的工厂函数创建AI服务实例
|
||||
# 传递 user_id 和 db_session,使得 AIService 能够自动加载用户配置的MCP工具
|
||||
return create_user_ai_service_with_mcp(
|
||||
api_provider=settings.api_provider,
|
||||
api_key=settings.api_key,
|
||||
api_base_url=settings.api_base_url or "",
|
||||
model_name=settings.llm_model,
|
||||
temperature=settings.temperature,
|
||||
max_tokens=settings.max_tokens
|
||||
max_tokens=settings.max_tokens,
|
||||
user_id=user.user_id, # ✅ 传递 user_id
|
||||
db_session=db, # ✅ 传递 db_session
|
||||
system_prompt=settings.system_prompt,
|
||||
enable_mcp=enable_mcp, # 根据MCP插件状态动态决定
|
||||
)
|
||||
|
||||
|
||||
@@ -242,10 +274,8 @@ async def get_available_models(
|
||||
if "data" in data and isinstance(data["data"], list):
|
||||
for model in data["data"]:
|
||||
model_id = model.get("id", "")
|
||||
# 过滤出常用的文本生成模型
|
||||
if any(keyword in model_id.lower() for keyword in [
|
||||
"gpt", "gemini", "claude", "llama", "mistral", "qwen", "deepseek"
|
||||
]):
|
||||
# 返回所有模型,不进行过滤
|
||||
if model_id:
|
||||
models.append({
|
||||
"value": model_id,
|
||||
"label": model_id,
|
||||
@@ -266,17 +296,30 @@ async def get_available_models(
|
||||
}
|
||||
|
||||
elif provider == "anthropic":
|
||||
# Anthropic 没有公开的模型列表API
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Anthropic 不支持自动获取模型列表,请手动输入模型名称"
|
||||
)
|
||||
# Anthropic models API
|
||||
url = f"{api_base_url.rstrip('/')}/v1/models"
|
||||
headers = {"x-api-key": api_key, "anthropic-version": "2023-06-01"}
|
||||
response = await client.get(url, headers=headers)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
models = [{"value": m["id"], "label": m["id"], "description": m.get("display_name", "")} for m in data.get("data", [])]
|
||||
return {"provider": provider, "models": models, "count": len(models)}
|
||||
|
||||
elif provider == "gemini":
|
||||
# Gemini models API
|
||||
url = f"{api_base_url.rstrip('/')}/models?key={api_key}"
|
||||
response = await client.get(url)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
models = []
|
||||
for m in data.get("models", []):
|
||||
if "generateContent" in m.get("supportedGenerationMethods", []):
|
||||
mid = m.get("name", "").replace("models/", "")
|
||||
models.append({"value": mid, "label": m.get("displayName", mid), "description": ""})
|
||||
return {"provider": provider, "models": models, "count": len(models)}
|
||||
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"不支持的提供商: {provider}"
|
||||
)
|
||||
raise HTTPException(status_code=400, detail=f"不支持的提供商: {provider}")
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.error(f"获取模型列表失败 (HTTP {e.response.status_code}): {e.response.text}")
|
||||
@@ -308,6 +351,227 @@ class ApiTestRequest(BaseModel):
|
||||
llm_model: str
|
||||
|
||||
|
||||
@router.post("/check-function-calling")
|
||||
async def check_function_calling_support(data: ApiTestRequest):
|
||||
"""
|
||||
检查模型是否支持 Function Calling(工具调用)
|
||||
|
||||
基于业界最佳实践的测试方法:
|
||||
1. 发送包含工具定义的请求
|
||||
2. 检查响应的 finish_reason 是否为 "tool_calls"
|
||||
3. 验证响应中是否包含有效的 tool_calls 数据
|
||||
|
||||
Args:
|
||||
data: 包含 API 配置的请求数据
|
||||
|
||||
Returns:
|
||||
检测结果包含支持状态、详细信息和建议
|
||||
"""
|
||||
api_key = data.api_key
|
||||
api_base_url = data.api_base_url
|
||||
provider = data.provider
|
||||
llm_model = data.llm_model
|
||||
|
||||
try:
|
||||
start_time = time.time()
|
||||
|
||||
# 定义一个简单的测试工具(天气查询)
|
||||
test_tools = [{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"description": "获取指定城市的当前天气信息",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {
|
||||
"type": "string",
|
||||
"description": "城市名称,例如:北京、上海、深圳"
|
||||
},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
"description": "温度单位"
|
||||
}
|
||||
},
|
||||
"required": ["city"]
|
||||
}
|
||||
}
|
||||
}]
|
||||
|
||||
# 测试提示:故意设计一个需要调用工具的问题
|
||||
test_prompt = "请告诉我北京现在的天气情况如何?"
|
||||
|
||||
logger.info(f"🧪 开始检测 Function Calling 支持")
|
||||
logger.info(f" - 提供商: {provider}")
|
||||
logger.info(f" - 模型: {llm_model}")
|
||||
logger.info(f" - 测试工具: get_weather")
|
||||
|
||||
# 创建临时 AI 服务实例进行测试
|
||||
test_service = AIService(
|
||||
api_provider=provider,
|
||||
api_key=api_key,
|
||||
api_base_url=api_base_url,
|
||||
default_model=llm_model,
|
||||
default_temperature=0.3, # 使用较低温度以获得更确定的行为
|
||||
default_max_tokens=200
|
||||
)
|
||||
|
||||
# 发送带工具的测试请求
|
||||
response = await test_service.generate_text(
|
||||
prompt=test_prompt,
|
||||
provider=provider,
|
||||
model=llm_model,
|
||||
temperature=0.3,
|
||||
max_tokens=200,
|
||||
tools=test_tools,
|
||||
tool_choice="auto", # 让模型自动决定是否使用工具
|
||||
auto_mcp=False # 禁用 MCP 自动加载
|
||||
)
|
||||
|
||||
end_time = time.time()
|
||||
response_time = round((end_time - start_time) * 1000, 2)
|
||||
|
||||
# 分析响应以确定是否支持 Function Calling
|
||||
supported = False
|
||||
finish_reason = None
|
||||
tool_calls = None
|
||||
response_content = None
|
||||
|
||||
if isinstance(response, dict):
|
||||
# 检查 finish_reason(OpenAI 标准)
|
||||
finish_reason = response.get("finish_reason")
|
||||
|
||||
# 检查是否有 tool_calls
|
||||
if "tool_calls" in response and response["tool_calls"]:
|
||||
supported = True
|
||||
tool_calls = response["tool_calls"]
|
||||
logger.info(f"✅ 检测到工具调用: {len(tool_calls)} 个")
|
||||
|
||||
# 记录返回的内容(如果有)
|
||||
if "content" in response:
|
||||
response_content = response["content"]
|
||||
elif isinstance(response, str):
|
||||
# 如果只返回字符串,说明不支持工具调用
|
||||
response_content = response
|
||||
|
||||
logger.info(f" - 响应时间: {response_time}ms")
|
||||
logger.info(f" - finish_reason: {finish_reason}")
|
||||
logger.info(f" - 支持状态: {'✅ 支持' if supported else '❌ 不支持'}")
|
||||
|
||||
# 构建详细的返回信息
|
||||
result = {
|
||||
"success": True,
|
||||
"supported": supported,
|
||||
"message": "✅ 模型支持 Function Calling" if supported else "❌ 模型不支持 Function Calling",
|
||||
"response_time_ms": response_time,
|
||||
"provider": provider,
|
||||
"model": llm_model,
|
||||
"details": {
|
||||
"finish_reason": finish_reason,
|
||||
"has_tool_calls": bool(tool_calls),
|
||||
"tool_call_count": len(tool_calls) if tool_calls else 0,
|
||||
"test_tool": "get_weather",
|
||||
"test_prompt": test_prompt,
|
||||
"response_type": "tool_calls" if supported else "text"
|
||||
}
|
||||
}
|
||||
|
||||
# 添加工具调用详情
|
||||
if tool_calls:
|
||||
result["tool_calls"] = tool_calls
|
||||
result["suggestions"] = [
|
||||
"✅ 该模型支持 Function Calling,可以正常使用 MCP 插件",
|
||||
"建议:启用需要的 MCP 插件以扩展 AI 能力",
|
||||
"提示:测试成功检测到工具调用,模型能够正确解析和使用外部工具"
|
||||
]
|
||||
else:
|
||||
result["response_preview"] = response_content[:200] if response_content else None
|
||||
result["suggestions"] = [
|
||||
"❌ 该模型不支持 Function Calling,无法使用 MCP 插件功能",
|
||||
"建议:更换支持工具调用的模型",
|
||||
"推荐模型:GPT-4 系列、GPT-4-turbo、Claude 3 Opus/Sonnet、Gemini 1.5 Pro 等",
|
||||
"说明:模型返回了文本回复而非工具调用,表明不支持该功能"
|
||||
]
|
||||
|
||||
return result
|
||||
|
||||
except ValueError as e:
|
||||
error_msg = str(e)
|
||||
logger.error(f"❌ Function Calling 检测配置错误: {error_msg}")
|
||||
return {
|
||||
"success": False,
|
||||
"supported": False,
|
||||
"message": "配置错误",
|
||||
"error": error_msg,
|
||||
"error_type": "ConfigurationError",
|
||||
"suggestions": [
|
||||
"请检查 API Key 是否正确",
|
||||
"请确认 API Base URL 格式是否正确",
|
||||
"请验证所选提供商与配置是否匹配"
|
||||
]
|
||||
}
|
||||
|
||||
except TimeoutError as e:
|
||||
error_msg = str(e)
|
||||
logger.error(f"❌ Function Calling 检测超时: {error_msg}")
|
||||
return {
|
||||
"success": False,
|
||||
"supported": None,
|
||||
"message": "检测超时",
|
||||
"error": error_msg,
|
||||
"error_type": "TimeoutError",
|
||||
"suggestions": [
|
||||
"请检查网络连接是否正常",
|
||||
"请确认 API 服务是否可访问",
|
||||
"建议:稍后重试或使用其他网络环境"
|
||||
]
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
error_type = type(e).__name__
|
||||
|
||||
logger.error(f"❌ Function Calling 检测失败: {error_msg}")
|
||||
logger.error(f" - 错误类型: {error_type}")
|
||||
|
||||
# 智能分析错误原因
|
||||
suggestions = []
|
||||
if "tool" in error_msg.lower() or "function" in error_msg.lower():
|
||||
suggestions = [
|
||||
"该模型可能不支持 Function Calling 功能",
|
||||
"API 返回了与工具调用相关的错误",
|
||||
"建议:更换支持工具调用的模型或联系 API 提供商"
|
||||
]
|
||||
elif "unauthorized" in error_msg.lower() or "401" in error_msg:
|
||||
suggestions = [
|
||||
"API Key 认证失败",
|
||||
"请检查 API Key 是否正确且有效",
|
||||
"请确认 API Key 是否有足够的权限"
|
||||
]
|
||||
elif "not found" in error_msg.lower() or "404" in error_msg:
|
||||
suggestions = [
|
||||
"模型不存在或不可用",
|
||||
"请检查模型名称是否正确",
|
||||
"请确认该模型在当前 API 中是否可用"
|
||||
]
|
||||
else:
|
||||
suggestions = [
|
||||
"检测过程中遇到未知错误",
|
||||
"建议:检查所有配置参数是否正确",
|
||||
"提示:查看详细错误信息以获取更多线索"
|
||||
]
|
||||
|
||||
return {
|
||||
"success": False,
|
||||
"supported": False,
|
||||
"message": "Function Calling 检测失败",
|
||||
"error": error_msg,
|
||||
"error_type": error_type,
|
||||
"suggestions": suggestions
|
||||
}
|
||||
|
||||
|
||||
@router.post("/test")
|
||||
async def test_api_connection(data: ApiTestRequest):
|
||||
"""
|
||||
@@ -351,7 +615,8 @@ async def test_api_connection(data: ApiTestRequest):
|
||||
provider=provider,
|
||||
model=llm_model,
|
||||
temperature=0.7,
|
||||
max_tokens=8000
|
||||
max_tokens=8000,
|
||||
auto_mcp=False # 测试时不加载MCP工具
|
||||
)
|
||||
|
||||
end_time = time.time()
|
||||
@@ -359,7 +624,10 @@ async def test_api_connection(data: ApiTestRequest):
|
||||
|
||||
logger.info(f"✅ API 测试成功")
|
||||
logger.info(f" - 响应时间: {response_time}ms")
|
||||
logger.info(f" - 响应内容: {response[:100] if response else 'N/A'}")
|
||||
|
||||
# 安全地处理响应内容(确保是字符串)
|
||||
response_str = str(response) if response else 'N/A'
|
||||
logger.info(f" - 响应内容: {response_str[:100]}")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
@@ -367,7 +635,7 @@ async def test_api_connection(data: ApiTestRequest):
|
||||
"response_time_ms": response_time,
|
||||
"provider": provider,
|
||||
"model": llm_model,
|
||||
"response_preview": response[:100] if response and len(response) > 100 else response,
|
||||
"response_preview": response_str[:100] if len(response_str) > 100 else response_str,
|
||||
"details": {
|
||||
"api_available": True,
|
||||
"model_accessible": True,
|
||||
@@ -461,4 +729,329 @@ async def test_api_connection(data: ApiTestRequest):
|
||||
"error": error_msg,
|
||||
"error_type": error_type,
|
||||
"suggestions": suggestions
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
# ========== API配置预设管理(零数据库改动方案)==========
|
||||
|
||||
async def get_user_settings(user_id: str, db: AsyncSession) -> Settings:
|
||||
"""获取用户settings,如果不存在则创建"""
|
||||
result = await db.execute(
|
||||
select(Settings).where(Settings.user_id == user_id)
|
||||
)
|
||||
settings = result.scalar_one_or_none()
|
||||
|
||||
if not settings:
|
||||
# 创建默认设置
|
||||
env_defaults = read_env_defaults()
|
||||
settings = Settings(
|
||||
user_id=user_id,
|
||||
**env_defaults,
|
||||
preferences='{}' # 初始化为空JSON
|
||||
)
|
||||
db.add(settings)
|
||||
await db.commit()
|
||||
await db.refresh(settings)
|
||||
logger.info(f"用户 {user_id} 首次访问,已创建默认设置")
|
||||
|
||||
return settings
|
||||
|
||||
|
||||
@router.get("/presets", response_model=PresetListResponse)
|
||||
async def get_presets(
|
||||
user: User = Depends(require_login),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
获取所有API配置预设
|
||||
|
||||
从preferences字段读取预设列表
|
||||
"""
|
||||
settings = await get_user_settings(user.user_id, db)
|
||||
|
||||
# 解析preferences
|
||||
try:
|
||||
prefs = json.loads(settings.preferences or '{}')
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"用户 {user.user_id} 的preferences字段JSON格式错误,重置为空")
|
||||
prefs = {}
|
||||
|
||||
api_presets = prefs.get('api_presets', {'presets': [], 'version': '1.0'})
|
||||
presets = api_presets.get('presets', [])
|
||||
|
||||
# 找到激活的预设
|
||||
active_preset_id = next(
|
||||
(p['id'] for p in presets if p.get('is_active')),
|
||||
None
|
||||
)
|
||||
|
||||
logger.info(f"用户 {user.user_id} 获取预设列表,共 {len(presets)} 个")
|
||||
|
||||
return {
|
||||
"presets": presets,
|
||||
"total": len(presets),
|
||||
"active_preset_id": active_preset_id
|
||||
}
|
||||
|
||||
|
||||
@router.post("/presets", response_model=PresetResponse)
|
||||
async def create_preset(
|
||||
data: PresetCreateRequest,
|
||||
user: User = Depends(require_login),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
创建新预设
|
||||
|
||||
将预设添加到preferences字段的JSON中
|
||||
"""
|
||||
settings = await get_user_settings(user.user_id, db)
|
||||
|
||||
# 解析preferences
|
||||
try:
|
||||
prefs = json.loads(settings.preferences or '{}')
|
||||
except json.JSONDecodeError:
|
||||
prefs = {}
|
||||
|
||||
api_presets = prefs.get('api_presets', {'presets': [], 'version': '1.0'})
|
||||
presets = api_presets.get('presets', [])
|
||||
|
||||
# 创建新预设
|
||||
new_preset = {
|
||||
"id": f"preset_{int(datetime.now().timestamp() * 1000)}",
|
||||
"name": data.name,
|
||||
"description": data.description,
|
||||
"is_active": False,
|
||||
"created_at": datetime.now().isoformat(),
|
||||
"config": data.config.model_dump()
|
||||
}
|
||||
|
||||
presets.append(new_preset)
|
||||
|
||||
# 保存回preferences
|
||||
api_presets['presets'] = presets
|
||||
prefs['api_presets'] = api_presets
|
||||
settings.preferences = json.dumps(prefs, ensure_ascii=False)
|
||||
|
||||
await db.commit()
|
||||
|
||||
logger.info(f"用户 {user.user_id} 创建预设: {data.name}")
|
||||
return new_preset
|
||||
|
||||
|
||||
@router.put("/presets/{preset_id}", response_model=PresetResponse)
|
||||
async def update_preset(
|
||||
preset_id: str,
|
||||
data: PresetUpdateRequest,
|
||||
user: User = Depends(require_login),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
更新预设
|
||||
|
||||
在preferences字段的JSON中更新指定预设
|
||||
"""
|
||||
settings = await get_user_settings(user.user_id, db)
|
||||
|
||||
# 解析preferences
|
||||
try:
|
||||
prefs = json.loads(settings.preferences or '{}')
|
||||
except json.JSONDecodeError:
|
||||
raise HTTPException(status_code=500, detail="配置数据格式错误")
|
||||
|
||||
api_presets = prefs.get('api_presets', {'presets': [], 'version': '1.0'})
|
||||
presets = api_presets.get('presets', [])
|
||||
|
||||
# 找到并更新预设
|
||||
target_preset = next((p for p in presets if p['id'] == preset_id), None)
|
||||
if not target_preset:
|
||||
raise HTTPException(status_code=404, detail="预设不存在")
|
||||
|
||||
# 更新字段
|
||||
if data.name is not None:
|
||||
target_preset['name'] = data.name
|
||||
if data.description is not None:
|
||||
target_preset['description'] = data.description
|
||||
if data.config is not None:
|
||||
target_preset['config'] = data.config.model_dump()
|
||||
|
||||
# 保存回preferences
|
||||
prefs['api_presets'] = api_presets
|
||||
settings.preferences = json.dumps(prefs, ensure_ascii=False)
|
||||
|
||||
await db.commit()
|
||||
|
||||
logger.info(f"用户 {user.user_id} 更新预设: {preset_id}")
|
||||
return target_preset
|
||||
|
||||
|
||||
@router.delete("/presets/{preset_id}")
|
||||
async def delete_preset(
|
||||
preset_id: str,
|
||||
user: User = Depends(require_login),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
删除预设
|
||||
|
||||
从preferences字段的JSON中删除指定预设
|
||||
"""
|
||||
settings = await get_user_settings(user.user_id, db)
|
||||
|
||||
# 解析preferences
|
||||
try:
|
||||
prefs = json.loads(settings.preferences or '{}')
|
||||
except json.JSONDecodeError:
|
||||
raise HTTPException(status_code=500, detail="配置数据格式错误")
|
||||
|
||||
api_presets = prefs.get('api_presets', {'presets': [], 'version': '1.0'})
|
||||
presets = api_presets.get('presets', [])
|
||||
|
||||
# 找到预设
|
||||
target_preset = next((p for p in presets if p['id'] == preset_id), None)
|
||||
if not target_preset:
|
||||
raise HTTPException(status_code=404, detail="预设不存在")
|
||||
|
||||
# 检查是否是激活的预设
|
||||
if target_preset.get('is_active'):
|
||||
raise HTTPException(status_code=400, detail="无法删除激活中的预设,请先激活其他预设")
|
||||
|
||||
# 删除预设
|
||||
presets = [p for p in presets if p['id'] != preset_id]
|
||||
|
||||
# 保存回preferences
|
||||
api_presets['presets'] = presets
|
||||
prefs['api_presets'] = api_presets
|
||||
settings.preferences = json.dumps(prefs, ensure_ascii=False)
|
||||
|
||||
await db.commit()
|
||||
|
||||
logger.info(f"用户 {user.user_id} 删除预设: {preset_id}")
|
||||
return {"message": "预设已删除", "preset_id": preset_id}
|
||||
|
||||
|
||||
@router.post("/presets/{preset_id}/activate")
|
||||
async def activate_preset(
|
||||
preset_id: str,
|
||||
user: User = Depends(require_login),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
激活预设
|
||||
|
||||
将预设的配置应用到Settings主字段
|
||||
"""
|
||||
settings = await get_user_settings(user.user_id, db)
|
||||
|
||||
# 解析preferences
|
||||
try:
|
||||
prefs = json.loads(settings.preferences or '{}')
|
||||
except json.JSONDecodeError:
|
||||
raise HTTPException(status_code=500, detail="配置数据格式错误")
|
||||
|
||||
api_presets = prefs.get('api_presets', {'presets': [], 'version': '1.0'})
|
||||
presets = api_presets.get('presets', [])
|
||||
|
||||
# 找到目标预设
|
||||
target_preset = next((p for p in presets if p['id'] == preset_id), None)
|
||||
if not target_preset:
|
||||
raise HTTPException(status_code=404, detail="预设不存在")
|
||||
|
||||
# 应用配置到Settings主字段
|
||||
config = target_preset['config']
|
||||
settings.api_provider = config['api_provider']
|
||||
settings.api_key = config['api_key']
|
||||
settings.api_base_url = config.get('api_base_url')
|
||||
settings.llm_model = config['llm_model']
|
||||
settings.temperature = config['temperature']
|
||||
settings.max_tokens = config['max_tokens']
|
||||
|
||||
# 更新所有预设的is_active状态
|
||||
for preset in presets:
|
||||
preset['is_active'] = (preset['id'] == preset_id)
|
||||
|
||||
# 保存回preferences
|
||||
prefs['api_presets'] = api_presets
|
||||
settings.preferences = json.dumps(prefs, ensure_ascii=False)
|
||||
|
||||
await db.commit()
|
||||
|
||||
logger.info(f"用户 {user.user_id} 激活预设: {target_preset['name']}")
|
||||
return {
|
||||
"message": "预设已激活",
|
||||
"preset_id": preset_id,
|
||||
"preset_name": target_preset['name']
|
||||
}
|
||||
|
||||
|
||||
@router.post("/presets/{preset_id}/test")
|
||||
async def test_preset(
|
||||
preset_id: str,
|
||||
user: User = Depends(require_login),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
测试预设的API连接
|
||||
"""
|
||||
settings = await get_user_settings(user.user_id, db)
|
||||
|
||||
# 解析preferences
|
||||
try:
|
||||
prefs = json.loads(settings.preferences or '{}')
|
||||
except json.JSONDecodeError:
|
||||
raise HTTPException(status_code=500, detail="配置数据格式错误")
|
||||
|
||||
api_presets = prefs.get('api_presets', {'presets': [], 'version': '1.0'})
|
||||
presets = api_presets.get('presets', [])
|
||||
|
||||
# 找到预设
|
||||
target_preset = next((p for p in presets if p['id'] == preset_id), None)
|
||||
if not target_preset:
|
||||
raise HTTPException(status_code=404, detail="预设不存在")
|
||||
|
||||
# 使用现有的test_api_connection逻辑
|
||||
config = target_preset['config']
|
||||
test_request = ApiTestRequest(
|
||||
api_key=config['api_key'],
|
||||
api_base_url=config.get('api_base_url', ''),
|
||||
provider=config['api_provider'],
|
||||
llm_model=config['llm_model']
|
||||
)
|
||||
|
||||
logger.info(f"用户 {user.user_id} 测试预设: {target_preset['name']}")
|
||||
return await test_api_connection(test_request)
|
||||
|
||||
|
||||
@router.post("/presets/from-current", response_model=PresetResponse)
|
||||
async def create_preset_from_current(
|
||||
name: str,
|
||||
description: Optional[str] = None,
|
||||
user: User = Depends(require_login),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
从当前配置创建新预设
|
||||
|
||||
快捷方式:将当前激活的配置保存为新预设
|
||||
"""
|
||||
settings = await get_user_settings(user.user_id, db)
|
||||
|
||||
# 从当前Settings主字段读取配置
|
||||
current_config = APIKeyPresetConfig(
|
||||
api_provider=settings.api_provider,
|
||||
api_key=settings.api_key,
|
||||
api_base_url=settings.api_base_url,
|
||||
llm_model=settings.llm_model,
|
||||
temperature=settings.temperature,
|
||||
max_tokens=settings.max_tokens
|
||||
)
|
||||
|
||||
# 创建预设
|
||||
create_request = PresetCreateRequest(
|
||||
name=name,
|
||||
description=description,
|
||||
config=current_config
|
||||
)
|
||||
|
||||
logger.info(f"用户 {user.user_id} 从当前配置创建预设: {name}")
|
||||
return await create_preset(create_request, user, db)
|
||||
@@ -5,6 +5,7 @@ from fastapi import APIRouter, HTTPException, Request, Depends
|
||||
from pydantic import BaseModel
|
||||
from typing import List, Optional
|
||||
from app.user_manager import user_manager, User
|
||||
from app.user_password import password_manager
|
||||
|
||||
router = APIRouter(prefix="/users", tags=["用户管理"])
|
||||
|
||||
@@ -29,6 +30,11 @@ class SetAdminRequest(BaseModel):
|
||||
is_admin: bool
|
||||
|
||||
|
||||
class ResetPasswordRequest(BaseModel):
|
||||
user_id: str
|
||||
new_password: Optional[str] = None # 如果为空则使用默认密码
|
||||
|
||||
|
||||
@router.get("/current")
|
||||
async def get_current_user(user: User = Depends(require_login)):
|
||||
"""获取当前登录用户信息"""
|
||||
@@ -122,4 +128,62 @@ async def get_user(
|
||||
if not user:
|
||||
raise HTTPException(status_code=404, detail="用户不存在")
|
||||
|
||||
return user.dict()
|
||||
return user.dict()
|
||||
|
||||
|
||||
@router.post("/reset-password")
|
||||
async def reset_user_password(
|
||||
data: ResetPasswordRequest,
|
||||
admin_user: User = Depends(require_admin)
|
||||
):
|
||||
"""
|
||||
重置用户密码(仅管理员)
|
||||
|
||||
如果提供了 new_password,则设置为指定密码
|
||||
如果未提供 new_password,则重置为默认密码(username@666)
|
||||
|
||||
限制:
|
||||
- 不能重置自己的密码(应该使用修改密码功能)
|
||||
"""
|
||||
# 检查是否尝试重置自己的密码
|
||||
if data.user_id == admin_user.user_id:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="不能重置自己的密码,请使用修改密码功能"
|
||||
)
|
||||
|
||||
# 检查目标用户是否存在
|
||||
target_user = await user_manager.get_user(data.user_id)
|
||||
if not target_user:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="目标用户不存在"
|
||||
)
|
||||
|
||||
# 重置密码
|
||||
try:
|
||||
actual_password = await password_manager.set_password(
|
||||
target_user.user_id,
|
||||
target_user.username,
|
||||
data.new_password
|
||||
)
|
||||
|
||||
# 如果使用了默认密码,返回密码供管理员告知用户
|
||||
message = "密码重置成功"
|
||||
response_data = {
|
||||
"message": message,
|
||||
"user_id": data.user_id,
|
||||
"username": target_user.username
|
||||
}
|
||||
|
||||
if not data.new_password:
|
||||
response_data["default_password"] = actual_password
|
||||
response_data["message"] = f"密码已重置为默认密码: {actual_password}"
|
||||
|
||||
return response_data
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"重置密码失败: {str(e)}"
|
||||
)
|
||||
+973
-575
File diff suppressed because it is too large
Load Diff
@@ -1,5 +1,5 @@
|
||||
"""写作风格管理 API"""
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, func, delete
|
||||
from typing import List
|
||||
@@ -15,62 +15,90 @@ from ..schemas.writing_style import (
|
||||
WritingStyleListResponse,
|
||||
SetDefaultStyleRequest
|
||||
)
|
||||
from ..services.prompt_service import WritingStyleManager
|
||||
from ..logger import get_logger
|
||||
|
||||
router = APIRouter(prefix="/writing-styles", tags=["writing-styles"])
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def get_current_user_id(request: Request) -> str:
|
||||
"""获取当前登录用户ID"""
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="未登录")
|
||||
return user_id
|
||||
|
||||
|
||||
@router.get("/presets/list", response_model=List[dict])
|
||||
async def get_preset_styles():
|
||||
async def get_preset_styles(db: AsyncSession = Depends(get_db)):
|
||||
"""
|
||||
获取所有预设风格列表
|
||||
获取所有预设风格列表(从数据库读取)
|
||||
|
||||
返回格式:数组形式的预设风格列表
|
||||
[
|
||||
{"id": "natural", "name": "自然流畅", "description": "...", "prompt_content": "..."},
|
||||
{"id": "classical", "name": "古典优雅", ...}
|
||||
{"id": 1, "preset_id": "natural", "name": "自然流畅", "description": "...", "prompt_content": "..."},
|
||||
{"id": 2, "preset_id": "classical", "name": "古典优雅", ...}
|
||||
]
|
||||
"""
|
||||
presets = WritingStyleManager.get_all_presets()
|
||||
# 将字典转换为数组,添加 id 字段
|
||||
# 从数据库获取全局预设风格(user_id 为 NULL)
|
||||
result = await db.execute(
|
||||
select(WritingStyle)
|
||||
.where(WritingStyle.user_id.is_(None))
|
||||
.order_by(WritingStyle.order_index)
|
||||
)
|
||||
preset_styles = result.scalars().all()
|
||||
|
||||
# 转换为响应格式
|
||||
return [
|
||||
{"id": preset_id, **preset_data}
|
||||
for preset_id, preset_data in presets.items()
|
||||
{
|
||||
"id": style.id,
|
||||
"preset_id": style.preset_id,
|
||||
"name": style.name,
|
||||
"description": style.description,
|
||||
"prompt_content": style.prompt_content,
|
||||
"style_type": style.style_type,
|
||||
"order_index": style.order_index
|
||||
}
|
||||
for style in preset_styles
|
||||
]
|
||||
|
||||
|
||||
@router.post("", response_model=WritingStyleResponse, status_code=201)
|
||||
async def create_writing_style(
|
||||
style_data: WritingStyleCreate,
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
创建新的写作风格
|
||||
创建新的写作风格(用户级别)
|
||||
|
||||
- **基于预设创建**:提供 preset_id,系统会自动填充预设内容
|
||||
- **基于预设创建**:提供 preset_id,系统会从数据库查询预设内容自动填充
|
||||
- **完全自定义**:不提供 preset_id,需要手动填写所有字段
|
||||
"""
|
||||
# 验证项目是否存在
|
||||
result = await db.execute(
|
||||
select(Project).where(Project.id == style_data.project_id)
|
||||
)
|
||||
project = result.scalar_one_or_none()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="项目不存在")
|
||||
# 获取当前用户ID
|
||||
user_id = get_current_user_id(request)
|
||||
|
||||
# 如果基于预设创建,获取预设内容
|
||||
# 如果基于预设创建,从数据库获取预设内容
|
||||
if style_data.preset_id:
|
||||
preset = WritingStyleManager.get_preset_style(style_data.preset_id)
|
||||
result = await db.execute(
|
||||
select(WritingStyle)
|
||||
.where(
|
||||
WritingStyle.user_id.is_(None),
|
||||
WritingStyle.preset_id == style_data.preset_id
|
||||
)
|
||||
)
|
||||
preset = result.scalar_one_or_none()
|
||||
|
||||
if not preset:
|
||||
raise HTTPException(status_code=400, detail=f"预设风格 '{style_data.preset_id}' 不存在")
|
||||
|
||||
# 使用预设内容填充(如果用户未提供)
|
||||
if not style_data.name:
|
||||
style_data.name = preset["name"]
|
||||
style_data.name = preset.name
|
||||
if not style_data.description:
|
||||
style_data.description = preset["description"]
|
||||
style_data.description = preset.description
|
||||
if not style_data.prompt_content:
|
||||
style_data.prompt_content = preset["prompt_content"]
|
||||
style_data.prompt_content = preset.prompt_content
|
||||
|
||||
# 验证必填字段
|
||||
if not style_data.name or not style_data.prompt_content:
|
||||
@@ -79,16 +107,16 @@ async def create_writing_style(
|
||||
detail="name 和 prompt_content 是必填字段"
|
||||
)
|
||||
|
||||
# 获取当前最大 order_index
|
||||
# 获取当前用户的最大 order_index
|
||||
count_result = await db.execute(
|
||||
select(func.count(WritingStyle.id))
|
||||
.where(WritingStyle.project_id == style_data.project_id)
|
||||
.where(WritingStyle.user_id == user_id)
|
||||
)
|
||||
max_order = count_result.scalar_one()
|
||||
|
||||
# 创建风格记录
|
||||
new_style = WritingStyle(
|
||||
project_id=style_data.project_id,
|
||||
user_id=user_id,
|
||||
name=style_data.name,
|
||||
style_type=style_data.style_type or ("preset" if style_data.preset_id else "custom"),
|
||||
preset_id=style_data.preset_id,
|
||||
@@ -104,7 +132,7 @@ async def create_writing_style(
|
||||
# 返回包含 is_default 字段的字典(新创建的风格默认不是默认风格)
|
||||
return {
|
||||
"id": new_style.id,
|
||||
"project_id": new_style.project_id,
|
||||
"user_id": new_style.user_id,
|
||||
"name": new_style.name,
|
||||
"style_type": new_style.style_type,
|
||||
"preset_id": new_style.preset_id,
|
||||
@@ -117,24 +145,85 @@ async def create_writing_style(
|
||||
}
|
||||
|
||||
|
||||
@router.get("/project/{project_id}", response_model=WritingStyleListResponse)
|
||||
async def get_project_styles(
|
||||
project_id: str,
|
||||
@router.get("/user", response_model=WritingStyleListResponse)
|
||||
async def get_user_styles(
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
获取项目的所有可用写作风格
|
||||
获取用户的所有可用写作风格
|
||||
|
||||
返回:全局预设风格 + 该项目的自定义风格
|
||||
返回:全局预设风格 + 该用户的自定义风格
|
||||
按 order_index 排序
|
||||
"""
|
||||
# 获取当前用户ID
|
||||
user_id = get_current_user_id(request)
|
||||
|
||||
# 获取全局预设风格(user_id 为 NULL)
|
||||
result = await db.execute(
|
||||
select(WritingStyle)
|
||||
.where(WritingStyle.user_id.is_(None))
|
||||
.order_by(WritingStyle.order_index)
|
||||
)
|
||||
preset_styles = list(result.scalars().all())
|
||||
|
||||
# 获取用户自定义风格
|
||||
result = await db.execute(
|
||||
select(WritingStyle)
|
||||
.where(WritingStyle.user_id == user_id)
|
||||
.order_by(WritingStyle.order_index)
|
||||
)
|
||||
custom_styles = list(result.scalars().all())
|
||||
|
||||
# 合并:预设风格 + 自定义风格
|
||||
all_styles = preset_styles + custom_styles
|
||||
|
||||
# 转换为响应格式
|
||||
styles_with_default = []
|
||||
for style in all_styles:
|
||||
style_dict = {
|
||||
"id": style.id,
|
||||
"user_id": style.user_id,
|
||||
"name": style.name,
|
||||
"style_type": style.style_type,
|
||||
"preset_id": style.preset_id,
|
||||
"description": style.description,
|
||||
"prompt_content": style.prompt_content,
|
||||
"order_index": style.order_index,
|
||||
"created_at": style.created_at,
|
||||
"updated_at": style.updated_at,
|
||||
"is_default": False # 用户级别不再需要默认风格标记
|
||||
}
|
||||
styles_with_default.append(style_dict)
|
||||
|
||||
return {"styles": styles_with_default, "total": len(styles_with_default)}
|
||||
|
||||
|
||||
@router.get("/project/{project_id}", response_model=WritingStyleListResponse)
|
||||
async def get_project_styles(
|
||||
project_id: str,
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
获取项目可用的所有写作风格(保留用于向后兼容)
|
||||
|
||||
返回:全局预设风格 + 该用户的自定义风格
|
||||
按 order_index 排序,并标记哪个是当前项目的默认风格
|
||||
"""
|
||||
# 验证项目是否存在
|
||||
# 获取当前用户ID
|
||||
user_id = get_current_user_id(request)
|
||||
|
||||
# 验证项目访问权限
|
||||
result = await db.execute(
|
||||
select(Project).where(Project.id == project_id)
|
||||
select(Project).where(
|
||||
Project.id == project_id,
|
||||
Project.user_id == user_id
|
||||
)
|
||||
)
|
||||
project = result.scalar_one_or_none()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="项目不存在")
|
||||
raise HTTPException(status_code=404, detail="项目不存在或无权访问")
|
||||
|
||||
# 获取该项目的默认风格ID
|
||||
result = await db.execute(
|
||||
@@ -143,18 +232,18 @@ async def get_project_styles(
|
||||
)
|
||||
default_style_id = result.scalar_one_or_none()
|
||||
|
||||
# 获取全局预设风格(project_id 为 NULL)
|
||||
# 获取全局预设风格(user_id 为 NULL)
|
||||
result = await db.execute(
|
||||
select(WritingStyle)
|
||||
.where(WritingStyle.project_id.is_(None))
|
||||
.where(WritingStyle.user_id.is_(None))
|
||||
.order_by(WritingStyle.order_index)
|
||||
)
|
||||
preset_styles = list(result.scalars().all())
|
||||
|
||||
# 获取项目自定义风格
|
||||
# 获取用户自定义风格
|
||||
result = await db.execute(
|
||||
select(WritingStyle)
|
||||
.where(WritingStyle.project_id == project_id)
|
||||
.where(WritingStyle.user_id == user_id)
|
||||
.order_by(WritingStyle.order_index)
|
||||
)
|
||||
custom_styles = list(result.scalars().all())
|
||||
@@ -167,7 +256,7 @@ async def get_project_styles(
|
||||
for style in all_styles:
|
||||
style_dict = {
|
||||
"id": style.id,
|
||||
"project_id": style.project_id,
|
||||
"user_id": style.user_id,
|
||||
"name": style.name,
|
||||
"style_type": style.style_type,
|
||||
"preset_id": style.preset_id,
|
||||
@@ -196,16 +285,16 @@ async def get_writing_style(
|
||||
if not style:
|
||||
raise HTTPException(status_code=404, detail="写作风格不存在")
|
||||
|
||||
# 检查是否有项目将其设置为默认风格
|
||||
# 检查是否有项目将其设置为默认风格(一个风格可能被多个项目使用,使用 first() 避免 MultipleResultsFound)
|
||||
result = await db.execute(
|
||||
select(ProjectDefaultStyle).where(ProjectDefaultStyle.style_id == style_id)
|
||||
)
|
||||
is_default = result.scalar_one_or_none() is not None
|
||||
is_default = result.scalars().first() is not None
|
||||
|
||||
# 返回包含 is_default 字段的字典
|
||||
return {
|
||||
"id": style.id,
|
||||
"project_id": style.project_id,
|
||||
"user_id": style.user_id,
|
||||
"name": style.name,
|
||||
"style_type": style.style_type,
|
||||
"preset_id": style.preset_id,
|
||||
@@ -222,6 +311,7 @@ async def get_writing_style(
|
||||
async def update_writing_style(
|
||||
style_id: int,
|
||||
style_data: WritingStyleUpdate,
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
@@ -230,6 +320,9 @@ async def update_writing_style(
|
||||
- 只能修改自定义风格
|
||||
- 不能修改全局预设风格
|
||||
"""
|
||||
# 获取当前用户ID
|
||||
user_id = get_current_user_id(request)
|
||||
|
||||
result = await db.execute(
|
||||
select(WritingStyle).where(WritingStyle.id == style_id)
|
||||
)
|
||||
@@ -238,9 +331,13 @@ async def update_writing_style(
|
||||
raise HTTPException(status_code=404, detail="写作风格不存在")
|
||||
|
||||
# 检查是否为全局预设风格(不允许修改)
|
||||
if style.project_id is None:
|
||||
if style.user_id is None:
|
||||
raise HTTPException(status_code=403, detail="不能修改全局预设风格,只能修改自定义风格")
|
||||
|
||||
# 验证用户权限(只能修改自己的风格)
|
||||
if style.user_id != user_id:
|
||||
raise HTTPException(status_code=403, detail="无权修改其他用户的风格")
|
||||
|
||||
# 更新字段
|
||||
update_data = style_data.model_dump(exclude_unset=True)
|
||||
|
||||
@@ -254,16 +351,16 @@ async def update_writing_style(
|
||||
await db.commit()
|
||||
await db.refresh(style)
|
||||
|
||||
# 检查是否有项目将其设置为默认风格
|
||||
# 检查是否有项目将其设置为默认风格(一个风格可能被多个项目使用,使用 first() 避免 MultipleResultsFound)
|
||||
result = await db.execute(
|
||||
select(ProjectDefaultStyle).where(ProjectDefaultStyle.style_id == style_id)
|
||||
)
|
||||
is_default = result.scalar_one_or_none() is not None
|
||||
is_default = result.scalars().first() is not None
|
||||
|
||||
# 返回包含 is_default 字段的字典
|
||||
return {
|
||||
"id": style.id,
|
||||
"project_id": style.project_id,
|
||||
"user_id": style.user_id,
|
||||
"name": style.name,
|
||||
"style_type": style.style_type,
|
||||
"preset_id": style.preset_id,
|
||||
@@ -279,6 +376,7 @@ async def update_writing_style(
|
||||
@router.delete("/{style_id}", status_code=204)
|
||||
async def delete_writing_style(
|
||||
style_id: int,
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
@@ -289,6 +387,9 @@ async def delete_writing_style(
|
||||
- 不能删除默认风格(必须先设置其他风格为默认)
|
||||
- 删除后无法恢复
|
||||
"""
|
||||
# 获取当前用户ID
|
||||
user_id = get_current_user_id(request)
|
||||
|
||||
result = await db.execute(
|
||||
select(WritingStyle).where(WritingStyle.id == style_id)
|
||||
)
|
||||
@@ -297,14 +398,18 @@ async def delete_writing_style(
|
||||
raise HTTPException(status_code=404, detail="写作风格不存在")
|
||||
|
||||
# 检查是否为全局预设风格(不允许删除)
|
||||
if style.project_id is None:
|
||||
if style.user_id is None:
|
||||
raise HTTPException(status_code=403, detail="不能删除全局预设风格,只能删除自定义风格")
|
||||
|
||||
# 检查是否有项目将其设置为默认风格
|
||||
# 验证用户权限(只能删除自己的风格)
|
||||
if style.user_id != user_id:
|
||||
raise HTTPException(status_code=403, detail="无权删除其他用户的风格")
|
||||
|
||||
# 检查是否有项目将其设置为默认风格(一个风格可能被多个项目使用,使用 first() 避免 MultipleResultsFound)
|
||||
result = await db.execute(
|
||||
select(ProjectDefaultStyle).where(ProjectDefaultStyle.style_id == style_id)
|
||||
)
|
||||
default_relation = result.scalar_one_or_none()
|
||||
default_relation = result.scalars().first()
|
||||
if default_relation:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
@@ -321,6 +426,7 @@ async def delete_writing_style(
|
||||
async def set_default_style(
|
||||
style_id: int,
|
||||
request_data: SetDefaultStyleRequest,
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
@@ -335,13 +441,19 @@ async def set_default_style(
|
||||
"""
|
||||
project_id = request_data.project_id
|
||||
|
||||
# 验证项目是否存在
|
||||
# 获取当前用户ID
|
||||
user_id = get_current_user_id(request)
|
||||
|
||||
# 验证项目访问权限
|
||||
result = await db.execute(
|
||||
select(Project).where(Project.id == project_id)
|
||||
select(Project).where(
|
||||
Project.id == project_id,
|
||||
Project.user_id == user_id
|
||||
)
|
||||
)
|
||||
project = result.scalar_one_or_none()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="项目不存在")
|
||||
raise HTTPException(status_code=404, detail="项目不存在或无权访问")
|
||||
|
||||
# 验证风格是否存在
|
||||
result = await db.execute(
|
||||
@@ -351,9 +463,9 @@ async def set_default_style(
|
||||
if not style:
|
||||
raise HTTPException(status_code=404, detail="写作风格不存在")
|
||||
|
||||
# 验证风格是否属于该项目(自定义风格)或是全局预设风格
|
||||
if style.project_id is not None and style.project_id != project_id:
|
||||
raise HTTPException(status_code=403, detail="无权操作其他项目的风格")
|
||||
# 验证风格是否属于该用户(自定义风格)或是全局预设风格
|
||||
if style.user_id is not None and style.user_id != user_id:
|
||||
raise HTTPException(status_code=403, detail="无权操作其他用户的风格")
|
||||
|
||||
# 使用 UPSERT 逻辑:先删除该项目的旧默认风格记录,再插入新的
|
||||
await db.execute(
|
||||
@@ -379,6 +491,7 @@ async def set_default_style(
|
||||
@router.post("/project/{project_id}/init-defaults", response_model=WritingStyleListResponse)
|
||||
async def initialize_default_styles(
|
||||
project_id: str,
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
@@ -387,13 +500,5 @@ async def initialize_default_styles(
|
||||
新架构下,预设风格是全局的,不需要为每个项目单独初始化
|
||||
该接口保留用于兼容性,直接返回项目可用的所有风格
|
||||
"""
|
||||
# 验证项目是否存在
|
||||
result = await db.execute(
|
||||
select(Project).where(Project.id == project_id)
|
||||
)
|
||||
project = result.scalar_one_or_none()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="项目不存在")
|
||||
|
||||
# 直接返回项目可用的所有风格(全局预设 + 项目自定义)
|
||||
return await get_project_styles(project_id, db)
|
||||
# 直接返回项目可用的所有风格(全局预设 + 用户自定义)
|
||||
return await get_project_styles(project_id, request, db)
|
||||
+33
-8
@@ -3,6 +3,7 @@ from pydantic_settings import BaseSettings
|
||||
from typing import Optional
|
||||
from pathlib import Path
|
||||
import logging
|
||||
import os
|
||||
|
||||
# 获取项目根目录(从backend/app/config.py向上两级)
|
||||
PROJECT_ROOT = Path(__file__).parent.parent
|
||||
@@ -12,13 +13,11 @@ DATA_DIR.mkdir(exist_ok=True)
|
||||
# 配置模块使用标准logging(在logger.py初始化之前)
|
||||
config_logger = logging.getLogger(__name__)
|
||||
|
||||
# 数据库文件路径(绝对路径)
|
||||
DB_FILE = DATA_DIR / "ai_story.db"
|
||||
# 数据库配置:PostgreSQL
|
||||
# 从环境变量获取数据库URL
|
||||
DATABASE_URL = os.getenv("DATABASE_URL", "postgresql+asyncpg://mumuai:password@localhost:5432/mumuai_novel")
|
||||
|
||||
# 生成数据库URL(在类外部生成,确保使用绝对路径)
|
||||
# 将Windows反斜杠转换为正斜杠,SQLite URL格式要求
|
||||
DATABASE_URL = f"sqlite+aiosqlite:///{str(DB_FILE.absolute()).replace(chr(92), '/')}"
|
||||
config_logger.debug(f"数据库文件路径: {DB_FILE}")
|
||||
config_logger.debug(f"数据库类型: PostgreSQL")
|
||||
config_logger.debug(f"数据库URL: {DATABASE_URL}")
|
||||
|
||||
class Settings(BaseSettings):
|
||||
@@ -41,9 +40,31 @@ class Settings(BaseSettings):
|
||||
# CORS配置
|
||||
cors_origins: list[str] = ["http://localhost:8000", "http://127.0.0.1:8000"]
|
||||
|
||||
# 数据库配置 - 使用预先计算好的绝对路径URL
|
||||
# 数据库配置 - PostgreSQL
|
||||
database_url: str = DATABASE_URL
|
||||
|
||||
# PostgreSQL连接池配置(优化后支持150-200并发用户)
|
||||
database_pool_size: int = 50 # 核心连接池大小(优化:从30提升到50)
|
||||
database_max_overflow: int = 30 # 最大溢出连接数(优化:从20提升到30)
|
||||
database_pool_timeout: int = 90 # 连接池超时秒数(优化:从60提升到90)
|
||||
database_pool_recycle: int = 1800 # 连接回收时间秒数(30分钟,防止长时间连接失效)
|
||||
database_pool_pre_ping: bool = True # 连接前ping检测,确保连接有效
|
||||
database_pool_use_lifo: bool = True # 使用LIFO策略提高连接复用率
|
||||
|
||||
# 连接池高级配置
|
||||
database_echo_pool: bool = False # 是否记录连接池日志(调试用)
|
||||
database_pool_reset_on_return: str = "rollback" # 连接归还时的重置策略:rollback/commit/none
|
||||
database_max_identifier_length: int = 128 # PostgreSQL标识符最大长度
|
||||
|
||||
# 会话监控配置
|
||||
database_session_max_active: int = 50 # 活跃会话警告阈值(从100降低到50)
|
||||
database_session_leak_threshold: int = 100 # 会话泄漏严重告警阈值
|
||||
|
||||
# 数据库监控配置
|
||||
database_enable_slow_query_log: bool = True # 启用慢查询日志
|
||||
database_slow_query_threshold: float = 1.0 # 慢查询阈值(秒)
|
||||
database_enable_metrics: bool = True # 启用性能指标收集
|
||||
|
||||
# AI服务配置
|
||||
openai_api_key: Optional[str] = None
|
||||
openai_base_url: Optional[str] = None
|
||||
@@ -54,7 +75,10 @@ class Settings(BaseSettings):
|
||||
default_ai_provider: str = "openai"
|
||||
default_model: str = "gpt-4"
|
||||
default_temperature: float = 0.7
|
||||
default_max_tokens: int = 2000
|
||||
default_max_tokens: int = 32000
|
||||
|
||||
# MCP配置
|
||||
mcp_max_rounds: int = 3 # MCP工具调用最大轮数(全局统一控制)
|
||||
|
||||
# LinuxDO OAuth2 配置
|
||||
LINUXDO_CLIENT_ID: Optional[str] = None
|
||||
@@ -85,6 +109,7 @@ class Settings(BaseSettings):
|
||||
class Config:
|
||||
env_file = ".env"
|
||||
case_sensitive = False
|
||||
extra = "ignore" # 忽略未定义的环境变量,避免验证错误
|
||||
|
||||
|
||||
# 创建全局配置实例
|
||||
|
||||
+286
-179
@@ -1,11 +1,10 @@
|
||||
"""数据库连接和会话管理 - 支持多用户数据隔离"""
|
||||
"""数据库连接和会话管理 - PostgreSQL 多用户数据隔离"""
|
||||
import asyncio
|
||||
from typing import Dict, Any
|
||||
from datetime import datetime
|
||||
from sqlalchemy import select, text
|
||||
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker
|
||||
from sqlalchemy.orm import declarative_base
|
||||
from sqlalchemy.pool import StaticPool
|
||||
from fastapi import Request, HTTPException
|
||||
from app.config import settings
|
||||
from app.logger import get_logger
|
||||
@@ -21,7 +20,8 @@ from app.models import (
|
||||
Project, Outline, Character, Chapter, GenerationHistory,
|
||||
Settings, WritingStyle, ProjectDefaultStyle,
|
||||
RelationshipType, CharacterRelationship, Organization, OrganizationMember,
|
||||
StoryMemory, PlotAnalysis, AnalysisTask
|
||||
StoryMemory, PlotAnalysis, AnalysisTask, BatchGenerationTask,
|
||||
RegenerationTask, Career, CharacterCareer, User, MCPPlugin, PromptTemplate
|
||||
)
|
||||
|
||||
# 引擎缓存:每个用户一个引擎
|
||||
@@ -45,51 +45,101 @@ _session_stats = {
|
||||
async def get_engine(user_id: str):
|
||||
"""获取或创建用户专属的数据库引擎(线程安全)
|
||||
|
||||
PostgreSQL: 所有用户共享一个数据库,通过user_id字段隔离数据
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
|
||||
Returns:
|
||||
用户专属的异步引擎
|
||||
"""
|
||||
if user_id in _engine_cache:
|
||||
return _engine_cache[user_id]
|
||||
# PostgreSQL模式:所有用户共享同一个引擎
|
||||
cache_key = "shared_postgres"
|
||||
if cache_key in _engine_cache:
|
||||
return _engine_cache[cache_key]
|
||||
|
||||
async with _cache_lock:
|
||||
if user_id not in _engine_locks:
|
||||
_engine_locks[user_id] = asyncio.Lock()
|
||||
user_lock = _engine_locks[user_id]
|
||||
|
||||
async with user_lock:
|
||||
if user_id not in _engine_cache:
|
||||
db_url = f"sqlite+aiosqlite:///data/ai_story_user_{user_id}.db"
|
||||
engine = create_async_engine(
|
||||
db_url,
|
||||
echo=False,
|
||||
future=True,
|
||||
poolclass=StaticPool,
|
||||
pool_pre_ping=True,
|
||||
pool_recycle=3600,
|
||||
connect_args={
|
||||
"timeout": 30,
|
||||
"check_same_thread": False
|
||||
}
|
||||
)
|
||||
if cache_key not in _engine_cache:
|
||||
# 检测数据库类型
|
||||
is_sqlite = 'sqlite' in settings.database_url.lower()
|
||||
|
||||
try:
|
||||
async with engine.begin() as conn:
|
||||
await conn.execute(text("PRAGMA journal_mode=WAL"))
|
||||
await conn.execute(text("PRAGMA synchronous=NORMAL"))
|
||||
await conn.execute(text("PRAGMA cache_size=-64000"))
|
||||
await conn.execute(text("PRAGMA temp_store=MEMORY"))
|
||||
await conn.execute(text("PRAGMA busy_timeout=5000"))
|
||||
# 基础引擎参数
|
||||
engine_args = {
|
||||
"echo": settings.database_echo_pool,
|
||||
"echo_pool": settings.database_echo_pool,
|
||||
"future": True,
|
||||
}
|
||||
|
||||
if is_sqlite:
|
||||
# SQLite 配置(使用 NullPool,不支持连接池参数)
|
||||
engine_args["connect_args"] = {
|
||||
"check_same_thread": False,
|
||||
"timeout": 30.0, # 等待锁释放的超时时间(秒)
|
||||
}
|
||||
# 启用连接前检测以支持更好的并发
|
||||
engine_args["pool_pre_ping"] = True
|
||||
|
||||
logger.info("📊 使用 SQLite 数据库(NullPool,超时30秒,WAL模式)")
|
||||
else:
|
||||
# PostgreSQL 配置(完整连接池支持)
|
||||
connect_args = {
|
||||
"server_settings": {
|
||||
"application_name": settings.app_name,
|
||||
"jit": "off",
|
||||
"search_path": "public",
|
||||
},
|
||||
"command_timeout": 60,
|
||||
"statement_cache_size": 500,
|
||||
}
|
||||
|
||||
engine_args.update({
|
||||
"pool_size": settings.database_pool_size,
|
||||
"max_overflow": settings.database_max_overflow,
|
||||
"pool_timeout": settings.database_pool_timeout,
|
||||
"pool_pre_ping": settings.database_pool_pre_ping,
|
||||
"pool_recycle": settings.database_pool_recycle,
|
||||
"pool_use_lifo": settings.database_pool_use_lifo,
|
||||
"pool_reset_on_return": settings.database_pool_reset_on_return,
|
||||
"max_identifier_length": settings.database_max_identifier_length,
|
||||
"connect_args": connect_args
|
||||
})
|
||||
|
||||
total_connections = settings.database_pool_size + settings.database_max_overflow
|
||||
estimated_concurrent_users = total_connections * 2
|
||||
|
||||
logger.info(
|
||||
f"📊 PostgreSQL 连接池配置:\n"
|
||||
f" ├─ 核心连接: {settings.database_pool_size}\n"
|
||||
f" ├─ 溢出连接: {settings.database_max_overflow}\n"
|
||||
f" ├─ 总连接数: {total_connections}\n"
|
||||
f" ├─ 获取超时: {settings.database_pool_timeout}秒\n"
|
||||
f" ├─ 连接回收: {settings.database_pool_recycle}秒\n"
|
||||
f" └─ 预估并发: {estimated_concurrent_users}+用户"
|
||||
)
|
||||
|
||||
engine = create_async_engine(settings.database_url, **engine_args)
|
||||
_engine_cache[cache_key] = engine
|
||||
|
||||
# 如果是 SQLite,启用 WAL 模式以支持读写并发
|
||||
if is_sqlite:
|
||||
try:
|
||||
from sqlalchemy import event
|
||||
from sqlalchemy.pool import NullPool
|
||||
|
||||
logger.info(f"✅ 用户 {user_id} 的数据库已优化(WAL模式 + 64MB缓存)")
|
||||
except Exception as e:
|
||||
logger.warning(f"⚠️ 用户 {user_id} 数据库优化失败: {str(e)}")
|
||||
_engine_cache[user_id] = engine
|
||||
logger.info(f"为用户 {user_id} 创建数据库引擎")
|
||||
@event.listens_for(engine.sync_engine, "connect")
|
||||
def set_sqlite_pragma(dbapi_conn, connection_record):
|
||||
cursor = dbapi_conn.cursor()
|
||||
cursor.execute("PRAGMA journal_mode=WAL")
|
||||
cursor.execute("PRAGMA synchronous=NORMAL")
|
||||
cursor.execute("PRAGMA cache_size=-64000") # 64MB 缓存
|
||||
cursor.execute("PRAGMA busy_timeout=30000") # 30秒超时
|
||||
cursor.close()
|
||||
|
||||
logger.info("✅ SQLite WAL 模式已启用(支持读写并发)")
|
||||
except Exception as e:
|
||||
logger.warning(f"⚠️ 启用 WAL 模式失败: {e},使用默认配置")
|
||||
|
||||
return _engine_cache[user_id]
|
||||
return _engine_cache[cache_key]
|
||||
|
||||
|
||||
async def get_db(request: Request):
|
||||
@@ -117,7 +167,7 @@ async def get_db(request: Request):
|
||||
_session_stats["created"] += 1
|
||||
_session_stats["active"] += 1
|
||||
|
||||
logger.debug(f"📊 会话创建 [User:{user_id}][ID:{session_id}] - 活跃:{_session_stats['active']}, 总创建:{_session_stats['created']}, 总关闭:{_session_stats['closed']}")
|
||||
# logger.debug(f"📊 会话创建 [User:{user_id}][ID:{session_id}] - 活跃:{_session_stats['active']}, 总创建:{_session_stats['created']}, 总关闭:{_session_stats['closed']}")
|
||||
|
||||
try:
|
||||
yield session
|
||||
@@ -157,8 +207,11 @@ async def get_db(request: Request):
|
||||
|
||||
logger.debug(f"📊 会话关闭 [User:{user_id}][ID:{session_id}] - 活跃:{_session_stats['active']}, 总创建:{_session_stats['created']}, 总关闭:{_session_stats['closed']}, 错误:{_session_stats['errors']}")
|
||||
|
||||
if _session_stats["active"] > 100:
|
||||
logger.warning(f"🚨 活跃会话数过多: {_session_stats['active']},可能存在连接泄漏!")
|
||||
# 使用优化后的会话监控阈值
|
||||
if _session_stats["active"] > settings.database_session_leak_threshold:
|
||||
logger.error(f"🚨 严重告警:活跃会话数 {_session_stats['active']} 超过泄漏阈值 {settings.database_session_leak_threshold}!")
|
||||
elif _session_stats["active"] > settings.database_session_max_active:
|
||||
logger.warning(f"⚠️ 警告:活跃会话数 {_session_stats['active']} 超过警告阈值 {settings.database_session_max_active},可能存在连接泄漏!")
|
||||
elif _session_stats["active"] < 0:
|
||||
logger.error(f"🚨 活跃会话数异常: {_session_stats['active']},统计可能不准确!")
|
||||
|
||||
@@ -170,147 +223,25 @@ async def get_db(request: Request):
|
||||
except:
|
||||
pass
|
||||
|
||||
async def _init_relationship_types(user_id: str):
|
||||
"""为指定用户初始化预置的关系类型数据
|
||||
async def init_db(user_id: str = None):
|
||||
"""
|
||||
初始化数据库(已弃用)
|
||||
|
||||
⚠️ 此函数已弃用,仅保留用于向后兼容
|
||||
|
||||
新的最佳实践:
|
||||
- 表结构管理: 使用 'alembic upgrade head'
|
||||
- 用户配置: Settings 在首次访问时自动创建(延迟初始化)
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
user_id: 用户ID (已不再使用)
|
||||
"""
|
||||
from app.models.relationship import RelationshipType
|
||||
|
||||
relationship_types = [
|
||||
{"name": "父亲", "category": "family", "reverse_name": "子女", "intimacy_range": "high", "icon": "👨"},
|
||||
{"name": "母亲", "category": "family", "reverse_name": "子女", "intimacy_range": "high", "icon": "👩"},
|
||||
{"name": "兄弟", "category": "family", "reverse_name": "兄弟", "intimacy_range": "high", "icon": "👬"},
|
||||
{"name": "姐妹", "category": "family", "reverse_name": "姐妹", "intimacy_range": "high", "icon": "👭"},
|
||||
{"name": "子女", "category": "family", "reverse_name": "父母", "intimacy_range": "high", "icon": "👶"},
|
||||
{"name": "配偶", "category": "family", "reverse_name": "配偶", "intimacy_range": "high", "icon": "💑"},
|
||||
{"name": "恋人", "category": "family", "reverse_name": "恋人", "intimacy_range": "high", "icon": "💕"},
|
||||
|
||||
{"name": "师父", "category": "social", "reverse_name": "徒弟", "intimacy_range": "high", "icon": "🎓"},
|
||||
{"name": "徒弟", "category": "social", "reverse_name": "师父", "intimacy_range": "high", "icon": "📚"},
|
||||
{"name": "朋友", "category": "social", "reverse_name": "朋友", "intimacy_range": "medium", "icon": "🤝"},
|
||||
{"name": "同学", "category": "social", "reverse_name": "同学", "intimacy_range": "medium", "icon": "🎒"},
|
||||
{"name": "邻居", "category": "social", "reverse_name": "邻居", "intimacy_range": "low", "icon": "🏘️"},
|
||||
{"name": "知己", "category": "social", "reverse_name": "知己", "intimacy_range": "high", "icon": "💙"},
|
||||
|
||||
{"name": "上司", "category": "professional", "reverse_name": "下属", "intimacy_range": "low", "icon": "👔"},
|
||||
{"name": "下属", "category": "professional", "reverse_name": "上司", "intimacy_range": "low", "icon": "💼"},
|
||||
{"name": "同事", "category": "professional", "reverse_name": "同事", "intimacy_range": "medium", "icon": "🤵"},
|
||||
{"name": "合作伙伴", "category": "professional", "reverse_name": "合作伙伴", "intimacy_range": "medium", "icon": "🤜🤛"},
|
||||
|
||||
{"name": "敌人", "category": "hostile", "reverse_name": "敌人", "intimacy_range": "low", "icon": "⚔️"},
|
||||
{"name": "仇人", "category": "hostile", "reverse_name": "仇人", "intimacy_range": "low", "icon": "💢"},
|
||||
{"name": "竞争对手", "category": "hostile", "reverse_name": "竞争对手", "intimacy_range": "low", "icon": "🎯"},
|
||||
{"name": "宿敌", "category": "hostile", "reverse_name": "宿敌", "intimacy_range": "low", "icon": "⚡"},
|
||||
]
|
||||
|
||||
try:
|
||||
engine = await get_engine(user_id)
|
||||
AsyncSessionLocal = async_sessionmaker(
|
||||
engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False
|
||||
)
|
||||
|
||||
async with AsyncSessionLocal() as session:
|
||||
result = await session.execute(select(RelationshipType))
|
||||
existing = result.scalars().first()
|
||||
|
||||
if existing:
|
||||
logger.info(f"用户 {user_id} 的关系类型数据已存在,跳过初始化")
|
||||
return
|
||||
|
||||
logger.info(f"开始为用户 {user_id} 插入关系类型数据...")
|
||||
for rt_data in relationship_types:
|
||||
relationship_type = RelationshipType(**rt_data)
|
||||
session.add(relationship_type)
|
||||
|
||||
await session.commit()
|
||||
logger.info(f"成功为用户 {user_id} 插入 {len(relationship_types)} 条关系类型数据")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"用户 {user_id} 初始化关系类型数据失败: {str(e)}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
|
||||
async def _init_global_writing_styles(user_id: str):
|
||||
"""为指定用户初始化全局预设写作风格
|
||||
|
||||
全局预设风格的 project_id 为 NULL,所有用户共享
|
||||
只在第一次创建数据库时插入一次
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
"""
|
||||
from app.models.writing_style import WritingStyle
|
||||
from app.services.prompt_service import WritingStyleManager
|
||||
|
||||
try:
|
||||
engine = await get_engine(user_id)
|
||||
AsyncSessionLocal = async_sessionmaker(
|
||||
engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False
|
||||
)
|
||||
|
||||
async with AsyncSessionLocal() as session:
|
||||
# 检查是否已存在全局预设风格
|
||||
result = await session.execute(
|
||||
select(WritingStyle).where(WritingStyle.project_id.is_(None))
|
||||
)
|
||||
existing = result.scalars().first()
|
||||
|
||||
if existing:
|
||||
logger.info(f"用户 {user_id} 的全局预设风格已存在,跳过初始化")
|
||||
return
|
||||
|
||||
logger.info(f"开始为用户 {user_id} 插入全局预设写作风格...")
|
||||
|
||||
# 获取所有预设风格配置
|
||||
presets = WritingStyleManager.get_all_presets()
|
||||
|
||||
for index, (preset_id, preset_data) in enumerate(presets.items(), start=1):
|
||||
style = WritingStyle(
|
||||
project_id=None, # NULL 表示全局预设
|
||||
name=preset_data["name"],
|
||||
style_type="preset",
|
||||
preset_id=preset_id,
|
||||
description=preset_data["description"],
|
||||
prompt_content=preset_data["prompt_content"],
|
||||
order_index=index
|
||||
)
|
||||
session.add(style)
|
||||
|
||||
await session.commit()
|
||||
logger.info(f"成功为用户 {user_id} 插入 {len(presets)} 个全局预设写作风格")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"用户 {user_id} 初始化全局预设写作风格失败: {str(e)}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
async def init_db(user_id: str):
|
||||
"""初始化指定用户的数据库,创建所有表并插入预置数据
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
"""
|
||||
try:
|
||||
logger.info(f"开始初始化用户 {user_id} 的数据库...")
|
||||
engine = await get_engine(user_id)
|
||||
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
await _init_relationship_types(user_id)
|
||||
await _init_global_writing_styles(user_id)
|
||||
|
||||
logger.info(f"用户 {user_id} 的数据库初始化成功")
|
||||
except Exception as e:
|
||||
logger.error(f"用户 {user_id} 的数据库初始化失败: {str(e)}", exc_info=True)
|
||||
raise
|
||||
logger.warning(
|
||||
"⚠️ init_db() 已弃用且无实际作用!\n"
|
||||
" - 表结构: 由 Alembic 管理\n"
|
||||
" - 用户配置: Settings API 自动创建\n"
|
||||
" 建议移除此调用"
|
||||
)
|
||||
|
||||
|
||||
async def close_db():
|
||||
@@ -324,4 +255,180 @@ async def close_db():
|
||||
logger.info("所有数据库连接已关闭")
|
||||
except Exception as e:
|
||||
logger.error(f"关闭数据库连接失败: {str(e)}", exc_info=True)
|
||||
raise
|
||||
raise
|
||||
|
||||
async def get_database_stats():
|
||||
"""获取数据库连接和会话统计信息
|
||||
|
||||
Returns:
|
||||
dict: 包含数据库统计信息的字典
|
||||
"""
|
||||
from app.config import settings
|
||||
|
||||
# 获取连接池详细状态
|
||||
pool_stats = {}
|
||||
cache_key = "shared_postgres"
|
||||
if cache_key in _engine_cache:
|
||||
engine = _engine_cache[cache_key]
|
||||
try:
|
||||
pool = engine.pool
|
||||
pool_stats = {
|
||||
"size": pool.size(), # 当前连接池大小
|
||||
"checked_in": pool.checkedin(), # 可用连接数
|
||||
"checked_out": pool.checkedout(), # 正在使用的连接数
|
||||
"overflow": pool.overflow(), # 溢出连接数
|
||||
"usage_percent": (pool.checkedout() / (settings.database_pool_size + settings.database_max_overflow)) * 100,
|
||||
}
|
||||
except Exception as e:
|
||||
logger.warning(f"获取连接池状态失败: {e}")
|
||||
pool_stats = {"error": str(e)}
|
||||
|
||||
stats = {
|
||||
"session_stats": {
|
||||
"created": _session_stats["created"],
|
||||
"closed": _session_stats["closed"],
|
||||
"active": _session_stats["active"],
|
||||
"errors": _session_stats["errors"],
|
||||
"generator_exits": _session_stats["generator_exits"],
|
||||
"last_check": _session_stats["last_check"],
|
||||
},
|
||||
"pool_stats": pool_stats, # 新增:连接池实时状态
|
||||
"engine_cache": {
|
||||
"total_engines": len(_engine_cache),
|
||||
"engine_keys": list(_engine_cache.keys()),
|
||||
},
|
||||
"config": {
|
||||
"database_type": "PostgreSQL",
|
||||
"pool_size": settings.database_pool_size,
|
||||
"max_overflow": settings.database_max_overflow,
|
||||
"total_connections": settings.database_pool_size + settings.database_max_overflow,
|
||||
"pool_timeout": settings.database_pool_timeout,
|
||||
"pool_recycle": settings.database_pool_recycle,
|
||||
"session_max_active_threshold": settings.database_session_max_active,
|
||||
"session_leak_threshold": settings.database_session_leak_threshold,
|
||||
},
|
||||
"health": {
|
||||
"status": "healthy",
|
||||
"warnings": [],
|
||||
"errors": [],
|
||||
}
|
||||
}
|
||||
|
||||
# 健康检查
|
||||
if _session_stats["active"] > settings.database_session_leak_threshold:
|
||||
stats["health"]["status"] = "critical"
|
||||
stats["health"]["errors"].append(
|
||||
f"活跃会话数 {_session_stats['active']} 超过泄漏阈值 {settings.database_session_leak_threshold}"
|
||||
)
|
||||
elif _session_stats["active"] > settings.database_session_max_active:
|
||||
stats["health"]["status"] = "warning"
|
||||
stats["health"]["warnings"].append(
|
||||
f"活跃会话数 {_session_stats['active']} 超过警告阈值 {settings.database_session_max_active}"
|
||||
)
|
||||
|
||||
if _session_stats["active"] < 0:
|
||||
stats["health"]["status"] = "error"
|
||||
stats["health"]["errors"].append(f"活跃会话数异常: {_session_stats['active']}")
|
||||
|
||||
# 连接池使用率检查
|
||||
if pool_stats and "usage_percent" in pool_stats:
|
||||
usage = pool_stats["usage_percent"]
|
||||
if usage > 90:
|
||||
stats["health"]["status"] = "warning"
|
||||
stats["health"]["warnings"].append(f"连接池使用率过高: {usage:.1f}%")
|
||||
elif usage > 95:
|
||||
stats["health"]["status"] = "critical"
|
||||
stats["health"]["errors"].append(f"连接池几乎耗尽: {usage:.1f}%")
|
||||
|
||||
error_rate = (_session_stats["errors"] / max(_session_stats["created"], 1)) * 100
|
||||
if error_rate > 5:
|
||||
if stats["health"]["status"] == "healthy":
|
||||
stats["health"]["status"] = "warning"
|
||||
stats["health"]["warnings"].append(f"会话错误率过高: {error_rate:.2f}%")
|
||||
|
||||
stats["health"]["error_rate"] = f"{error_rate:.2f}%"
|
||||
|
||||
return stats
|
||||
|
||||
|
||||
async def check_database_health(user_id: str = None) -> dict:
|
||||
"""检查数据库连接健康状态
|
||||
|
||||
Args:
|
||||
user_id: 可选的用户ID,如果提供则检查特定用户的数据库
|
||||
|
||||
Returns:
|
||||
dict: 健康检查结果
|
||||
"""
|
||||
result = {
|
||||
"healthy": True,
|
||||
"checks": {},
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
try:
|
||||
# 检查引擎是否存在
|
||||
cache_key = "shared_postgres"
|
||||
if user_id:
|
||||
engine = await get_engine(user_id)
|
||||
else:
|
||||
if cache_key not in _engine_cache:
|
||||
result["checks"]["engine"] = {"status": "not_initialized", "healthy": True}
|
||||
return result
|
||||
engine = _engine_cache[cache_key]
|
||||
|
||||
# 测试数据库连接
|
||||
AsyncSessionLocal = async_sessionmaker(
|
||||
engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False
|
||||
)
|
||||
|
||||
async with AsyncSessionLocal() as session:
|
||||
# 执行简单查询测试连接
|
||||
await session.execute(text("SELECT 1"))
|
||||
result["checks"]["connection"] = {"status": "ok", "healthy": True}
|
||||
|
||||
# 检查连接池状态(仅PostgreSQL)
|
||||
if hasattr(engine.pool, 'size'):
|
||||
pool_status = {
|
||||
"size": engine.pool.size(),
|
||||
"checked_in": engine.pool.checkedin(),
|
||||
"checked_out": engine.pool.checkedout(),
|
||||
"overflow": engine.pool.overflow(),
|
||||
"healthy": True
|
||||
}
|
||||
|
||||
# 连接池健康检查
|
||||
if engine.pool.overflow() >= settings.database_max_overflow:
|
||||
pool_status["healthy"] = False
|
||||
pool_status["warning"] = "连接池溢出已满"
|
||||
result["healthy"] = False
|
||||
|
||||
result["checks"]["pool"] = pool_status
|
||||
|
||||
except Exception as e:
|
||||
result["healthy"] = False
|
||||
result["checks"]["error"] = {
|
||||
"status": "error",
|
||||
"message": str(e),
|
||||
"healthy": False
|
||||
}
|
||||
logger.error(f"数据库健康检查失败: {str(e)}", exc_info=True)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
async def reset_session_stats():
|
||||
"""重置会话统计信息(用于测试或维护)"""
|
||||
global _session_stats
|
||||
_session_stats = {
|
||||
"created": 0,
|
||||
"closed": 0,
|
||||
"active": 0,
|
||||
"errors": 0,
|
||||
"generator_exits": 0,
|
||||
"last_check": datetime.now().isoformat()
|
||||
}
|
||||
logger.info("✅ 会话统计信息已重置")
|
||||
return _session_stats
|
||||
@@ -130,11 +130,15 @@ def _configure_third_party_loggers():
|
||||
logging.getLogger('sqlalchemy.dialects').setLevel(logging.WARNING)
|
||||
logging.getLogger('sqlalchemy.orm').setLevel(logging.WARNING)
|
||||
|
||||
# aiosqlite - 异步SQLite,禁用DEBUG日志
|
||||
logging.getLogger('aiosqlite').setLevel(logging.WARNING)
|
||||
|
||||
# Watchfiles - 开发时的文件监控,降低级别
|
||||
logging.getLogger('watchfiles').setLevel(logging.WARNING)
|
||||
|
||||
# httpx - HTTP客户端
|
||||
# httpx/httpcore - HTTP客户端,禁用DEBUG日志
|
||||
logging.getLogger('httpx').setLevel(logging.WARNING)
|
||||
logging.getLogger('httpcore').setLevel(logging.WARNING)
|
||||
|
||||
# openai/anthropic - AI客户端库
|
||||
logging.getLogger('openai').setLevel(logging.WARNING)
|
||||
|
||||
+24
-2
@@ -12,6 +12,7 @@ from app.database import close_db, _session_stats
|
||||
from app.logger import setup_logging, get_logger
|
||||
from app.middleware import RequestIDMiddleware
|
||||
from app.middleware.auth_middleware import AuthMiddleware
|
||||
from app.mcp import mcp_client, register_status_sync
|
||||
|
||||
setup_logging(
|
||||
level=config_settings.log_level,
|
||||
@@ -26,10 +27,23 @@ logger = get_logger(__name__)
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""应用生命周期管理"""
|
||||
logger.info("应用启动,等待用户登录...")
|
||||
# 注册MCP状态同步服务
|
||||
register_status_sync()
|
||||
|
||||
logger.info("应用启动完成")
|
||||
|
||||
yield
|
||||
|
||||
# 清理MCP插件
|
||||
await mcp_client.cleanup()
|
||||
|
||||
# 清理HTTP客户端池
|
||||
from app.services.ai_service import cleanup_http_clients
|
||||
await cleanup_http_clients()
|
||||
|
||||
# 关闭数据库连接
|
||||
await close_db()
|
||||
|
||||
logger.info("应用已关闭")
|
||||
|
||||
|
||||
@@ -114,22 +128,30 @@ async def db_session_stats():
|
||||
from app.api import (
|
||||
projects, outlines, characters, chapters,
|
||||
wizard_stream, relationships, organizations,
|
||||
auth, users, settings, writing_styles, memories
|
||||
auth, users, settings, writing_styles, memories,
|
||||
mcp_plugins, admin, inspiration, prompt_templates,
|
||||
changelog, careers
|
||||
)
|
||||
|
||||
app.include_router(auth.router, prefix="/api")
|
||||
app.include_router(users.router, prefix="/api")
|
||||
app.include_router(settings.router, prefix="/api")
|
||||
app.include_router(admin.router, prefix="/api")
|
||||
|
||||
app.include_router(projects.router, prefix="/api")
|
||||
app.include_router(wizard_stream.router, prefix="/api")
|
||||
app.include_router(inspiration.router, prefix="/api")
|
||||
app.include_router(outlines.router, prefix="/api")
|
||||
app.include_router(characters.router, prefix="/api")
|
||||
app.include_router(careers.router, prefix="/api") # 职业管理API
|
||||
app.include_router(chapters.router, prefix="/api")
|
||||
app.include_router(relationships.router, prefix="/api")
|
||||
app.include_router(organizations.router, prefix="/api")
|
||||
app.include_router(writing_styles.router, prefix="/api")
|
||||
app.include_router(memories.router) # 记忆管理API (已包含/api前缀)
|
||||
app.include_router(mcp_plugins.router, prefix="/api") # MCP插件管理API
|
||||
app.include_router(prompt_templates.router, prefix="/api") # 提示词模板管理API
|
||||
app.include_router(changelog.router, prefix="/api") # 更新日志API
|
||||
|
||||
static_dir = Path(__file__).parent.parent / "static"
|
||||
if static_dir.exists():
|
||||
|
||||
@@ -0,0 +1,36 @@
|
||||
"""MCP模块 - 统一的MCP客户端管理
|
||||
|
||||
本模块提供MCP(Model Context Protocol)客户端的统一管理接口。
|
||||
|
||||
推荐使用方式:
|
||||
from app.mcp import mcp_client, MCPPluginConfig
|
||||
|
||||
# 注册插件
|
||||
await mcp_client.register(MCPPluginConfig(
|
||||
user_id="user123",
|
||||
plugin_name="exa-search",
|
||||
url="http://localhost:8000/mcp"
|
||||
))
|
||||
|
||||
# 获取工具
|
||||
tools = await mcp_client.get_tools("user123", "exa-search")
|
||||
|
||||
# 调用工具
|
||||
result = await mcp_client.call_tool("user123", "exa-search", "web_search", {"query": "..."})
|
||||
|
||||
# 注册状态变更回调
|
||||
from app.mcp.status_sync import register_status_sync
|
||||
register_status_sync()
|
||||
"""
|
||||
|
||||
from .facade import mcp_client, MCPClientFacade, MCPPluginConfig, MCPError, PluginStatus
|
||||
from .status_sync import register_status_sync
|
||||
|
||||
__all__ = [
|
||||
"mcp_client",
|
||||
"MCPClientFacade",
|
||||
"MCPPluginConfig",
|
||||
"MCPError",
|
||||
"PluginStatus",
|
||||
"register_status_sync",
|
||||
]
|
||||
@@ -0,0 +1,42 @@
|
||||
"""MCP模块配置常量"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class MCPConfig:
|
||||
"""MCP模块配置常量(不可变)"""
|
||||
|
||||
# 连接池配置
|
||||
MAX_CLIENTS: int = 1000 # 最大客户端数量
|
||||
CLIENT_TTL_SECONDS: int = 3600 # 客户端过期时间(1小时)
|
||||
IDLE_TIMEOUT_SECONDS: int = 1800 # 空闲超时(30分钟)
|
||||
|
||||
# 健康检查配置
|
||||
HEALTH_CHECK_INTERVAL_SECONDS: int = 60 # 健康检查间隔
|
||||
ERROR_RATE_CRITICAL: float = 0.7 # 严重错误率阈值
|
||||
ERROR_RATE_WARNING: float = 0.4 # 警告错误率阈值
|
||||
MIN_REQUESTS_FOR_HEALTH_CHECK: int = 10 # 进行健康检查的最小请求数
|
||||
|
||||
# 清理任务配置
|
||||
CLEANUP_INTERVAL_SECONDS: int = 300 # 清理任务间隔(5分钟)
|
||||
|
||||
# 缓存配置
|
||||
TOOL_CACHE_TTL_MINUTES: int = 10 # 工具定义缓存TTL
|
||||
|
||||
# 重试配置
|
||||
MAX_RETRIES: int = 3 # 最大重试次数
|
||||
BASE_RETRY_DELAY_SECONDS: float = 1.0 # 基础重试延迟
|
||||
MAX_RETRY_DELAY_SECONDS: float = 10.0 # 最大重试延迟
|
||||
|
||||
# 超时配置
|
||||
DEFAULT_TIMEOUT_SECONDS: float = 60.0 # 默认超时时间
|
||||
TOOL_CALL_TIMEOUT_SECONDS: float = 60.0 # 工具调用超时时间
|
||||
|
||||
# 日志配置
|
||||
LOG_TOOL_ARGUMENTS: bool = True # 是否记录工具参数
|
||||
LOG_TOOL_RESULTS: bool = False # 是否记录工具结果(可能很大)
|
||||
|
||||
|
||||
# 全局配置实例
|
||||
mcp_config = MCPConfig()
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,50 @@
|
||||
"""MCP插件状态同步服务
|
||||
|
||||
将内存中的会话状态变更同步到数据库,确保状态一致性。
|
||||
"""
|
||||
|
||||
from typing import Dict, Any
|
||||
from sqlalchemy import update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
|
||||
from app.models.mcp_plugin import MCPPlugin
|
||||
from app.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
async def sync_status_to_db(event: Dict[str, Any]):
|
||||
"""
|
||||
状态变更回调 - 同步到数据库
|
||||
"""
|
||||
user_id = event["user_id"]
|
||||
plugin_name = event["plugin_name"]
|
||||
new_status = event["new_status"]
|
||||
reason = event.get("reason", "")
|
||||
|
||||
try:
|
||||
from app.database import get_engine
|
||||
|
||||
engine = await get_engine(user_id)
|
||||
AsyncSessionLocal = async_sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
|
||||
|
||||
async with AsyncSessionLocal() as db:
|
||||
stmt = (
|
||||
update(MCPPlugin)
|
||||
.where(MCPPlugin.user_id == user_id, MCPPlugin.plugin_name == plugin_name)
|
||||
.values(status=new_status, last_error=reason if new_status == "error" else None)
|
||||
)
|
||||
await db.execute(stmt)
|
||||
await db.commit()
|
||||
|
||||
logger.debug(f"✅ 状态已同步到数据库: {plugin_name} -> {new_status}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 状态同步失败: {plugin_name}, 错误: {e}")
|
||||
|
||||
|
||||
def register_status_sync():
|
||||
"""注册状态同步回调到MCP客户端"""
|
||||
from app.mcp import mcp_client
|
||||
mcp_client.register_status_callback(sync_status_to_db)
|
||||
logger.info("✅ MCP状态同步服务已注册")
|
||||
@@ -1,9 +1,12 @@
|
||||
"""
|
||||
认证中间件 - 从 Cookie 中提取用户信息并注入到 request.state
|
||||
"""
|
||||
from fastapi import Request
|
||||
from fastapi import Request, HTTPException
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from app.user_manager import user_manager
|
||||
from app.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class AuthMiddleware(BaseHTTPMiddleware):
|
||||
@@ -20,9 +23,18 @@ class AuthMiddleware(BaseHTTPMiddleware):
|
||||
if user_id:
|
||||
user = await user_manager.get_user(user_id)
|
||||
if user:
|
||||
request.state.user_id = user_id
|
||||
request.state.user = user
|
||||
request.state.is_admin = user.is_admin
|
||||
# 检查用户是否被禁用 (trust_level = -1)
|
||||
if user.trust_level == -1:
|
||||
logger.warning(f"禁用用户尝试访问: {user_id} ({user.username})")
|
||||
# 清除用户状态,视为未登录
|
||||
request.state.user_id = None
|
||||
request.state.user = None
|
||||
request.state.is_admin = False
|
||||
else:
|
||||
# 用户正常,注入状态
|
||||
request.state.user_id = user_id
|
||||
request.state.user = user
|
||||
request.state.is_admin = user.is_admin
|
||||
else:
|
||||
# 用户不存在,清除状态
|
||||
request.state.user_id = None
|
||||
|
||||
@@ -1,35 +1,44 @@
|
||||
"""数据库模型"""
|
||||
"""数据模型导出"""
|
||||
from app.models.project import Project
|
||||
from app.models.outline import Outline
|
||||
from app.models.character import Character
|
||||
from app.models.chapter import Chapter
|
||||
from app.models.character import Character
|
||||
from app.models.relationship import CharacterRelationship, Organization, OrganizationMember, RelationshipType
|
||||
from app.models.generation_history import GenerationHistory
|
||||
from app.models.analysis_task import AnalysisTask
|
||||
from app.models.batch_generation_task import BatchGenerationTask
|
||||
from app.models.settings import Settings
|
||||
from app.models.memory import StoryMemory, PlotAnalysis
|
||||
from app.models.writing_style import WritingStyle
|
||||
from app.models.project_default_style import ProjectDefaultStyle
|
||||
from app.models.relationship import (
|
||||
RelationshipType,
|
||||
CharacterRelationship,
|
||||
Organization,
|
||||
OrganizationMember
|
||||
)
|
||||
from app.models.memory import StoryMemory, PlotAnalysis
|
||||
from app.models.analysis_task import AnalysisTask
|
||||
from app.models.mcp_plugin import MCPPlugin
|
||||
from app.models.user import User, UserPassword
|
||||
from app.models.regeneration_task import RegenerationTask
|
||||
from app.models.career import Career, CharacterCareer
|
||||
from app.models.prompt_template import PromptTemplate
|
||||
|
||||
__all__ = [
|
||||
"Project",
|
||||
"Outline",
|
||||
"Character",
|
||||
"Chapter",
|
||||
"GenerationHistory",
|
||||
"Settings",
|
||||
"WritingStyle",
|
||||
"ProjectDefaultStyle",
|
||||
"RelationshipType",
|
||||
"Character",
|
||||
"CharacterRelationship",
|
||||
"Organization",
|
||||
"OrganizationMember",
|
||||
"RelationshipType",
|
||||
"GenerationHistory",
|
||||
"AnalysisTask",
|
||||
"BatchGenerationTask",
|
||||
"Settings",
|
||||
"StoryMemory",
|
||||
"PlotAnalysis",
|
||||
"AnalysisTask",
|
||||
"WritingStyle",
|
||||
"ProjectDefaultStyle",
|
||||
"MCPPlugin",
|
||||
"User",
|
||||
"UserPassword",
|
||||
"RegenerationTask",
|
||||
"Career",
|
||||
"CharacterCareer",
|
||||
"PromptTemplate"
|
||||
]
|
||||
@@ -0,0 +1,43 @@
|
||||
"""批量生成任务数据模型"""
|
||||
from sqlalchemy import Column, String, Integer, DateTime, Boolean, JSON
|
||||
from sqlalchemy.sql import func
|
||||
from app.database import Base
|
||||
import uuid
|
||||
|
||||
|
||||
class BatchGenerationTask(Base):
|
||||
"""批量生成任务表"""
|
||||
__tablename__ = "batch_generation_tasks"
|
||||
|
||||
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
|
||||
project_id = Column(String(36), nullable=False, comment="项目ID")
|
||||
user_id = Column(String(100), nullable=False, comment="用户ID")
|
||||
|
||||
# 任务配置
|
||||
start_chapter_number = Column(Integer, nullable=False, comment="起始章节序号")
|
||||
chapter_count = Column(Integer, nullable=False, comment="生成章节数量")
|
||||
chapter_ids = Column(JSON, nullable=False, comment="待生成的章节ID列表")
|
||||
style_id = Column(Integer, comment="使用的写作风格ID")
|
||||
target_word_count = Column(Integer, default=3000, comment="目标字数")
|
||||
enable_analysis = Column(Boolean, default=False, comment="是否启用同步分析")
|
||||
|
||||
# 任务状态
|
||||
status = Column(String(20), default="pending", comment="任务状态: pending/running/completed/failed/cancelled")
|
||||
total_chapters = Column(Integer, default=0, comment="总章节数")
|
||||
completed_chapters = Column(Integer, default=0, comment="已完成章节数")
|
||||
failed_chapters = Column(JSON, default=list, comment="失败的章节信息列表")
|
||||
current_chapter_id = Column(String(36), comment="当前正在生成的章节ID")
|
||||
current_chapter_number = Column(Integer, comment="当前正在生成的章节序号")
|
||||
current_retry_count = Column(Integer, default=0, comment="当前章节重试次数")
|
||||
max_retries = Column(Integer, default=3, comment="最大重试次数")
|
||||
|
||||
# 时间记录
|
||||
created_at = Column(DateTime, server_default=func.now(), comment="创建时间")
|
||||
started_at = Column(DateTime, comment="开始时间")
|
||||
completed_at = Column(DateTime, comment="完成时间")
|
||||
|
||||
# 错误信息
|
||||
error_message = Column(String(500), comment="错误信息")
|
||||
|
||||
def __repr__(self):
|
||||
return f"<BatchGenerationTask(id={self.id}, status={self.status}, completed={self.completed_chapters}/{self.total_chapters})>"
|
||||
@@ -0,0 +1,77 @@
|
||||
"""职业数据模型"""
|
||||
from sqlalchemy import Column, String, Text, DateTime, Integer, ForeignKey, Index
|
||||
from sqlalchemy.sql import func
|
||||
from app.database import Base
|
||||
import uuid
|
||||
|
||||
|
||||
class Career(Base):
|
||||
"""职业表"""
|
||||
__tablename__ = "careers"
|
||||
|
||||
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
|
||||
project_id = Column(String(36), ForeignKey("projects.id", ondelete="CASCADE"), nullable=False)
|
||||
|
||||
# 基本信息
|
||||
name = Column(String(100), nullable=False, comment="职业名称")
|
||||
type = Column(String(20), nullable=False, comment="职业类型: main(主职业)/sub(副职业)")
|
||||
description = Column(Text, comment="职业描述")
|
||||
category = Column(String(50), comment="职业分类(如:战斗系、生产系、辅助系)")
|
||||
|
||||
# 阶段设定
|
||||
stages = Column(Text, nullable=False, comment="职业阶段列表(JSON): [{level:1, name:'', description:''}, ...]")
|
||||
max_stage = Column(Integer, nullable=False, default=10, comment="最大阶段数")
|
||||
|
||||
# 职业特性
|
||||
requirements = Column(Text, comment="职业要求/限制")
|
||||
special_abilities = Column(Text, comment="特殊能力描述")
|
||||
worldview_rules = Column(Text, comment="世界观规则关联")
|
||||
|
||||
# 职业属性加成(可选,JSON格式)
|
||||
attribute_bonuses = Column(Text, comment="属性加成(JSON): {strength: '+10%', intelligence: '+5%'}")
|
||||
|
||||
# 元数据
|
||||
source = Column(String(20), default='ai', comment="来源: ai/manual")
|
||||
created_at = Column(DateTime, server_default=func.now(), comment="创建时间")
|
||||
updated_at = Column(DateTime, server_default=func.now(), onupdate=func.now(), comment="更新时间")
|
||||
|
||||
__table_args__ = (
|
||||
Index('idx_project_id', 'project_id'),
|
||||
Index('idx_type', 'type'),
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<Career(id={self.id}, name={self.name}, type={self.type})>"
|
||||
|
||||
|
||||
class CharacterCareer(Base):
|
||||
"""角色职业关联表"""
|
||||
__tablename__ = "character_careers"
|
||||
|
||||
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
|
||||
character_id = Column(String(36), ForeignKey("characters.id", ondelete="CASCADE"), nullable=False)
|
||||
career_id = Column(String(36), ForeignKey("careers.id", ondelete="CASCADE"), nullable=False)
|
||||
career_type = Column(String(20), nullable=False, comment="main(主职业)/sub(副职业)")
|
||||
|
||||
# 阶段进度
|
||||
current_stage = Column(Integer, nullable=False, default=1, comment="当前阶段(对应职业中的数值)")
|
||||
stage_progress = Column(Integer, default=0, comment="阶段内进度(0-100)")
|
||||
|
||||
# 时间记录
|
||||
started_at = Column(String(100), comment="开始修炼时间(小说时间线)")
|
||||
reached_current_stage_at = Column(String(100), comment="到达当前阶段时间")
|
||||
|
||||
# 备注
|
||||
notes = Column(Text, comment="备注(如:修炼心得、特殊事件)")
|
||||
|
||||
created_at = Column(DateTime, server_default=func.now(), comment="创建时间")
|
||||
updated_at = Column(DateTime, server_default=func.now(), onupdate=func.now(), comment="更新时间")
|
||||
|
||||
__table_args__ = (
|
||||
Index('idx_character_id', 'character_id'),
|
||||
Index('idx_career_type', 'career_type'),
|
||||
Index('idx_character_career', 'character_id', 'career_id', unique=True),
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<CharacterCareer(character_id={self.character_id}, career_id={self.career_id}, type={self.career_type})>"
|
||||
@@ -17,8 +17,16 @@ class Chapter(Base):
|
||||
summary = Column(Text, comment="章节摘要")
|
||||
word_count = Column(Integer, default=0, comment="字数统计")
|
||||
status = Column(String(20), default="draft", comment="章节状态")
|
||||
|
||||
# 大纲关联字段(实现一对多关系)
|
||||
outline_id = Column(String(36), ForeignKey("outlines.id", ondelete="SET NULL"), nullable=True, comment="关联的大纲ID")
|
||||
sub_index = Column(Integer, default=1, comment="大纲下的子章节序号")
|
||||
|
||||
# 大纲展开规划数据(JSON格式)
|
||||
expansion_plan = Column(Text, comment="展开规划详情(JSON): 包含key_events, character_focus, emotional_tone等")
|
||||
|
||||
created_at = Column(DateTime, server_default=func.now(), comment="创建时间")
|
||||
updated_at = Column(DateTime, server_default=func.now(), onupdate=func.now(), comment="更新时间")
|
||||
|
||||
def __repr__(self):
|
||||
return f"<Chapter(id={self.id}, chapter_number={self.chapter_number}, title={self.title})>"
|
||||
return f"<Chapter(id={self.id}, chapter_number={self.chapter_number}, title={self.title}, outline_id={self.outline_id})>"
|
||||
@@ -1,5 +1,5 @@
|
||||
"""角色数据模型"""
|
||||
from sqlalchemy import Column, String, Text, DateTime, ForeignKey, Boolean
|
||||
from sqlalchemy import Column, String, Text, DateTime, ForeignKey, Boolean, Integer
|
||||
from sqlalchemy.sql import func
|
||||
from app.database import Base
|
||||
import uuid
|
||||
@@ -14,8 +14,8 @@ class Character(Base):
|
||||
|
||||
# 基本信息
|
||||
name = Column(String(100), nullable=False, comment="角色/组织名称")
|
||||
age = Column(String(20), comment="年龄")
|
||||
gender = Column(String(20), comment="性别")
|
||||
age = Column(String(50), comment="年龄")
|
||||
gender = Column(String(50), comment="性别")
|
||||
is_organization = Column(Boolean, default=False, comment="是否为组织")
|
||||
|
||||
# 角色类型:protagonist(主角)/supporting(配角)/antagonist(反派)
|
||||
@@ -32,6 +32,11 @@ class Character(Base):
|
||||
organization_purpose = Column(String(500), comment="组织目的")
|
||||
organization_members = Column(Text, comment="组织成员(JSON)")
|
||||
|
||||
# 职业相关字段(冗余字段,用于提升查询性能)
|
||||
main_career_id = Column(String(36), ForeignKey("careers.id", ondelete="SET NULL"), comment="主职业ID")
|
||||
main_career_stage = Column(Integer, comment="主职业当前阶段")
|
||||
sub_careers = Column(Text, comment="副职业列表(JSON): [{\"career_id\": \"xxx\", \"stage\": 3}, ...]")
|
||||
|
||||
# 其他
|
||||
avatar_url = Column(String(500), comment="头像URL")
|
||||
traits = Column(Text, comment="特征标签(JSON)")
|
||||
|
||||
@@ -0,0 +1,52 @@
|
||||
"""MCP插件配置数据模型"""
|
||||
from sqlalchemy import Column, String, Text, Boolean, Integer, DateTime, Index, JSON
|
||||
from sqlalchemy.sql import func
|
||||
from app.database import Base
|
||||
import uuid
|
||||
|
||||
|
||||
class MCPPlugin(Base):
|
||||
"""MCP插件配置表"""
|
||||
__tablename__ = "mcp_plugins"
|
||||
|
||||
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
|
||||
user_id = Column(String(50), nullable=False, index=True, comment="用户ID")
|
||||
|
||||
# 插件基本信息
|
||||
plugin_name = Column(String(100), nullable=False, comment="插件名称(唯一标识)")
|
||||
display_name = Column(String(200), nullable=False, comment="显示名称")
|
||||
description = Column(Text, comment="插件描述")
|
||||
plugin_type = Column(String(50), default="http", comment="插件类型:http/stdio")
|
||||
|
||||
# 连接配置
|
||||
server_url = Column(String(500), comment="服务器URL(HTTP类型)")
|
||||
command = Column(String(500), comment="启动命令(stdio类型)")
|
||||
args = Column(JSON, comment="命令参数(stdio类型)")
|
||||
env = Column(JSON, comment="环境变量")
|
||||
headers = Column(JSON, comment="HTTP请求头")
|
||||
|
||||
# 插件配置
|
||||
config = Column(JSON, comment="插件特定配置(JSON)")
|
||||
tools = Column(JSON, comment="提供的工具列表")
|
||||
|
||||
# 状态管理
|
||||
enabled = Column(Boolean, default=True, comment="是否启用")
|
||||
status = Column(String(50), default="inactive", comment="状态:active/inactive/error")
|
||||
last_error = Column(Text, comment="最后错误信息")
|
||||
last_test_at = Column(DateTime, comment="最后测试时间")
|
||||
|
||||
# 排序和分组
|
||||
category = Column(String(100), default="general", comment="分类")
|
||||
sort_order = Column(Integer, default=0, comment="排序顺序")
|
||||
|
||||
# 时间戳
|
||||
created_at = Column(DateTime, server_default=func.now(), comment="创建时间")
|
||||
updated_at = Column(DateTime, server_default=func.now(), onupdate=func.now(), comment="更新时间")
|
||||
|
||||
__table_args__ = (
|
||||
Index('idx_user_plugin', 'user_id', 'plugin_name', unique=True),
|
||||
Index('idx_user_enabled', 'user_id', 'enabled'),
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<MCPPlugin(id={self.id}, name={self.plugin_name}, enabled={self.enabled})>"
|
||||
@@ -9,7 +9,7 @@ class StoryMemory(Base):
|
||||
"""故事记忆表 - 存储结构化的故事片段和元数据"""
|
||||
__tablename__ = "story_memories"
|
||||
|
||||
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
|
||||
id = Column(String(100), primary_key=True, default=lambda: str(uuid.uuid4()))
|
||||
project_id = Column(String(36), ForeignKey("projects.id", ondelete="CASCADE"), nullable=False, index=True)
|
||||
chapter_id = Column(String(36), ForeignKey("chapters.id", ondelete="CASCADE"), nullable=True, index=True)
|
||||
|
||||
@@ -45,7 +45,7 @@ class StoryMemory(Base):
|
||||
|
||||
# 伏笔相关字段
|
||||
is_foreshadow = Column(Integer, default=0, comment="伏笔状态: 0=普通记忆, 1=已埋下伏笔, 2=伏笔已回收")
|
||||
foreshadow_resolved_at = Column(String(36), ForeignKey("chapters.id", ondelete="SET NULL"), comment="伏笔回收的章节ID")
|
||||
foreshadow_resolved_at = Column(String(100), ForeignKey("chapters.id", ondelete="SET NULL"), comment="伏笔回收的章节ID")
|
||||
foreshadow_strength = Column(Float, comment="伏笔强度 0.0-1.0")
|
||||
|
||||
# 向量数据库关联
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
"""项目数据模型"""
|
||||
from sqlalchemy import Column, String, Text, DateTime, Integer
|
||||
from sqlalchemy import Column, String, Text, DateTime, Integer, CheckConstraint
|
||||
from sqlalchemy.sql import func
|
||||
from app.database import Base
|
||||
import uuid
|
||||
@@ -10,6 +10,7 @@ class Project(Base):
|
||||
__tablename__ = "projects"
|
||||
|
||||
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
|
||||
user_id = Column(String(100), nullable=False, index=True, comment="用户ID")
|
||||
title = Column(String(200), nullable=False, comment="项目标题")
|
||||
description = Column(Text, comment="项目简介")
|
||||
theme = Column(Text, comment="主题")
|
||||
@@ -19,6 +20,7 @@ class Project(Base):
|
||||
status = Column(String(20), default="planning", comment="创作状态")
|
||||
wizard_status = Column(String(20), default="incomplete", comment="向导完成状态: incomplete/completed")
|
||||
wizard_step = Column(Integer, default=0, comment="向导当前步骤: 0-4")
|
||||
outline_mode = Column(String(20), nullable=False, default="one-to-many", comment="大纲章节模式: one-to-one(传统模式) 或 one-to-many(细化模式)")
|
||||
|
||||
# 世界构建字段
|
||||
world_time_period = Column(Text, comment="时间背景")
|
||||
@@ -34,5 +36,12 @@ class Project(Base):
|
||||
created_at = Column(DateTime, server_default=func.now(), comment="创建时间")
|
||||
updated_at = Column(DateTime, server_default=func.now(), onupdate=func.now(), comment="更新时间")
|
||||
|
||||
__table_args__ = (
|
||||
CheckConstraint(
|
||||
"outline_mode IN ('one-to-one', 'one-to-many')",
|
||||
name='check_outline_mode'
|
||||
),
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<Project(id={self.id}, title={self.title})>"
|
||||
@@ -0,0 +1,30 @@
|
||||
"""提示词模板数据模型"""
|
||||
from sqlalchemy import Column, String, Text, Boolean, DateTime, Index
|
||||
from sqlalchemy.sql import func
|
||||
from app.database import Base
|
||||
import uuid
|
||||
|
||||
|
||||
class PromptTemplate(Base):
|
||||
"""提示词模板表"""
|
||||
__tablename__ = "prompt_templates"
|
||||
|
||||
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
|
||||
user_id = Column(String(50), nullable=False, index=True, comment="用户ID")
|
||||
template_key = Column(String(100), nullable=False, comment="模板键名")
|
||||
template_name = Column(String(200), nullable=False, comment="模板显示名称")
|
||||
template_content = Column(Text, nullable=False, comment="模板内容")
|
||||
description = Column(Text, comment="模板描述")
|
||||
category = Column(String(50), comment="模板分类")
|
||||
parameters = Column(Text, comment="模板参数定义(JSON)")
|
||||
is_active = Column(Boolean, default=True, comment="是否启用")
|
||||
is_system_default = Column(Boolean, default=False, comment="是否为系统默认模板")
|
||||
created_at = Column(DateTime, server_default=func.now(), comment="创建时间")
|
||||
updated_at = Column(DateTime, server_default=func.now(), onupdate=func.now(), comment="更新时间")
|
||||
|
||||
__table_args__ = (
|
||||
Index('idx_user_template', 'user_id', 'template_key', unique=True),
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<PromptTemplate(id={self.id}, user_id={self.user_id}, template_key={self.template_key})>"
|
||||
@@ -0,0 +1,51 @@
|
||||
"""章节重新生成任务模型"""
|
||||
from sqlalchemy import Column, String, Text, Integer, DateTime, ForeignKey, JSON, Boolean
|
||||
from sqlalchemy.sql import func
|
||||
from app.database import Base
|
||||
import uuid
|
||||
|
||||
|
||||
class RegenerationTask(Base):
|
||||
"""章节重新生成任务表"""
|
||||
__tablename__ = "regeneration_tasks"
|
||||
|
||||
# 基本信息
|
||||
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
|
||||
chapter_id = Column(String(36), ForeignKey('chapters.id', ondelete='CASCADE'), nullable=False, index=True)
|
||||
analysis_id = Column(String(36), nullable=True, comment="关联的分析结果ID")
|
||||
user_id = Column(String(50), nullable=False, index=True)
|
||||
project_id = Column(String(36), nullable=False, index=True)
|
||||
|
||||
# 修改指令
|
||||
modification_instructions = Column(Text, nullable=False, comment="综合修改指令")
|
||||
original_suggestions = Column(JSON, comment="来自分析的原始建议列表")
|
||||
selected_suggestion_indices = Column(JSON, comment="用户选择的建议索引")
|
||||
custom_instructions = Column(Text, comment="用户自定义修改意见")
|
||||
|
||||
# 生成参数
|
||||
style_id = Column(Integer, nullable=True, comment="写作风格ID")
|
||||
target_word_count = Column(Integer, default=3000, comment="目标字数")
|
||||
focus_areas = Column(JSON, comment="重点优化方向")
|
||||
preserve_elements = Column(JSON, comment="需要保留的元素配置")
|
||||
|
||||
# 状态跟踪
|
||||
status = Column(String(20), default='pending', comment="pending/running/completed/failed")
|
||||
progress = Column(Integer, default=0, comment="进度 0-100")
|
||||
error_message = Column(Text, nullable=True)
|
||||
|
||||
# 内容版本
|
||||
original_content = Column(Text, comment="原始章节内容快照")
|
||||
original_word_count = Column(Integer, comment="原始字数")
|
||||
regenerated_content = Column(Text, comment="重新生成的内容")
|
||||
regenerated_word_count = Column(Integer, comment="新内容字数")
|
||||
version_number = Column(Integer, default=1, comment="版本号")
|
||||
version_note = Column(String(500), comment="版本说明")
|
||||
|
||||
# 时间戳
|
||||
created_at = Column(DateTime, server_default=func.now())
|
||||
started_at = Column(DateTime, nullable=True)
|
||||
completed_at = Column(DateTime, nullable=True)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<RegenerationTask(id={self.id[:8]}..., chapter_id={self.chapter_id[:8]}..., status={self.status})>"
|
||||
|
||||
@@ -38,7 +38,7 @@ class CharacterRelationship(Base):
|
||||
relationship_name = Column(String(100), comment="自定义关系名称")
|
||||
|
||||
# 关系属性
|
||||
intimacy_level = Column(Integer, default=50, comment="亲密度:0-100")
|
||||
intimacy_level = Column(Integer, default=50, comment="亲密度:-100到100")
|
||||
status = Column(String(20), default="active", comment="状态:active/broken/past/complicated")
|
||||
description = Column(Text, comment="关系详细描述")
|
||||
|
||||
@@ -75,7 +75,7 @@ class Organization(Base):
|
||||
|
||||
# 组织特色
|
||||
motto = Column(String(200), comment="宗旨/口号")
|
||||
color = Column(String(20), comment="代表颜色")
|
||||
color = Column(String(100), comment="代表颜色")
|
||||
|
||||
created_at = Column(DateTime, server_default=func.now(), comment="创建时间")
|
||||
updated_at = Column(DateTime, server_default=func.now(), onupdate=func.now(), comment="更新时间")
|
||||
|
||||
@@ -17,6 +17,7 @@ class Settings(Base):
|
||||
llm_model = Column(String(100), default="gpt-4", comment="模型名称")
|
||||
temperature = Column(Float, default=0.7, comment="温度参数")
|
||||
max_tokens = Column(Integer, default=2000, comment="最大token数")
|
||||
system_prompt = Column(Text, comment="系统级别提示词,每次AI调用都会使用")
|
||||
preferences = Column(Text, comment="其他偏好设置(JSON)")
|
||||
created_at = Column(DateTime, server_default=func.now(), comment="创建时间")
|
||||
updated_at = Column(DateTime, server_default=func.now(), onupdate=func.now(), comment="更新时间")
|
||||
|
||||
@@ -0,0 +1,47 @@
|
||||
"""
|
||||
用户数据模型 - 存储用户基本信息
|
||||
"""
|
||||
from sqlalchemy import Column, String, Integer, Boolean, DateTime
|
||||
from sqlalchemy.sql import func
|
||||
from app.database import Base
|
||||
|
||||
|
||||
class User(Base):
|
||||
"""用户模型 - 存储OAuth和本地用户信息"""
|
||||
__tablename__ = "users"
|
||||
|
||||
user_id = Column(String(100), primary_key=True, index=True, comment="用户ID,格式:linuxdo_{id} 或 local_{id}")
|
||||
username = Column(String(100), nullable=False, index=True, comment="用户名")
|
||||
display_name = Column(String(200), nullable=False, comment="显示名称")
|
||||
avatar_url = Column(String(500), nullable=True, comment="头像URL")
|
||||
trust_level = Column(Integer, default=0, comment="信任等级(仅用于显示)")
|
||||
is_admin = Column(Boolean, default=False, comment="是否为管理员")
|
||||
linuxdo_id = Column(String(100), nullable=False, unique=True, index=True, comment="LinuxDO用户ID或本地用户ID")
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now(), comment="创建时间")
|
||||
last_login = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), comment="最后登录时间")
|
||||
|
||||
def to_dict(self):
|
||||
"""转换为字典"""
|
||||
return {
|
||||
"user_id": self.user_id,
|
||||
"username": self.username,
|
||||
"display_name": self.display_name,
|
||||
"avatar_url": self.avatar_url,
|
||||
"trust_level": self.trust_level,
|
||||
"is_admin": self.is_admin,
|
||||
"linuxdo_id": self.linuxdo_id,
|
||||
"created_at": self.created_at.isoformat() if self.created_at else None,
|
||||
"last_login": self.last_login.isoformat() if self.last_login else None,
|
||||
}
|
||||
|
||||
|
||||
class UserPassword(Base):
|
||||
"""用户密码模型 - 存储用户密码信息"""
|
||||
__tablename__ = "user_passwords"
|
||||
|
||||
user_id = Column(String(100), primary_key=True, index=True, comment="用户ID")
|
||||
username = Column(String(100), nullable=False, comment="用户名")
|
||||
password_hash = Column(String(64), nullable=False, comment="密码哈希(SHA256)")
|
||||
has_custom_password = Column(Boolean, default=False, comment="是否为自定义密码")
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now(), comment="创建时间")
|
||||
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), comment="更新时间")
|
||||
@@ -9,7 +9,7 @@ class WritingStyle(Base):
|
||||
__tablename__ = "writing_styles"
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
project_id = Column(String(36), ForeignKey("projects.id", ondelete="CASCADE"), nullable=True, comment="所属项目ID(NULL表示全局预设风格)")
|
||||
user_id = Column(String(255), ForeignKey("users.user_id", ondelete="CASCADE"), nullable=True, comment="所属用户ID(NULL表示全局预设风格)")
|
||||
name = Column(String(100), nullable=False, comment="风格名称")
|
||||
style_type = Column(String(50), nullable=False, comment="风格类型:preset/custom")
|
||||
preset_id = Column(String(50), comment="预设风格ID:natural/classical/modern等")
|
||||
@@ -20,4 +20,4 @@ class WritingStyle(Base):
|
||||
updated_at = Column(DateTime, server_default=func.now(), onupdate=func.now(), comment="更新时间")
|
||||
|
||||
def __repr__(self):
|
||||
return f"<WritingStyle(id={self.id}, name={self.name}, project_id={self.project_id})>"
|
||||
return f"<WritingStyle(id={self.id}, name={self.name}, user_id={self.user_id})>"
|
||||
@@ -0,0 +1,154 @@
|
||||
"""职业相关的Pydantic模型"""
|
||||
from pydantic import BaseModel, Field, ConfigDict
|
||||
from typing import Optional, List, Dict, Any
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
class CareerStage(BaseModel):
|
||||
"""职业阶段模型"""
|
||||
level: int = Field(..., description="阶段等级")
|
||||
name: str = Field(..., description="阶段名称")
|
||||
description: Optional[str] = Field(None, description="阶段描述")
|
||||
|
||||
|
||||
class CareerBase(BaseModel):
|
||||
"""职业基础模型"""
|
||||
name: str = Field(..., description="职业名称")
|
||||
type: str = Field(..., description="职业类型: main(主职业)/sub(副职业)")
|
||||
description: Optional[str] = Field(None, description="职业描述")
|
||||
category: Optional[str] = Field(None, description="职业分类")
|
||||
stages: List[CareerStage] = Field(..., description="职业阶段列表")
|
||||
max_stage: int = Field(10, description="最大阶段数")
|
||||
requirements: Optional[str] = Field(None, description="职业要求/限制")
|
||||
special_abilities: Optional[str] = Field(None, description="特殊能力描述")
|
||||
worldview_rules: Optional[str] = Field(None, description="世界观规则关联")
|
||||
attribute_bonuses: Optional[Dict[str, str]] = Field(None, description="属性加成")
|
||||
|
||||
|
||||
class CareerCreate(CareerBase):
|
||||
"""创建职业的请求模型"""
|
||||
project_id: str = Field(..., description="项目ID")
|
||||
source: str = Field("manual", description="来源: ai/manual")
|
||||
|
||||
|
||||
class CareerUpdate(BaseModel):
|
||||
"""更新职业的请求模型"""
|
||||
name: Optional[str] = None
|
||||
type: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
category: Optional[str] = None
|
||||
stages: Optional[List[CareerStage]] = None
|
||||
max_stage: Optional[int] = None
|
||||
requirements: Optional[str] = None
|
||||
special_abilities: Optional[str] = None
|
||||
worldview_rules: Optional[str] = None
|
||||
attribute_bonuses: Optional[Dict[str, str]] = None
|
||||
|
||||
|
||||
class CareerResponse(BaseModel):
|
||||
"""职业响应模型"""
|
||||
id: str
|
||||
project_id: str
|
||||
name: str
|
||||
type: str
|
||||
description: Optional[str] = None
|
||||
category: Optional[str] = None
|
||||
stages: List[CareerStage]
|
||||
max_stage: int
|
||||
requirements: Optional[str] = None
|
||||
special_abilities: Optional[str] = None
|
||||
worldview_rules: Optional[str] = None
|
||||
attribute_bonuses: Optional[Dict[str, str]] = None
|
||||
source: str
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class CareerListResponse(BaseModel):
|
||||
"""职业列表响应模型"""
|
||||
total: int
|
||||
main_careers: List[CareerResponse] = Field(default_factory=list, description="主职业列表")
|
||||
sub_careers: List[CareerResponse] = Field(default_factory=list, description="副职业列表")
|
||||
|
||||
|
||||
class CareerGenerateRequest(BaseModel):
|
||||
"""AI生成职业体系的请求模型"""
|
||||
project_id: str = Field(..., description="项目ID")
|
||||
main_career_count: int = Field(5, description="主职业数量", ge=1, le=20)
|
||||
sub_career_count: int = Field(8, description="副职业数量", ge=0, le=30)
|
||||
enable_mcp: bool = Field(False, description="是否启用MCP工具增强")
|
||||
|
||||
|
||||
# ===== 角色职业关联相关 =====
|
||||
|
||||
class CharacterCareerBase(BaseModel):
|
||||
"""角色职业关联基础模型"""
|
||||
career_id: str = Field(..., description="职业ID")
|
||||
career_type: str = Field(..., description="main(主职业)/sub(副职业)")
|
||||
current_stage: int = Field(1, description="当前阶段", ge=1)
|
||||
stage_progress: int = Field(0, description="阶段内进度(0-100)", ge=0, le=100)
|
||||
started_at: Optional[str] = Field(None, description="开始修炼时间")
|
||||
reached_current_stage_at: Optional[str] = Field(None, description="到达当前阶段时间")
|
||||
notes: Optional[str] = Field(None, description="备注")
|
||||
|
||||
|
||||
class CharacterCareerCreate(CharacterCareerBase):
|
||||
"""创建角色职业关联的请求模型"""
|
||||
character_id: str = Field(..., description="角色ID")
|
||||
|
||||
|
||||
class CharacterCareerUpdate(BaseModel):
|
||||
"""更新角色职业关联的请求模型"""
|
||||
current_stage: Optional[int] = Field(None, ge=1)
|
||||
stage_progress: Optional[int] = Field(None, ge=0, le=100)
|
||||
reached_current_stage_at: Optional[str] = None
|
||||
notes: Optional[str] = None
|
||||
|
||||
|
||||
class CharacterCareerDetail(BaseModel):
|
||||
"""角色职业详情模型(包含职业信息)"""
|
||||
id: str
|
||||
character_id: str
|
||||
career_id: str
|
||||
career_name: str = Field(..., description="职业名称")
|
||||
career_type: str
|
||||
current_stage: int
|
||||
stage_name: str = Field(..., description="当前阶段名称")
|
||||
stage_description: Optional[str] = Field(None, description="当前阶段描述")
|
||||
stage_progress: int
|
||||
max_stage: int = Field(..., description="该职业的最大阶段")
|
||||
started_at: Optional[str] = None
|
||||
reached_current_stage_at: Optional[str] = None
|
||||
notes: Optional[str] = None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
|
||||
class CharacterCareerResponse(BaseModel):
|
||||
"""角色职业响应模型"""
|
||||
main_career: Optional[CharacterCareerDetail] = Field(None, description="主职业")
|
||||
sub_careers: List[CharacterCareerDetail] = Field(default_factory=list, description="副职业列表")
|
||||
|
||||
|
||||
class SetMainCareerRequest(BaseModel):
|
||||
"""设置主职业请求模型"""
|
||||
career_id: str = Field(..., description="职业ID")
|
||||
current_stage: int = Field(1, description="当前阶段", ge=1)
|
||||
started_at: Optional[str] = Field(None, description="开始修炼时间")
|
||||
|
||||
|
||||
class AddSubCareerRequest(BaseModel):
|
||||
"""添加副职业请求模型"""
|
||||
career_id: str = Field(..., description="职业ID")
|
||||
current_stage: int = Field(1, description="当前阶段", ge=1)
|
||||
started_at: Optional[str] = Field(None, description="开始修炼时间")
|
||||
|
||||
|
||||
class UpdateCareerStageRequest(BaseModel):
|
||||
"""更新职业阶段请求模型"""
|
||||
current_stage: int = Field(..., description="新的阶段", ge=1)
|
||||
stage_progress: int = Field(0, description="阶段进度", ge=0, le=100)
|
||||
reached_current_stage_at: Optional[str] = Field(None, description="到达时间")
|
||||
notes: Optional[str] = Field(None, description="备注")
|
||||
@@ -1,6 +1,6 @@
|
||||
"""章节相关的Pydantic模型"""
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Optional
|
||||
from pydantic import BaseModel, Field, ConfigDict
|
||||
from typing import Optional, List, Dict, Any
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
@@ -12,6 +12,9 @@ class ChapterBase(BaseModel):
|
||||
summary: Optional[str] = Field(None, description="章节摘要")
|
||||
word_count: Optional[int] = Field(0, description="字数")
|
||||
status: Optional[str] = Field("draft", description="章节状态")
|
||||
outline_id: Optional[str] = Field(None, description="关联的大纲ID")
|
||||
sub_index: Optional[int] = Field(1, description="大纲下的子章节序号")
|
||||
expansion_plan: Optional[str] = Field(None, description="展开规划详情(JSON)")
|
||||
|
||||
|
||||
class ChapterCreate(BaseModel):
|
||||
@@ -22,6 +25,9 @@ class ChapterCreate(BaseModel):
|
||||
content: Optional[str] = Field(None, description="章节内容")
|
||||
summary: Optional[str] = Field(None, description="章节摘要")
|
||||
status: Optional[str] = Field("draft", description="章节状态")
|
||||
outline_id: Optional[str] = Field(None, description="关联的大纲ID")
|
||||
sub_index: Optional[int] = Field(1, description="大纲下的子章节序号")
|
||||
expansion_plan: Optional[str] = Field(None, description="展开规划详情(JSON)")
|
||||
|
||||
|
||||
class ChapterUpdate(BaseModel):
|
||||
@@ -44,11 +50,15 @@ class ChapterResponse(BaseModel):
|
||||
summary: Optional[str] = None
|
||||
word_count: int = 0
|
||||
status: str
|
||||
outline_id: Optional[str] = None
|
||||
sub_index: Optional[int] = 1
|
||||
expansion_plan: Optional[str] = None
|
||||
outline_title: Optional[str] = None # 大纲标题(从Outline表联查)
|
||||
outline_order: Optional[int] = None # 大纲排序序号(从Outline表联查)
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class ChapterListResponse(BaseModel):
|
||||
@@ -65,4 +75,93 @@ class ChapterGenerateRequest(BaseModel):
|
||||
description="目标字数,默认3000字",
|
||||
ge=500, # 最小500字
|
||||
le=10000 # 最大10000字
|
||||
)
|
||||
)
|
||||
enable_mcp: bool = Field(True, description="是否启用MCP工具增强(搜索参考资料)")
|
||||
model: Optional[str] = Field(None, description="指定使用的AI模型,不提供则使用用户默认模型")
|
||||
narrative_perspective: Optional[str] = Field(None, description="临时人称视角:first_person/third_person/omniscient,不提供则使用项目默认")
|
||||
|
||||
|
||||
class BatchGenerateRequest(BaseModel):
|
||||
"""批量生成章节的请求模型"""
|
||||
start_chapter_number: int = Field(..., description="起始章节序号")
|
||||
count: int = Field(..., description="生成章节数量", ge=1, le=20)
|
||||
style_id: Optional[int] = Field(None, description="写作风格ID")
|
||||
target_word_count: Optional[int] = Field(
|
||||
3000,
|
||||
description="目标字数,默认3000字",
|
||||
ge=500,
|
||||
le=10000
|
||||
)
|
||||
enable_analysis: bool = Field(False, description="是否启用同步分析")
|
||||
enable_mcp: bool = Field(True, description="是否启用MCP工具增强(搜索参考资料)")
|
||||
max_retries: int = Field(3, description="每个章节的最大重试次数", ge=0, le=5)
|
||||
model: Optional[str] = Field(None, description="指定使用的AI模型,不提供则使用用户默认模型")
|
||||
|
||||
|
||||
class BatchGenerateResponse(BaseModel):
|
||||
"""批量生成响应模型"""
|
||||
batch_id: str = Field(..., description="批次ID")
|
||||
message: str = Field(..., description="响应消息")
|
||||
chapters_to_generate: list[dict] = Field(..., description="待生成章节列表")
|
||||
estimated_time_minutes: int = Field(..., description="预估耗时(分钟)")
|
||||
|
||||
|
||||
class BatchGenerateStatusResponse(BaseModel):
|
||||
"""批量生成状态响应模型"""
|
||||
batch_id: str
|
||||
status: str
|
||||
total: int
|
||||
completed: int
|
||||
current_chapter_id: Optional[str] = None
|
||||
current_chapter_number: Optional[int] = None
|
||||
current_retry_count: Optional[int] = None
|
||||
max_retries: Optional[int] = None
|
||||
failed_chapters: list[dict] = []
|
||||
created_at: Optional[str] = None
|
||||
started_at: Optional[str] = None
|
||||
completed_at: Optional[str] = None
|
||||
error_message: Optional[str] = None
|
||||
|
||||
|
||||
class SceneData(BaseModel):
|
||||
"""场景数据模型"""
|
||||
location: str = Field(..., description="场景地点")
|
||||
characters: List[str] = Field(..., description="参与角色列表")
|
||||
purpose: str = Field(..., description="场景目的")
|
||||
|
||||
|
||||
class ExpansionPlanUpdate(BaseModel):
|
||||
"""章节规划更新模型"""
|
||||
summary: Optional[str] = Field(None, description="章节情节概要")
|
||||
key_events: Optional[List[str]] = Field(None, description="关键事件列表")
|
||||
character_focus: Optional[List[str]] = Field(None, description="涉及角色列表")
|
||||
emotional_tone: Optional[str] = Field(None, description="情感基调")
|
||||
narrative_goal: Optional[str] = Field(None, description="叙事目标")
|
||||
conflict_type: Optional[str] = Field(None, description="冲突类型")
|
||||
estimated_words: Optional[int] = Field(None, description="预估字数", ge=500, le=10000)
|
||||
scenes: Optional[List[SceneData]] = Field(None, description="场景列表")
|
||||
|
||||
model_config = ConfigDict(json_schema_extra={
|
||||
"example": {
|
||||
"key_events": ["主角遇到挑战", "关键决策时刻"],
|
||||
"character_focus": ["张三", "李四"],
|
||||
"emotional_tone": "紧张激烈",
|
||||
"narrative_goal": "推进主线剧情",
|
||||
"conflict_type": "内心冲突",
|
||||
"estimated_words": 3000,
|
||||
"scenes": [
|
||||
{
|
||||
"location": "城市广场",
|
||||
"characters": ["张三", "李四"],
|
||||
"purpose": "初次相遇"
|
||||
}
|
||||
]
|
||||
}
|
||||
})
|
||||
|
||||
|
||||
class ExpansionPlanResponse(BaseModel):
|
||||
"""章节规划响应模型"""
|
||||
id: str = Field(..., description="章节ID")
|
||||
expansion_plan: Optional[Dict[str, Any]] = Field(None, description="规划数据")
|
||||
message: str = Field(..., description="响应消息")
|
||||
@@ -1,5 +1,5 @@
|
||||
"""角色相关的Pydantic模型"""
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, ConfigDict
|
||||
from typing import Optional, List, Dict, Any
|
||||
from datetime import datetime
|
||||
|
||||
@@ -21,6 +21,36 @@ class CharacterBase(BaseModel):
|
||||
traits: Optional[str] = Field(None, description="特征标签(JSON)")
|
||||
|
||||
|
||||
class CharacterCreate(BaseModel):
|
||||
"""手动创建角色的请求模型"""
|
||||
project_id: str = Field(..., description="项目ID")
|
||||
name: str = Field(..., description="角色/组织姓名")
|
||||
age: Optional[str] = Field(None, description="年龄")
|
||||
gender: Optional[str] = Field(None, description="性别")
|
||||
is_organization: bool = Field(False, description="是否为组织")
|
||||
role_type: Optional[str] = Field("supporting", description="角色类型:protagonist/supporting/antagonist")
|
||||
personality: Optional[str] = Field(None, description="性格特点/组织特性")
|
||||
background: Optional[str] = Field(None, description="背景故事")
|
||||
appearance: Optional[str] = Field(None, description="外貌特征")
|
||||
relationships: Optional[str] = Field(None, description="人际关系(JSON)")
|
||||
organization_type: Optional[str] = Field(None, description="组织类型")
|
||||
organization_purpose: Optional[str] = Field(None, description="组织目的")
|
||||
organization_members: Optional[str] = Field(None, description="组织成员(JSON)")
|
||||
traits: Optional[str] = Field(None, description="特征标签(JSON)")
|
||||
avatar_url: Optional[str] = Field(None, description="头像URL")
|
||||
|
||||
# 组织额外字段
|
||||
power_level: Optional[int] = Field(None, description="组织势力等级(0-100)")
|
||||
location: Optional[str] = Field(None, description="组织所在地")
|
||||
motto: Optional[str] = Field(None, description="组织格言/口号")
|
||||
color: Optional[str] = Field(None, description="组织代表颜色")
|
||||
|
||||
# 职业字段
|
||||
main_career_id: Optional[str] = Field(None, description="主职业ID")
|
||||
main_career_stage: Optional[int] = Field(None, description="主职业阶段")
|
||||
sub_careers: Optional[str] = Field(None, description="副职业列表JSON字符串")
|
||||
|
||||
|
||||
class CharacterUpdate(BaseModel):
|
||||
"""更新角色的请求模型"""
|
||||
name: Optional[str] = None
|
||||
@@ -36,6 +66,17 @@ class CharacterUpdate(BaseModel):
|
||||
organization_purpose: Optional[str] = None
|
||||
organization_members: Optional[str] = None
|
||||
traits: Optional[str] = None
|
||||
|
||||
# 组织额外字段(会同步到Organization表)
|
||||
power_level: Optional[int] = Field(None, description="组织势力等级(0-100)")
|
||||
location: Optional[str] = Field(None, description="组织所在地")
|
||||
motto: Optional[str] = Field(None, description="组织格言/口号")
|
||||
color: Optional[str] = Field(None, description="组织代表颜色")
|
||||
|
||||
# 职业字段(会同步到CharacterCareer表)
|
||||
main_career_id: Optional[str] = Field(None, description="主职业ID")
|
||||
main_career_stage: Optional[int] = Field(None, description="主职业阶段")
|
||||
sub_careers: Optional[str] = Field(None, description="副职业列表JSON字符串")
|
||||
|
||||
|
||||
class CharacterResponse(CharacterBase):
|
||||
@@ -46,8 +87,18 @@ class CharacterResponse(CharacterBase):
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
# 组织额外字段(从Organization表关联)
|
||||
power_level: Optional[int] = Field(None, description="组织势力等级(0-100)")
|
||||
location: Optional[str] = Field(None, description="组织所在地")
|
||||
motto: Optional[str] = Field(None, description="组织格言/口号")
|
||||
color: Optional[str] = Field(None, description="组织代表颜色")
|
||||
|
||||
# 职业信息字段
|
||||
main_career_id: Optional[str] = Field(None, description="主职业ID")
|
||||
main_career_stage: Optional[int] = Field(None, description="主职业阶段")
|
||||
sub_careers: Optional[List[Dict[str, Any]]] = Field(None, description="副职业列表")
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class CharacterGenerateRequest(BaseModel):
|
||||
@@ -57,8 +108,7 @@ class CharacterGenerateRequest(BaseModel):
|
||||
role_type: Optional[str] = Field(None, description="角色类型")
|
||||
background: Optional[str] = Field(None, description="角色背景")
|
||||
requirements: Optional[str] = Field(None, description="特殊要求")
|
||||
provider: Optional[str] = Field(None, description="AI提供商")
|
||||
model: Optional[str] = Field(None, description="AI模型")
|
||||
enable_mcp: bool = Field(True, description="是否启用MCP工具增强(搜索人物原型参考)")
|
||||
|
||||
|
||||
class CharacterListResponse(BaseModel):
|
||||
|
||||
@@ -19,6 +19,11 @@ class ChapterExportData(BaseModel):
|
||||
word_count: int = 0
|
||||
status: str = "draft"
|
||||
created_at: Optional[str] = None
|
||||
|
||||
# 大纲细化功能新增字段
|
||||
outline_title: Optional[str] = None # 关联的大纲标题(用于导入时重建关联)
|
||||
sub_index: Optional[int] = None # 大纲下的子章节序号
|
||||
expansion_plan: Optional[Dict[str, Any]] = None # 展开规划详情(JSON对象)
|
||||
|
||||
|
||||
class CharacterExportData(BaseModel):
|
||||
@@ -31,9 +36,20 @@ class CharacterExportData(BaseModel):
|
||||
personality: Optional[str] = None
|
||||
background: Optional[str] = None
|
||||
appearance: Optional[str] = None
|
||||
relationships: Optional[str] = None
|
||||
traits: Optional[List[str]] = None
|
||||
organization_type: Optional[str] = None
|
||||
organization_purpose: Optional[str] = None
|
||||
organization_members: Optional[str] = None
|
||||
avatar_url: Optional[str] = None
|
||||
main_career_id: Optional[str] = None
|
||||
main_career_stage: Optional[int] = None
|
||||
sub_careers: Optional[str] = None
|
||||
# 组织专属字段
|
||||
power_level: Optional[int] = None
|
||||
location: Optional[str] = None
|
||||
motto: Optional[str] = None
|
||||
color: Optional[str] = None
|
||||
created_at: Optional[str] = None
|
||||
|
||||
|
||||
@@ -133,4 +149,28 @@ class ImportResult(BaseModel):
|
||||
project_id: Optional[str] = None
|
||||
message: str
|
||||
statistics: Dict[str, int] = {}
|
||||
details: Optional[Dict[str, List[str]]] = None
|
||||
warnings: List[str] = []
|
||||
|
||||
|
||||
class CharactersExportRequest(BaseModel):
|
||||
"""角色/组织批量导出请求"""
|
||||
character_ids: List[str] = Field(..., description="要导出的角色/组织ID列表")
|
||||
|
||||
|
||||
class CharactersExportData(BaseModel):
|
||||
"""角色/组织批量导出数据"""
|
||||
version: str = "1.0.0"
|
||||
export_time: str
|
||||
export_type: str = "characters"
|
||||
count: int
|
||||
data: List[CharacterExportData]
|
||||
|
||||
|
||||
class CharactersImportResult(BaseModel):
|
||||
"""角色/组织导入结果"""
|
||||
success: bool
|
||||
message: str
|
||||
statistics: Dict[str, int]
|
||||
details: Dict[str, List[str]]
|
||||
warnings: List[str] = []
|
||||
@@ -0,0 +1,104 @@
|
||||
"""MCP插件Pydantic模式"""
|
||||
from pydantic import BaseModel, Field, ConfigDict
|
||||
from typing import Optional, Dict, Any, List
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
class MCPToolSchema(BaseModel):
|
||||
"""MCP工具定义"""
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
inputSchema: Optional[Dict[str, Any]] = None
|
||||
category: Optional[str] = None
|
||||
|
||||
|
||||
class MCPPluginBase(BaseModel):
|
||||
"""插件基础模式"""
|
||||
plugin_name: str = Field(..., description="插件唯一标识")
|
||||
display_name: Optional[str] = Field(None, description="显示名称")
|
||||
description: Optional[str] = Field(None, description="插件描述")
|
||||
plugin_type: str = Field(default="http", description="插件类型:http/stdio")
|
||||
category: str = Field(default="general", description="分类")
|
||||
sort_order: int = Field(default=0, description="排序顺序")
|
||||
|
||||
|
||||
class MCPPluginCreate(MCPPluginBase):
|
||||
"""创建插件"""
|
||||
server_url: Optional[str] = Field(None, description="服务器URL(HTTP类型)")
|
||||
command: Optional[str] = Field(None, description="启动命令(stdio类型)")
|
||||
args: Optional[List[str]] = Field(None, description="命令参数")
|
||||
env: Optional[Dict[str, str]] = Field(None, description="环境变量")
|
||||
headers: Optional[Dict[str, str]] = Field(None, description="HTTP请求头")
|
||||
config: Optional[Dict[str, Any]] = Field(None, description="插件特定配置")
|
||||
enabled: bool = Field(default=True, description="是否启用")
|
||||
|
||||
|
||||
class MCPPluginSimpleCreate(BaseModel):
|
||||
"""简化的插件创建(通过标准MCP配置JSON)"""
|
||||
config_json: str = Field(..., description="标准MCP配置JSON字符串")
|
||||
enabled: bool = Field(default=True, description="是否启用")
|
||||
category: str = Field(default="general", description="插件分类")
|
||||
|
||||
|
||||
class MCPPluginUpdate(BaseModel):
|
||||
"""更新插件"""
|
||||
display_name: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
server_url: Optional[str] = None
|
||||
command: Optional[str] = None
|
||||
args: Optional[List[str]] = None
|
||||
env: Optional[Dict[str, str]] = None
|
||||
headers: Optional[Dict[str, str]] = None
|
||||
config: Optional[Dict[str, Any]] = None
|
||||
enabled: Optional[bool] = None
|
||||
category: Optional[str] = None
|
||||
sort_order: Optional[int] = None
|
||||
|
||||
|
||||
class MCPPluginResponse(BaseModel):
|
||||
"""插件响应 - 优化后只返回必要字段"""
|
||||
id: str
|
||||
plugin_name: str
|
||||
display_name: str
|
||||
description: Optional[str] = None
|
||||
plugin_type: str
|
||||
category: str
|
||||
|
||||
# HTTP类型字段
|
||||
server_url: Optional[str] = None
|
||||
headers: Optional[Dict[str, str]] = None
|
||||
|
||||
# Stdio类型字段
|
||||
command: Optional[str] = None
|
||||
args: Optional[List[str]] = None
|
||||
env: Optional[Dict[str, str]] = None
|
||||
|
||||
# 状态字段
|
||||
enabled: bool
|
||||
status: str
|
||||
last_error: Optional[str] = None
|
||||
last_test_at: Optional[datetime] = None
|
||||
|
||||
# 时间戳
|
||||
created_at: datetime
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class MCPToolCall(BaseModel):
|
||||
"""工具调用请求"""
|
||||
plugin_id: str = Field(..., description="插件ID")
|
||||
tool_name: str = Field(..., description="工具名称")
|
||||
arguments: Dict[str, Any] = Field(default_factory=dict, description="工具参数")
|
||||
|
||||
|
||||
class MCPTestResult(BaseModel):
|
||||
"""测试结果"""
|
||||
success: bool
|
||||
message: str
|
||||
response_time_ms: Optional[float] = None
|
||||
tools_count: Optional[int] = None
|
||||
tools: Optional[List[MCPToolSchema]] = None
|
||||
error: Optional[str] = None
|
||||
error_type: Optional[str] = None
|
||||
suggestions: Optional[List[str]] = None
|
||||
+144
-11
@@ -1,9 +1,74 @@
|
||||
"""大纲相关的Pydantic模型"""
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Optional
|
||||
from pydantic import BaseModel, Field, ConfigDict
|
||||
from typing import Optional, List, Dict, Any
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
# 角色预测相关Schema
|
||||
class CharacterPredictionRequest(BaseModel):
|
||||
"""角色预测请求"""
|
||||
project_id: str
|
||||
start_chapter: int
|
||||
chapter_count: int = 3
|
||||
plot_stage: str = "development"
|
||||
story_direction: Optional[str] = "自然延续"
|
||||
enable_mcp: bool = True
|
||||
|
||||
|
||||
class PredictedCharacter(BaseModel):
|
||||
"""预测的角色信息"""
|
||||
name: Optional[str] = None
|
||||
role_description: str
|
||||
suggested_role_type: str
|
||||
importance: str
|
||||
appearance_chapter: int
|
||||
key_abilities: List[str] = []
|
||||
plot_function: str
|
||||
relationship_suggestions: List[Dict[str, str]] = []
|
||||
|
||||
|
||||
class CharacterPredictionResponse(BaseModel):
|
||||
"""角色预测响应"""
|
||||
needs_new_characters: bool
|
||||
reason: str
|
||||
character_count: int
|
||||
predicted_characters: List[PredictedCharacter]
|
||||
|
||||
|
||||
# 组织预测相关Schema
|
||||
class OrganizationPredictionRequest(BaseModel):
|
||||
"""组织预测请求"""
|
||||
project_id: str
|
||||
start_chapter: int
|
||||
chapter_count: int = 3
|
||||
plot_stage: str = "development"
|
||||
story_direction: Optional[str] = "自然延续"
|
||||
enable_mcp: bool = True
|
||||
|
||||
|
||||
class PredictedOrganization(BaseModel):
|
||||
"""预测的组织信息"""
|
||||
name: Optional[str] = None
|
||||
organization_description: str
|
||||
organization_type: str
|
||||
importance: str
|
||||
appearance_chapter: int
|
||||
power_level: int = 50
|
||||
plot_function: str
|
||||
location: Optional[str] = None
|
||||
motto: Optional[str] = None
|
||||
initial_members: List[Dict[str, Any]] = []
|
||||
relationship_suggestions: List[Dict[str, str]] = []
|
||||
|
||||
|
||||
class OrganizationPredictionResponse(BaseModel):
|
||||
"""组织预测响应"""
|
||||
needs_new_organizations: bool
|
||||
reason: str
|
||||
organization_count: int
|
||||
predicted_organizations: List[PredictedOrganization]
|
||||
|
||||
|
||||
class OutlineBase(BaseModel):
|
||||
"""大纲基础模型"""
|
||||
title: str = Field(..., description="章节标题")
|
||||
@@ -38,8 +103,7 @@ class OutlineResponse(BaseModel):
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class OutlineGenerateRequest(BaseModel):
|
||||
@@ -61,6 +125,17 @@ class OutlineGenerateRequest(BaseModel):
|
||||
story_direction: Optional[str] = Field(None, description="故事发展方向提示(续写时使用)")
|
||||
plot_stage: str = Field("development", description="情节阶段: development(发展), climax(高潮), ending(结局)")
|
||||
keep_existing: bool = Field(False, description="是否保留现有大纲(续写时)")
|
||||
enable_mcp: bool = Field(True, description="是否启用MCP工具增强(搜索情节设计参考)")
|
||||
|
||||
# 自动角色引入相关参数
|
||||
enable_auto_characters: bool = Field(True, description="是否启用自动角色引入(根据剧情推进自动创建新角色)")
|
||||
require_character_confirmation: bool = Field(True, description="是否需要用户确认新角色(False则AI预测的角色直接创建)")
|
||||
confirmed_characters: Optional[List[Dict[str, Any]]] = Field(None, description="用户确认的角色列表(跳过预测直接创建)")
|
||||
|
||||
# 自动组织引入相关参数
|
||||
enable_auto_organizations: bool = Field(True, description="是否启用自动组织引入(根据剧情推进自动创建新组织)")
|
||||
require_organization_confirmation: bool = Field(True, description="是否需要用户确认新组织(False则AI预测的组织直接创建)")
|
||||
confirmed_organizations: Optional[List[Dict[str, Any]]] = Field(None, description="用户确认的组织列表(跳过预测直接创建)")
|
||||
|
||||
|
||||
class ChapterOutlineGenerateRequest(BaseModel):
|
||||
@@ -77,12 +152,70 @@ class OutlineListResponse(BaseModel):
|
||||
items: list[OutlineResponse]
|
||||
|
||||
|
||||
class OutlineReorderItem(BaseModel):
|
||||
"""单个大纲重排序项"""
|
||||
id: str = Field(..., description="大纲ID")
|
||||
order_index: int = Field(..., description="新的序号", ge=1)
|
||||
class ChapterPlanItem(BaseModel):
|
||||
"""单个章节规划项"""
|
||||
sub_index: int = Field(..., description="子章节序号", ge=1)
|
||||
title: str = Field(..., description="章节标题")
|
||||
plot_summary: str = Field(..., description="剧情摘要(200-300字)")
|
||||
key_events: list[str] = Field(..., description="关键事件列表")
|
||||
character_focus: list[str] = Field(..., description="主要涉及的角色")
|
||||
emotional_tone: str = Field(..., description="情感基调")
|
||||
narrative_goal: str = Field(..., description="叙事目标")
|
||||
conflict_type: str = Field(..., description="冲突类型")
|
||||
estimated_words: int = Field(3000, description="预计字数", ge=1000)
|
||||
scenes: Optional[list[str]] = Field(None, description="场景列表(可选)")
|
||||
|
||||
|
||||
class OutlineReorderRequest(BaseModel):
|
||||
"""大纲批量重排序请求"""
|
||||
orders: list[OutlineReorderItem] = Field(..., description="排序列表")
|
||||
class OutlineExpansionRequest(BaseModel):
|
||||
"""大纲展开为多章节的请求模型(outline_id从路径参数获取)"""
|
||||
target_chapter_count: int = Field(3, description="目标章节数", ge=1, le=10)
|
||||
expansion_strategy: str = Field("balanced", description="展开策略: balanced(均衡), climax(高潮重点), detail(细节丰富)")
|
||||
enable_scene_analysis: bool = Field(False, description="是否包含场景规划")
|
||||
auto_create_chapters: bool = Field(True, description="是否自动创建章节记录")
|
||||
provider: Optional[str] = Field(None, description="AI提供商")
|
||||
model: Optional[str] = Field(None, description="AI模型")
|
||||
|
||||
|
||||
class OutlineExpansionResponse(BaseModel):
|
||||
"""大纲展开响应模型"""
|
||||
outline_id: str = Field(..., description="大纲ID")
|
||||
outline_title: str = Field(..., description="大纲标题")
|
||||
target_chapter_count: int = Field(..., description="目标章节数")
|
||||
actual_chapter_count: int = Field(..., description="实际生成的章节数")
|
||||
expansion_strategy: str = Field(..., description="使用的展开策略")
|
||||
chapter_plans: list[ChapterPlanItem] = Field(..., description="章节规划列表")
|
||||
created_chapters: Optional[list] = Field(None, description="已创建的章节列表")
|
||||
|
||||
|
||||
class BatchOutlineExpansionRequest(BaseModel):
|
||||
"""批量大纲展开请求模型"""
|
||||
project_id: str = Field(..., description="项目ID")
|
||||
outline_ids: Optional[list[str]] = Field(None, description="要展开的大纲ID列表(为空则展开所有)")
|
||||
chapters_per_outline: int = Field(3, description="每个大纲的目标章节数", ge=1, le=10)
|
||||
expansion_strategy: str = Field("balanced", description="展开策略")
|
||||
enable_scene_analysis: bool = Field(False, description="是否包含场景规划")
|
||||
auto_create_chapters: bool = Field(True, description="是否自动创建章节记录")
|
||||
provider: Optional[str] = Field(None, description="AI提供商")
|
||||
model: Optional[str] = Field(None, description="AI模型")
|
||||
|
||||
|
||||
class BatchOutlineExpansionResponse(BaseModel):
|
||||
"""批量大纲展开响应模型"""
|
||||
project_id: str = Field(..., description="项目ID")
|
||||
total_outlines_expanded: int = Field(..., description="总共展开的大纲数")
|
||||
total_chapters_created: int = Field(..., description="总共创建的章节数")
|
||||
expansion_results: list[OutlineExpansionResponse] = Field(..., description="展开结果列表")
|
||||
skipped_outlines: Optional[list[dict]] = Field(None, description="跳过的大纲列表(已展开)")
|
||||
|
||||
|
||||
class CreateChaptersFromPlansRequest(BaseModel):
|
||||
"""根据已有规划创建章节的请求模型"""
|
||||
chapter_plans: list[ChapterPlanItem] = Field(..., description="章节规划列表(来自之前的AI生成结果)")
|
||||
|
||||
|
||||
class CreateChaptersFromPlansResponse(BaseModel):
|
||||
"""根据已有规划创建章节的响应模型"""
|
||||
outline_id: str = Field(..., description="大纲ID")
|
||||
outline_title: str = Field(..., description="大纲标题")
|
||||
chapters_created: int = Field(..., description="创建的章节数")
|
||||
created_chapters: list = Field(..., description="创建的章节列表")
|
||||
@@ -1,6 +1,6 @@
|
||||
"""项目相关的Pydantic模型"""
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Optional
|
||||
from pydantic import BaseModel, Field, ConfigDict
|
||||
from typing import Optional, Literal
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
@@ -11,6 +11,10 @@ class ProjectBase(BaseModel):
|
||||
theme: Optional[str] = Field(None, description="主题")
|
||||
genre: Optional[str] = Field(None, description="小说类型")
|
||||
target_words: Optional[int] = Field(None, description="目标字数")
|
||||
outline_mode: Literal["one-to-one", "one-to-many"] = Field(
|
||||
default="one-to-many",
|
||||
description="大纲章节模式: one-to-one(传统模式,1大纲→1章节) 或 one-to-many(细化模式,1大纲→N章节)"
|
||||
)
|
||||
|
||||
|
||||
class ProjectCreate(ProjectBase):
|
||||
@@ -51,11 +55,11 @@ class ProjectResponse(ProjectBase):
|
||||
chapter_count: Optional[int] = None
|
||||
narrative_perspective: Optional[str] = None
|
||||
character_count: Optional[int] = None
|
||||
outline_mode: str # 显式声明以确保响应中包含
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class ProjectListResponse(BaseModel):
|
||||
@@ -73,6 +77,10 @@ class ProjectWizardRequest(BaseModel):
|
||||
narrative_perspective: str = Field(..., description="叙事视角")
|
||||
character_count: int = Field(5, ge=5, description="角色数量(至少5个)")
|
||||
target_words: Optional[int] = Field(None, description="目标字数")
|
||||
outline_mode: Literal["one-to-one", "one-to-many"] = Field(
|
||||
default="one-to-many",
|
||||
description="大纲章节模式"
|
||||
)
|
||||
|
||||
|
||||
class WorldBuildingResponse(BaseModel):
|
||||
|
||||
@@ -0,0 +1,89 @@
|
||||
"""提示词模板相关的Pydantic模型"""
|
||||
from pydantic import BaseModel, Field, ConfigDict
|
||||
from typing import Optional, List
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
class PromptTemplateBase(BaseModel):
|
||||
"""提示词模板基础模型"""
|
||||
template_key: str = Field(..., description="模板键名")
|
||||
template_name: str = Field(..., description="模板显示名称")
|
||||
template_content: str = Field(..., description="模板内容")
|
||||
description: Optional[str] = Field(None, description="模板描述")
|
||||
category: Optional[str] = Field(None, description="模板分类")
|
||||
parameters: Optional[str] = Field(None, description="模板参数定义(JSON)")
|
||||
is_active: bool = Field(True, description="是否启用")
|
||||
|
||||
|
||||
class PromptTemplateCreate(PromptTemplateBase):
|
||||
"""创建提示词模板请求模型"""
|
||||
pass
|
||||
|
||||
|
||||
class PromptTemplateUpdate(BaseModel):
|
||||
"""更新提示词模板请求模型"""
|
||||
template_name: Optional[str] = Field(None, description="模板显示名称")
|
||||
template_content: Optional[str] = Field(None, description="模板内容")
|
||||
description: Optional[str] = Field(None, description="模板描述")
|
||||
category: Optional[str] = Field(None, description="模板分类")
|
||||
parameters: Optional[str] = Field(None, description="模板参数定义(JSON)")
|
||||
is_active: Optional[bool] = Field(None, description="是否启用")
|
||||
|
||||
|
||||
class PromptTemplateResponse(PromptTemplateBase):
|
||||
"""提示词模板响应模型"""
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: str
|
||||
user_id: str
|
||||
is_system_default: bool
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
|
||||
class PromptTemplateListResponse(BaseModel):
|
||||
"""提示词模板列表响应"""
|
||||
templates: List[PromptTemplateResponse]
|
||||
total: int
|
||||
categories: List[str]
|
||||
|
||||
|
||||
class PromptTemplateCategoryResponse(BaseModel):
|
||||
"""提示词模板分类响应"""
|
||||
category: str
|
||||
count: int
|
||||
templates: List[PromptTemplateResponse]
|
||||
|
||||
|
||||
class PromptTemplateExportItem(BaseModel):
|
||||
"""提示词模板导出项模型"""
|
||||
template_key: str = Field(..., description="模板键名")
|
||||
template_name: str = Field(..., description="模板显示名称")
|
||||
template_content: str = Field(..., description="模板内容")
|
||||
description: Optional[str] = Field(None, description="模板描述")
|
||||
category: Optional[str] = Field(None, description="模板分类")
|
||||
parameters: Optional[str] = Field(None, description="模板参数定义(JSON)")
|
||||
is_active: bool = Field(True, description="是否启用")
|
||||
is_customized: bool = Field(..., description="是否为用户自定义(false=系统默认,true=用户自定义)")
|
||||
system_content_hash: Optional[str] = Field(None, description="系统默认内容的哈希值,用于比对")
|
||||
|
||||
|
||||
class PromptTemplateExport(BaseModel):
|
||||
"""提示词模板导出模型"""
|
||||
templates: List[PromptTemplateExportItem]
|
||||
export_time: datetime
|
||||
version: str = "2.0"
|
||||
statistics: Optional[dict] = Field(None, description="导出统计信息")
|
||||
|
||||
|
||||
class PromptTemplateImportResult(BaseModel):
|
||||
"""提示词模板导入结果"""
|
||||
message: str
|
||||
statistics: dict = Field(..., description="导入统计信息")
|
||||
converted_templates: List[dict] = Field(default_factory=list, description="被转换为自定义的模板列表")
|
||||
|
||||
|
||||
class PromptTemplatePreviewRequest(BaseModel):
|
||||
"""提示词模板预览请求"""
|
||||
template_content: str = Field(..., description="模板内容")
|
||||
parameters: dict = Field(..., description="参数字典")
|
||||
@@ -0,0 +1,65 @@
|
||||
"""章节重新生成相关的Schema定义"""
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Optional, List, Dict, Any
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
class PreserveElementsConfig(BaseModel):
|
||||
"""保留元素配置"""
|
||||
preserve_structure: bool = Field(False, description="是否保留整体结构")
|
||||
preserve_dialogues: List[str] = Field(default_factory=list, description="需要保留的对话片段关键词")
|
||||
preserve_plot_points: List[str] = Field(default_factory=list, description="需要保留的情节点关键词")
|
||||
preserve_character_traits: bool = Field(True, description="保持角色性格一致")
|
||||
|
||||
|
||||
class ChapterRegenerateRequest(BaseModel):
|
||||
"""章节重新生成请求"""
|
||||
|
||||
# 修改来源
|
||||
modification_source: str = Field("custom", description="修改来源: custom/analysis_suggestions/mixed")
|
||||
|
||||
# 基于分析建议
|
||||
selected_suggestion_indices: Optional[List[int]] = Field(None, description="选中的建议索引列表")
|
||||
|
||||
# 自定义修改指令
|
||||
custom_instructions: Optional[str] = Field(None, description="用户自定义的修改要求")
|
||||
|
||||
# 保留配置
|
||||
preserve_elements: Optional[PreserveElementsConfig] = Field(None, description="保留元素配置")
|
||||
|
||||
# 生成参数
|
||||
style_id: Optional[int] = Field(None, description="写作风格ID")
|
||||
target_word_count: int = Field(3000, description="目标字数", ge=500, le=10000)
|
||||
focus_areas: List[str] = Field(default_factory=list, description="重点优化方向")
|
||||
|
||||
# 版本管理
|
||||
save_as_version: bool = Field(True, description="是否保存为新版本")
|
||||
version_note: Optional[str] = Field(None, description="版本说明", max_length=500)
|
||||
auto_apply: bool = Field(False, description="是否自动应用(替换当前内容)")
|
||||
|
||||
|
||||
class RegenerationTaskResponse(BaseModel):
|
||||
"""重新生成任务响应"""
|
||||
task_id: str
|
||||
chapter_id: str
|
||||
status: str
|
||||
message: str
|
||||
estimated_time_seconds: int = 120
|
||||
|
||||
|
||||
class RegenerationTaskStatus(BaseModel):
|
||||
"""重新生成任务状态"""
|
||||
task_id: str
|
||||
chapter_id: str
|
||||
status: str
|
||||
progress: int
|
||||
error_message: Optional[str] = None
|
||||
created_at: Optional[datetime] = None
|
||||
started_at: Optional[datetime] = None
|
||||
completed_at: Optional[datetime] = None
|
||||
|
||||
# 结果信息
|
||||
original_word_count: Optional[int] = None
|
||||
regenerated_word_count: Optional[int] = None
|
||||
version_number: Optional[int] = None
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
"""关系管理相关的Pydantic模型"""
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, ConfigDict
|
||||
from typing import Optional, List
|
||||
from datetime import datetime
|
||||
|
||||
@@ -17,8 +17,7 @@ class RelationshipTypeResponse(BaseModel):
|
||||
description: Optional[str] = None
|
||||
created_at: datetime
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
# ============ 角色关系相关 ============
|
||||
@@ -27,7 +26,7 @@ class CharacterRelationshipBase(BaseModel):
|
||||
"""角色关系基础模型"""
|
||||
relationship_type_id: Optional[int] = Field(None, description="关系类型ID")
|
||||
relationship_name: Optional[str] = Field(None, description="自定义关系名称")
|
||||
intimacy_level: int = Field(50, ge=0, le=100, description="亲密度:0-100")
|
||||
intimacy_level: int = Field(50, ge=-100, le=100, description="亲密度:-100到100")
|
||||
status: str = Field("active", description="状态:active/broken/past/complicated")
|
||||
description: Optional[str] = Field(None, description="关系描述")
|
||||
started_at: Optional[str] = Field(None, description="关系开始时间(故事时间)")
|
||||
@@ -45,7 +44,7 @@ class CharacterRelationshipUpdate(BaseModel):
|
||||
"""更新角色关系的请求模型"""
|
||||
relationship_type_id: Optional[int] = None
|
||||
relationship_name: Optional[str] = None
|
||||
intimacy_level: Optional[int] = Field(None, ge=0, le=100)
|
||||
intimacy_level: Optional[int] = Field(None, ge=-100, le=100)
|
||||
status: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
started_at: Optional[str] = None
|
||||
@@ -62,8 +61,7 @@ class CharacterRelationshipResponse(CharacterRelationshipBase):
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class RelationshipGraphNode(BaseModel):
|
||||
@@ -127,8 +125,7 @@ class OrganizationResponse(OrganizationBase):
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class OrganizationDetailResponse(BaseModel):
|
||||
@@ -185,8 +182,7 @@ class OrganizationMemberResponse(OrganizationMemberBase):
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class OrganizationMemberDetailResponse(BaseModel):
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
"""设置相关的Pydantic模型"""
|
||||
from pydantic import BaseModel, Field, ConfigDict
|
||||
from typing import Optional
|
||||
from typing import Optional, List
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
@@ -14,6 +14,7 @@ class SettingsBase(BaseModel):
|
||||
llm_model: Optional[str] = Field(default="gpt-4", description="模型名称")
|
||||
temperature: Optional[float] = Field(default=0.7, ge=0.0, le=2.0, description="温度参数")
|
||||
max_tokens: Optional[int] = Field(default=2000, ge=1, description="最大token数")
|
||||
system_prompt: Optional[str] = Field(default=None, description="系统级别提示词,每次AI调用都会使用")
|
||||
preferences: Optional[str] = Field(default=None, description="其他偏好设置(JSON)")
|
||||
|
||||
|
||||
@@ -34,4 +35,62 @@ class SettingsResponse(SettingsBase):
|
||||
id: str
|
||||
user_id: str
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
|
||||
# ========== API配置预设相关模型 ==========
|
||||
|
||||
class APIKeyPresetConfig(BaseModel):
|
||||
"""预设配置内容"""
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
api_provider: str = Field(..., description="API提供商")
|
||||
api_key: str = Field(..., description="API密钥")
|
||||
api_base_url: Optional[str] = Field(None, description="自定义API地址")
|
||||
llm_model: str = Field(..., description="模型名称")
|
||||
temperature: float = Field(default=0.7, ge=0.0, le=2.0, description="温度参数")
|
||||
max_tokens: int = Field(default=2000, ge=1, description="最大token数")
|
||||
|
||||
|
||||
class APIKeyPreset(BaseModel):
|
||||
"""API配置预设"""
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
id: str = Field(..., description="预设ID")
|
||||
name: str = Field(..., min_length=1, max_length=50, description="预设名称")
|
||||
description: Optional[str] = Field(None, max_length=200, description="预设描述")
|
||||
is_active: bool = Field(default=False, description="是否激活")
|
||||
created_at: datetime = Field(..., description="创建时间")
|
||||
config: APIKeyPresetConfig = Field(..., description="配置内容")
|
||||
|
||||
|
||||
class PresetCreateRequest(BaseModel):
|
||||
"""创建预设请求"""
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
name: str = Field(..., min_length=1, max_length=50, description="预设名称")
|
||||
description: Optional[str] = Field(None, max_length=200, description="预设描述")
|
||||
config: APIKeyPresetConfig = Field(..., description="配置内容")
|
||||
|
||||
|
||||
class PresetUpdateRequest(BaseModel):
|
||||
"""更新预设请求"""
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
name: Optional[str] = Field(None, min_length=1, max_length=50, description="预设名称")
|
||||
description: Optional[str] = Field(None, max_length=200, description="预设描述")
|
||||
config: Optional[APIKeyPresetConfig] = Field(None, description="配置内容")
|
||||
|
||||
|
||||
class PresetResponse(APIKeyPreset):
|
||||
"""预设响应"""
|
||||
pass
|
||||
|
||||
|
||||
class PresetListResponse(BaseModel):
|
||||
"""预设列表响应"""
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
presets: List[PresetResponse] = Field(..., description="预设列表")
|
||||
total: int = Field(..., description="总数")
|
||||
active_preset_id: Optional[str] = Field(None, description="当前激活的预设ID")
|
||||
@@ -1,5 +1,5 @@
|
||||
"""写作风格 Schema"""
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, ConfigDict
|
||||
from typing import Optional
|
||||
from datetime import datetime
|
||||
|
||||
@@ -13,9 +13,13 @@ class WritingStyleBase(BaseModel):
|
||||
prompt_content: str = Field(..., description="风格提示词内容")
|
||||
|
||||
|
||||
class WritingStyleCreate(WritingStyleBase):
|
||||
"""创建写作风格(仅用于创建项目自定义风格)"""
|
||||
project_id: str = Field(..., description="所属项目ID")
|
||||
class WritingStyleCreate(BaseModel):
|
||||
"""创建写作风格(仅用于创建用户自定义风格)"""
|
||||
name: str = Field(..., description="风格名称")
|
||||
style_type: Optional[str] = Field(None, description="风格类型:preset/custom")
|
||||
preset_id: Optional[str] = Field(None, description="预设风格ID")
|
||||
description: Optional[str] = Field(None, description="风格描述")
|
||||
prompt_content: str = Field(..., description="风格提示词内容")
|
||||
|
||||
|
||||
class WritingStyleUpdate(BaseModel):
|
||||
@@ -33,7 +37,7 @@ class SetDefaultStyleRequest(BaseModel):
|
||||
class WritingStyleResponse(BaseModel):
|
||||
"""写作风格响应"""
|
||||
id: int
|
||||
project_id: Optional[str] = None # NULL 表示全局预设风格
|
||||
user_id: Optional[str] = None # NULL 表示全局预设风格
|
||||
name: str
|
||||
style_type: str
|
||||
preset_id: Optional[str] = None
|
||||
@@ -44,8 +48,7 @@ class WritingStyleResponse(BaseModel):
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class WritingStyleListResponse(BaseModel):
|
||||
|
||||
@@ -0,0 +1,6 @@
|
||||
"""AI 客户端模块"""
|
||||
from .base_client import BaseAIClient
|
||||
from .openai_client import OpenAIClient
|
||||
from .anthropic_client import AnthropicClient
|
||||
|
||||
__all__ = ["BaseAIClient", "OpenAIClient", "AnthropicClient"]
|
||||
@@ -0,0 +1,142 @@
|
||||
"""Anthropic 客户端"""
|
||||
from typing import Any, AsyncGenerator, Dict, Optional
|
||||
|
||||
from anthropic import AsyncAnthropic
|
||||
|
||||
from app.logger import get_logger
|
||||
from app.services.ai_config import AIClientConfig, default_config
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class AnthropicClient:
|
||||
"""Anthropic API 客户端"""
|
||||
|
||||
def __init__(self, api_key: str, base_url: Optional[str] = None, config: Optional[AIClientConfig] = None):
|
||||
self.config = config or default_config
|
||||
kwargs = {"api_key": api_key}
|
||||
if base_url:
|
||||
kwargs["base_url"] = base_url
|
||||
self.client = AsyncAnthropic(**kwargs)
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
messages: list,
|
||||
model: str,
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
system_prompt: Optional[str] = None,
|
||||
tools: Optional[list] = None,
|
||||
tool_choice: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
kwargs = {
|
||||
"model": model,
|
||||
"max_tokens": max_tokens,
|
||||
"temperature": temperature,
|
||||
"messages": messages,
|
||||
}
|
||||
if system_prompt:
|
||||
kwargs["system"] = system_prompt
|
||||
if tools:
|
||||
kwargs["tools"] = tools
|
||||
if tool_choice == "required":
|
||||
kwargs["tool_choice"] = {"type": "any"}
|
||||
elif tool_choice == "auto":
|
||||
kwargs["tool_choice"] = {"type": "auto"}
|
||||
|
||||
response = await self.client.messages.create(**kwargs)
|
||||
|
||||
tool_calls = []
|
||||
content = ""
|
||||
for block in response.content:
|
||||
if block.type == "tool_use":
|
||||
tool_calls.append({
|
||||
"id": block.id,
|
||||
"type": "function",
|
||||
"function": {"name": block.name, "arguments": block.input},
|
||||
})
|
||||
elif block.type == "text":
|
||||
content += block.text
|
||||
|
||||
return {
|
||||
"content": content,
|
||||
"tool_calls": tool_calls if tool_calls else None,
|
||||
"finish_reason": response.stop_reason,
|
||||
}
|
||||
|
||||
async def chat_completion_stream(
|
||||
self,
|
||||
messages: list,
|
||||
model: str,
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
system_prompt: Optional[str] = None,
|
||||
tools: Optional[list] = None,
|
||||
tool_choice: Optional[str] = None,
|
||||
) -> AsyncGenerator[Dict[str, Any], None]:
|
||||
"""
|
||||
流式生成,支持工具调用
|
||||
|
||||
Yields:
|
||||
Dict with keys:
|
||||
- content: str - 文本内容块
|
||||
- tool_calls: list - 工具调用列表(如果有)
|
||||
- done: bool - 是否结束
|
||||
"""
|
||||
kwargs = {
|
||||
"model": model,
|
||||
"max_tokens": max_tokens,
|
||||
"temperature": temperature,
|
||||
"messages": messages,
|
||||
}
|
||||
if system_prompt:
|
||||
kwargs["system"] = system_prompt
|
||||
if tools:
|
||||
kwargs["tools"] = tools
|
||||
if tool_choice == "required":
|
||||
kwargs["tool_choice"] = {"type": "any"}
|
||||
elif tool_choice == "auto":
|
||||
kwargs["tool_choice"] = {"type": "auto"}
|
||||
|
||||
try:
|
||||
async with self.client.messages.stream(**kwargs) as stream:
|
||||
try:
|
||||
tool_calls = []
|
||||
async for chunk in stream:
|
||||
# 处理不同类型的块
|
||||
if chunk.type == "text_delta":
|
||||
yield {"content": chunk.text}
|
||||
elif chunk.type == "tool_use_delta":
|
||||
# 工具调用增量
|
||||
if not tool_calls or tool_calls[-1].get("id") != chunk.id:
|
||||
tool_calls.append({
|
||||
"id": chunk.id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": chunk.name,
|
||||
"arguments": ""
|
||||
}
|
||||
})
|
||||
# 追加参数
|
||||
if tool_calls[-1]["function"]["arguments"] is None:
|
||||
tool_calls[-1]["function"]["arguments"] = ""
|
||||
tool_calls[-1]["function"]["arguments"] += chunk.input_gets_new_text or ""
|
||||
elif chunk.type == "message_delta":
|
||||
if chunk.stop_reason:
|
||||
# 流结束
|
||||
if tool_calls:
|
||||
yield {"tool_calls": tool_calls}
|
||||
yield {"done": True, "finish_reason": chunk.stop_reason}
|
||||
except GeneratorExit:
|
||||
# 生成器被关闭,这是正常的清理过程
|
||||
logger.debug("Anthropic 流式响应生成器被关闭(GeneratorExit)")
|
||||
raise
|
||||
except Exception as iter_error:
|
||||
logger.error(f"Anthropic 流式响应迭代出错: {str(iter_error)}")
|
||||
raise
|
||||
except GeneratorExit:
|
||||
# 重新抛出GeneratorExit,让调用方处理
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Anthropic 流式请求出错: {str(e)}")
|
||||
raise
|
||||
@@ -0,0 +1,154 @@
|
||||
"""AI 客户端基类"""
|
||||
import asyncio
|
||||
import hashlib
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, AsyncGenerator, Dict, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
from app.logger import get_logger
|
||||
from app.services.ai_config import AIClientConfig, default_config
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# 全局 HTTP 客户端池
|
||||
_http_client_pool: Dict[str, httpx.AsyncClient] = {}
|
||||
_global_semaphore: Optional[asyncio.Semaphore] = None
|
||||
|
||||
|
||||
def _get_semaphore(max_concurrent: int) -> asyncio.Semaphore:
|
||||
"""获取全局信号量"""
|
||||
global _global_semaphore
|
||||
if _global_semaphore is None:
|
||||
_global_semaphore = asyncio.Semaphore(max_concurrent)
|
||||
return _global_semaphore
|
||||
|
||||
|
||||
class BaseAIClient(ABC):
|
||||
"""AI HTTP 客户端基类"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
base_url: str,
|
||||
config: Optional[AIClientConfig] = None,
|
||||
):
|
||||
self.api_key = api_key
|
||||
self.base_url = base_url.rstrip("/")
|
||||
self.config = config or default_config
|
||||
self.http_client = self._get_or_create_client()
|
||||
|
||||
def _get_client_key(self) -> str:
|
||||
"""生成客户端唯一键"""
|
||||
key_hash = hashlib.md5(self.api_key.encode()).hexdigest()[:8]
|
||||
return f"{self.__class__.__name__}_{self.base_url}_{key_hash}"
|
||||
|
||||
def _get_or_create_client(self) -> httpx.AsyncClient:
|
||||
"""获取或创建 HTTP 客户端"""
|
||||
client_key = self._get_client_key()
|
||||
|
||||
if client_key in _http_client_pool:
|
||||
client = _http_client_pool[client_key]
|
||||
if not client.is_closed:
|
||||
return client
|
||||
del _http_client_pool[client_key]
|
||||
|
||||
http_cfg = self.config.http
|
||||
client = httpx.AsyncClient(
|
||||
timeout=httpx.Timeout(
|
||||
connect=http_cfg.connect_timeout,
|
||||
read=http_cfg.read_timeout,
|
||||
write=http_cfg.write_timeout,
|
||||
pool=http_cfg.pool_timeout,
|
||||
),
|
||||
limits=httpx.Limits(
|
||||
max_keepalive_connections=http_cfg.max_keepalive_connections,
|
||||
max_connections=http_cfg.max_connections,
|
||||
keepalive_expiry=http_cfg.keepalive_expiry,
|
||||
),
|
||||
)
|
||||
_http_client_pool[client_key] = client
|
||||
logger.info(f"✅ 创建 HTTP 客户端: {client_key}")
|
||||
return client
|
||||
|
||||
@abstractmethod
|
||||
def _build_headers(self) -> Dict[str, str]:
|
||||
"""构建请求头"""
|
||||
pass
|
||||
|
||||
async def _request_with_retry(
|
||||
self,
|
||||
method: str,
|
||||
endpoint: str,
|
||||
payload: Dict[str, Any],
|
||||
stream: bool = False,
|
||||
) -> Any:
|
||||
"""带重试的 HTTP 请求"""
|
||||
url = f"{self.base_url}{endpoint}"
|
||||
headers = self._build_headers()
|
||||
retry_cfg = self.config.retry
|
||||
rate_cfg = self.config.rate_limit
|
||||
|
||||
semaphore = _get_semaphore(rate_cfg.max_concurrent_requests)
|
||||
|
||||
async with semaphore:
|
||||
await asyncio.sleep(rate_cfg.request_delay)
|
||||
|
||||
for attempt in range(retry_cfg.max_retries):
|
||||
try:
|
||||
if attempt > 0:
|
||||
delay = min(
|
||||
retry_cfg.base_delay * (retry_cfg.exponential_base ** attempt),
|
||||
retry_cfg.max_delay,
|
||||
)
|
||||
logger.warning(f"⚠️ 重试 {attempt + 1}/{retry_cfg.max_retries},等待 {delay}s")
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
if stream:
|
||||
return self.http_client.stream(method, url, headers=headers, json=payload)
|
||||
|
||||
response = await self.http_client.request(method, url, headers=headers, json=payload)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
if e.response.status_code in retry_cfg.non_retryable_status_codes:
|
||||
raise
|
||||
if attempt == retry_cfg.max_retries - 1:
|
||||
raise
|
||||
except (httpx.ConnectError, httpx.TimeoutException):
|
||||
if attempt == retry_cfg.max_retries - 1:
|
||||
raise
|
||||
|
||||
@abstractmethod
|
||||
async def chat_completion(
|
||||
self,
|
||||
messages: list,
|
||||
model: str,
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
tools: Optional[list] = None,
|
||||
tool_choice: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""聊天补全"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def chat_completion_stream(
|
||||
self,
|
||||
messages: list,
|
||||
model: str,
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""流式聊天补全"""
|
||||
pass
|
||||
|
||||
|
||||
async def cleanup_all_clients():
|
||||
"""清理所有 HTTP 客户端"""
|
||||
for key, client in list(_http_client_pool.items()):
|
||||
if not client.is_closed:
|
||||
await client.aclose()
|
||||
_http_client_pool.clear()
|
||||
logger.info("✅ HTTP 客户端池已清理")
|
||||
@@ -0,0 +1,189 @@
|
||||
"""Gemini 客户端"""
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional
|
||||
import httpx
|
||||
from app.services.ai_config import AIClientConfig, default_config
|
||||
from app.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class GeminiClient:
|
||||
"""Google Gemini API 客户端"""
|
||||
|
||||
def __init__(self, api_key: str, base_url: Optional[str] = None, config: Optional[AIClientConfig] = None):
|
||||
self.api_key = api_key
|
||||
self.base_url = (base_url or "https://generativelanguage.googleapis.com/v1beta").rstrip("/")
|
||||
self.config = config or default_config
|
||||
http_cfg = self.config.http
|
||||
self.client = httpx.AsyncClient(
|
||||
timeout=httpx.Timeout(
|
||||
connect=http_cfg.connect_timeout,
|
||||
read=http_cfg.read_timeout,
|
||||
write=http_cfg.write_timeout,
|
||||
pool=http_cfg.pool_timeout
|
||||
)
|
||||
)
|
||||
|
||||
def _convert_tools_to_gemini(self, tools: list) -> list:
|
||||
"""将 OpenAI 格式工具转换为 Gemini 格式"""
|
||||
gemini_tools = []
|
||||
for tool in tools:
|
||||
if tool.get("type") == "function":
|
||||
func = tool["function"]
|
||||
params = func.get("parameters", {}).copy() if func.get("parameters") else {}
|
||||
params.pop("$schema", None)
|
||||
params.pop("additionalProperties", None)
|
||||
if params and "type" not in params:
|
||||
params["type"] = "object"
|
||||
decl = {
|
||||
"name": func["name"],
|
||||
"description": func.get("description") or func["name"],
|
||||
}
|
||||
if params:
|
||||
decl["parameters"] = params
|
||||
gemini_tools.append(decl)
|
||||
return [{"functionDeclarations": gemini_tools}] if gemini_tools else []
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
messages: list,
|
||||
model: str,
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
system_prompt: Optional[str] = None,
|
||||
tools: Optional[list] = None,
|
||||
tool_choice: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
url = f"{self.base_url}/models/{model}:generateContent?key={self.api_key}"
|
||||
|
||||
contents = []
|
||||
for msg in messages:
|
||||
role = "user" if msg["role"] == "user" else "model"
|
||||
contents.append({"role": role, "parts": [{"text": msg["content"]}]})
|
||||
|
||||
payload = {
|
||||
"contents": contents,
|
||||
"generationConfig": {"temperature": temperature, "maxOutputTokens": max_tokens}
|
||||
}
|
||||
if system_prompt:
|
||||
payload["systemInstruction"] = {"parts": [{"text": system_prompt}]}
|
||||
if tools:
|
||||
payload["tools"] = self._convert_tools_to_gemini(tools)
|
||||
|
||||
response = await self.client.post(url, json=payload)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
candidates = data.get("candidates", [])
|
||||
if not candidates or len(candidates) == 0:
|
||||
# 返回空内容而不是报错,保持流程继续
|
||||
return {
|
||||
"content": "",
|
||||
"tool_calls": None,
|
||||
"finish_reason": "stop"
|
||||
}
|
||||
|
||||
parts = candidates[0].get("content", {}).get("parts", [])
|
||||
text = ""
|
||||
tool_calls = []
|
||||
|
||||
for part in parts:
|
||||
if "text" in part:
|
||||
text += part["text"]
|
||||
elif "functionCall" in part:
|
||||
fc = part["functionCall"]
|
||||
tool_calls.append({
|
||||
"id": f"call_{fc['name']}",
|
||||
"type": "function",
|
||||
"function": {"name": fc["name"], "arguments": fc.get("args", {})}
|
||||
})
|
||||
|
||||
return {
|
||||
"content": text,
|
||||
"tool_calls": tool_calls if tool_calls else None,
|
||||
"finish_reason": "tool_calls" if tool_calls else "stop"
|
||||
}
|
||||
|
||||
async def chat_completion_stream(
|
||||
self,
|
||||
messages: list,
|
||||
model: str,
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
system_prompt: Optional[str] = None,
|
||||
tools: Optional[list] = None,
|
||||
tool_choice: Optional[str] = None,
|
||||
) -> AsyncGenerator[Dict[str, Any], None]:
|
||||
"""
|
||||
流式生成,支持工具调用
|
||||
|
||||
Yields:
|
||||
Dict with keys:
|
||||
- content: str - 文本内容块
|
||||
- tool_calls: list - 工具调用列表(如果有)
|
||||
- done: bool - 是否结束
|
||||
"""
|
||||
url = f"{self.base_url}/models/{model}:streamGenerateContent?key={self.api_key}&alt=sse"
|
||||
|
||||
contents = []
|
||||
for msg in messages:
|
||||
role = "user" if msg["role"] == "user" else "model"
|
||||
contents.append({"role": role, "parts": [{"text": msg["content"]}]})
|
||||
|
||||
payload = {
|
||||
"contents": contents,
|
||||
"generationConfig": {"temperature": temperature, "maxOutputTokens": max_tokens}
|
||||
}
|
||||
if system_prompt:
|
||||
payload["systemInstruction"] = {"parts": [{"text": system_prompt}]}
|
||||
if tools:
|
||||
payload["tools"] = self._convert_tools_to_gemini(tools)
|
||||
|
||||
try:
|
||||
async with self.client.stream("POST", url, json=payload) as response:
|
||||
response.raise_for_status()
|
||||
try:
|
||||
async for line in response.aiter_lines():
|
||||
if line.startswith("data: "):
|
||||
import json
|
||||
try:
|
||||
data = json.loads(line[6:])
|
||||
candidates = data.get("candidates", [])
|
||||
if candidates and len(candidates) > 0:
|
||||
parts = candidates[0].get("content", {}).get("parts", [])
|
||||
if parts and len(parts) > 0:
|
||||
text = ""
|
||||
function_calls = []
|
||||
for part in parts:
|
||||
if "text" in part:
|
||||
text += part["text"]
|
||||
elif "functionCall" in part:
|
||||
fc = part["functionCall"]
|
||||
function_calls.append({
|
||||
"id": f"call_{fc['name']}",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": fc["name"],
|
||||
"arguments": fc.get("args", {})
|
||||
}
|
||||
})
|
||||
|
||||
if text:
|
||||
yield {"content": text}
|
||||
if function_calls:
|
||||
yield {"tool_calls": function_calls}
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
except GeneratorExit:
|
||||
# 生成器被关闭,这是正常的清理过程
|
||||
logger.debug("Gemini 流式响应生成器被关闭(GeneratorExit)")
|
||||
raise
|
||||
except Exception as iter_error:
|
||||
logger.error(f"Gemini 流式响应迭代出错: {str(iter_error)}")
|
||||
raise
|
||||
except GeneratorExit:
|
||||
# 重新抛出GeneratorExit,让调用方处理
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Gemini 流式请求出错: {str(e)}")
|
||||
raise
|
||||
@@ -0,0 +1,159 @@
|
||||
"""OpenAI 客户端"""
|
||||
import json
|
||||
from typing import Any, AsyncGenerator, Dict, Optional
|
||||
|
||||
from app.logger import get_logger
|
||||
from .base_client import BaseAIClient
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class OpenAIClient(BaseAIClient):
|
||||
"""OpenAI API 客户端"""
|
||||
|
||||
def _build_headers(self) -> Dict[str, str]:
|
||||
return {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
def _build_payload(
|
||||
self,
|
||||
messages: list,
|
||||
model: str,
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
tools: Optional[list] = None,
|
||||
tool_choice: Optional[str] = None,
|
||||
stream: bool = False,
|
||||
) -> Dict[str, Any]:
|
||||
payload = {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"temperature": temperature,
|
||||
"max_tokens": max_tokens,
|
||||
}
|
||||
if stream:
|
||||
payload["stream"] = True
|
||||
if tools:
|
||||
# 清理 $schema 字段
|
||||
cleaned = []
|
||||
for t in tools:
|
||||
tc = t.copy()
|
||||
if "function" in tc and "parameters" in tc["function"]:
|
||||
tc["function"]["parameters"] = {
|
||||
k: v for k, v in tc["function"]["parameters"].items() if k != "$schema"
|
||||
}
|
||||
cleaned.append(tc)
|
||||
payload["tools"] = cleaned
|
||||
if tool_choice:
|
||||
payload["tool_choice"] = tool_choice
|
||||
return payload
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
messages: list,
|
||||
model: str,
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
tools: Optional[list] = None,
|
||||
tool_choice: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
payload = self._build_payload(messages, model, temperature, max_tokens, tools, tool_choice)
|
||||
|
||||
logger.debug(f"📤 OpenAI 请求 payload: {json.dumps(payload, ensure_ascii=False, indent=2)}")
|
||||
|
||||
data = await self._request_with_retry("POST", "/chat/completions", payload)
|
||||
|
||||
# 调试日志:输出原始响应
|
||||
logger.debug(f"📥 OpenAI 原始响应: {json.dumps(data, ensure_ascii=False, indent=2)}")
|
||||
|
||||
choices = data.get("choices", [])
|
||||
if not choices or len(choices) == 0:
|
||||
raise ValueError("API 返回空 choices 或 choices 为空列表")
|
||||
|
||||
choice = choices[0]
|
||||
message = choice.get("message", {})
|
||||
return {
|
||||
"content": message.get("content", ""),
|
||||
"tool_calls": message.get("tool_calls"),
|
||||
"finish_reason": choice.get("finish_reason"),
|
||||
}
|
||||
|
||||
async def chat_completion_stream(
|
||||
self,
|
||||
messages: list,
|
||||
model: str,
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
tools: Optional[list] = None,
|
||||
tool_choice: Optional[str] = None,
|
||||
) -> AsyncGenerator[Dict[str, Any], None]:
|
||||
"""
|
||||
流式生成,支持工具调用
|
||||
|
||||
Yields:
|
||||
Dict with keys:
|
||||
- content: str - 文本内容块
|
||||
- tool_calls: list - 工具调用列表(如果有)
|
||||
- done: bool - 是否结束
|
||||
"""
|
||||
payload = self._build_payload(messages, model, temperature, max_tokens, tools, tool_choice, stream=True)
|
||||
|
||||
tool_calls_buffer = {} # 收集工具调用块
|
||||
|
||||
try:
|
||||
async with await self._request_with_retry("POST", "/chat/completions", payload, stream=True) as response:
|
||||
response.raise_for_status()
|
||||
try:
|
||||
async for line in response.aiter_lines():
|
||||
if line.startswith("data: "):
|
||||
data_str = line[6:]
|
||||
if data_str.strip() == "[DONE]":
|
||||
# 流结束,检查是否有工具调用需要处理
|
||||
if tool_calls_buffer:
|
||||
yield {"tool_calls": list(tool_calls_buffer.values()), "done": True}
|
||||
yield {"done": True}
|
||||
break
|
||||
try:
|
||||
data = json.loads(data_str)
|
||||
choices = data.get("choices", [])
|
||||
if choices and len(choices) > 0:
|
||||
delta = choices[0].get("delta", {})
|
||||
content = delta.get("content", "")
|
||||
|
||||
# 检查工具调用
|
||||
tc_list = delta.get("tool_calls")
|
||||
if tc_list:
|
||||
for tc in tc_list:
|
||||
index = tc.get("index", 0)
|
||||
if index not in tool_calls_buffer:
|
||||
tool_calls_buffer[index] = tc
|
||||
else:
|
||||
existing = tool_calls_buffer[index]
|
||||
# 合并 function.arguments
|
||||
if "function" in tc and "function" in existing:
|
||||
if tc["function"].get("arguments"):
|
||||
existing["function"]["arguments"] = (
|
||||
existing["function"].get("arguments", "") +
|
||||
tc["function"]["arguments"]
|
||||
)
|
||||
|
||||
if content:
|
||||
yield {"content": content}
|
||||
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
except GeneratorExit:
|
||||
# 生成器被关闭,这是正常的清理过程
|
||||
logger.debug("流式响应生成器被关闭(GeneratorExit)")
|
||||
raise
|
||||
except Exception as iter_error:
|
||||
logger.error(f"流式响应迭代出错: {str(iter_error)}")
|
||||
raise
|
||||
except GeneratorExit:
|
||||
# 重新抛出GeneratorExit,让调用方处理
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"流式请求出错: {str(e)}")
|
||||
raise
|
||||
@@ -0,0 +1,44 @@
|
||||
"""AI 服务配置管理"""
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
class HTTPClientConfig:
|
||||
"""HTTP 客户端配置"""
|
||||
connect_timeout: float = 90.0
|
||||
read_timeout: float = 300.0
|
||||
write_timeout: float = 90.0
|
||||
pool_timeout: float = 90.0
|
||||
max_keepalive_connections: int = 50
|
||||
max_connections: int = 100
|
||||
keepalive_expiry: float = 60.0
|
||||
|
||||
|
||||
@dataclass
|
||||
class RetryConfig:
|
||||
"""重试配置"""
|
||||
max_retries: int = 3
|
||||
base_delay: float = 0.2
|
||||
max_delay: float = 10.0
|
||||
exponential_base: int = 2
|
||||
non_retryable_status_codes: tuple = field(default_factory=lambda: (401, 403, 404))
|
||||
|
||||
|
||||
@dataclass
|
||||
class RateLimitConfig:
|
||||
"""限流配置"""
|
||||
max_concurrent_requests: int = 5
|
||||
request_delay: float = 0.2
|
||||
|
||||
|
||||
@dataclass
|
||||
class AIClientConfig:
|
||||
"""AI 客户端完整配置"""
|
||||
http: HTTPClientConfig = field(default_factory=HTTPClientConfig)
|
||||
retry: RetryConfig = field(default_factory=RetryConfig)
|
||||
rate_limit: RateLimitConfig = field(default_factory=RateLimitConfig)
|
||||
|
||||
|
||||
# 全局默认配置
|
||||
default_config = AIClientConfig()
|
||||
@@ -0,0 +1,6 @@
|
||||
"""AI Provider 模块"""
|
||||
from .base_provider import BaseAIProvider
|
||||
from .openai_provider import OpenAIProvider
|
||||
from .anthropic_provider import AnthropicProvider
|
||||
|
||||
__all__ = ["BaseAIProvider", "OpenAIProvider", "AnthropicProvider"]
|
||||
@@ -0,0 +1,161 @@
|
||||
"""Anthropic Provider"""
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional
|
||||
|
||||
from app.logger import get_logger
|
||||
from app.services.ai_clients.anthropic_client import AnthropicClient
|
||||
from .base_provider import BaseAIProvider
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class AnthropicProvider(BaseAIProvider):
|
||||
"""Anthropic 提供商"""
|
||||
|
||||
def __init__(self, client: AnthropicClient):
|
||||
self.client = client
|
||||
|
||||
async def generate(
|
||||
self,
|
||||
prompt: str,
|
||||
model: str,
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
system_prompt: Optional[str] = None,
|
||||
tools: Optional[List[Dict]] = None,
|
||||
tool_choice: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
return await self.client.chat_completion(
|
||||
messages=messages,
|
||||
model=model,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
system_prompt=system_prompt,
|
||||
tools=tools,
|
||||
tool_choice=tool_choice,
|
||||
)
|
||||
|
||||
async def generate_stream(
|
||||
self,
|
||||
prompt: str,
|
||||
model: str,
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
system_prompt: Optional[str] = None,
|
||||
tools: Optional[List[Dict]] = None,
|
||||
tool_choice: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
# 如果有工具,使用真正的流式工具调用
|
||||
if tools:
|
||||
logger.debug(f"🔧 AnthropicProvider: 有 {len(tools)} 个工具,使用流式处理")
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
actual_tool_choice = tool_choice if tool_choice else "auto"
|
||||
|
||||
tool_calls_buffer = []
|
||||
|
||||
async for chunk in self.client.chat_completion_stream(
|
||||
messages=messages,
|
||||
model=model,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
system_prompt=system_prompt,
|
||||
tools=tools,
|
||||
tool_choice=actual_tool_choice,
|
||||
):
|
||||
# 检查是否有工具调用
|
||||
if chunk.get("tool_calls"):
|
||||
tool_calls_buffer.extend(chunk["tool_calls"])
|
||||
logger.debug(f"🔧 收到工具调用: {len(chunk['tool_calls'])} 个")
|
||||
|
||||
# 检查是否结束
|
||||
if chunk.get("done"):
|
||||
if tool_calls_buffer:
|
||||
logger.info(f"🔧 流式结束,处理 {len(tool_calls_buffer)} 个工具调用")
|
||||
from app.mcp import mcp_client
|
||||
actual_user_id = user_id or ""
|
||||
tool_results = await mcp_client.batch_call_tools(
|
||||
user_id=actual_user_id,
|
||||
tool_calls=tool_calls_buffer
|
||||
)
|
||||
# 将工具结果注入到上下文中
|
||||
tool_context = mcp_client.build_tool_context(tool_results, format="markdown")
|
||||
|
||||
# 构建最终提示词,要求AI基于工具结果回答
|
||||
final_prompt = f"{prompt}\n\n{tool_context}\n\n请基于以上工具查询结果,给出完整详细的回答。"
|
||||
final_messages = [{"role": "user", "content": final_prompt}]
|
||||
|
||||
# 递归调用生成最终结果
|
||||
async for final_chunk in self._generate_with_tools(
|
||||
final_messages, model, temperature, max_tokens, system_prompt, tools, user_id
|
||||
):
|
||||
yield final_chunk
|
||||
break
|
||||
|
||||
# 输出文本内容
|
||||
if chunk.get("content"):
|
||||
yield chunk["content"]
|
||||
return
|
||||
|
||||
# 无工具时普通流式生成
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
async for chunk in self.client.chat_completion_stream(
|
||||
messages=messages,
|
||||
model=model,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
system_prompt=system_prompt,
|
||||
):
|
||||
# 确保只 yield 字符串内容,避免 yield 字典导致类型错误
|
||||
if isinstance(chunk, dict):
|
||||
if chunk.get("content"):
|
||||
yield chunk["content"]
|
||||
else:
|
||||
yield chunk
|
||||
|
||||
async def _generate_with_tools(
|
||||
self,
|
||||
messages: list,
|
||||
model: str,
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
system_prompt: Optional[str] = None,
|
||||
tools: list = None,
|
||||
user_id: Optional[str] = None,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""辅助方法:带工具的流式生成"""
|
||||
tool_calls_buffer = []
|
||||
|
||||
async for chunk in self.client.chat_completion_stream(
|
||||
messages=messages,
|
||||
model=model,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
system_prompt=system_prompt,
|
||||
tools=tools,
|
||||
tool_choice="auto",
|
||||
):
|
||||
if chunk.get("tool_calls"):
|
||||
tool_calls_buffer.extend(chunk["tool_calls"])
|
||||
logger.debug(f"🔧 _generate_with_tools 收到工具调用: {len(chunk['tool_calls'])} 个")
|
||||
|
||||
if chunk.get("done"):
|
||||
if tool_calls_buffer:
|
||||
from app.mcp import mcp_client
|
||||
actual_user_id = user_id or ""
|
||||
tool_results = await mcp_client.batch_call_tools(
|
||||
user_id=actual_user_id,
|
||||
tool_calls=tool_calls_buffer
|
||||
)
|
||||
tool_context = mcp_client.build_tool_context(tool_results, format="markdown")
|
||||
|
||||
messages.append({"role": "user", "content": f"{tool_context}\n\n请基于以上工具查询结果,给出完整详细的回答。"})
|
||||
|
||||
async for final_chunk in self._generate_with_tools(
|
||||
messages, model, temperature, max_tokens, system_prompt, tools, user_id
|
||||
):
|
||||
yield final_chunk
|
||||
break
|
||||
|
||||
if chunk.get("content"):
|
||||
yield chunk["content"]
|
||||
@@ -0,0 +1,36 @@
|
||||
"""AI Provider 基类"""
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional
|
||||
|
||||
|
||||
class BaseAIProvider(ABC):
|
||||
"""AI 提供商抽象基类"""
|
||||
|
||||
@abstractmethod
|
||||
async def generate(
|
||||
self,
|
||||
prompt: str,
|
||||
model: str,
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
system_prompt: Optional[str] = None,
|
||||
tools: Optional[List[Dict]] = None,
|
||||
tool_choice: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""生成文本"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def generate_stream(
|
||||
self,
|
||||
prompt: str,
|
||||
model: str,
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
system_prompt: Optional[str] = None,
|
||||
tools: Optional[List[Dict]] = None,
|
||||
tool_choice: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""流式生成"""
|
||||
pass
|
||||
@@ -0,0 +1,159 @@
|
||||
"""Gemini Provider"""
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional
|
||||
|
||||
from app.logger import get_logger
|
||||
from app.services.ai_clients.gemini_client import GeminiClient
|
||||
from .base_provider import BaseAIProvider
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class GeminiProvider(BaseAIProvider):
|
||||
def __init__(self, client: GeminiClient):
|
||||
self.client = client
|
||||
|
||||
async def generate(
|
||||
self,
|
||||
prompt: str,
|
||||
model: str,
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
system_prompt: Optional[str] = None,
|
||||
tools: Optional[List[Dict]] = None,
|
||||
tool_choice: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
return await self.client.chat_completion(
|
||||
messages=messages,
|
||||
model=model,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
system_prompt=system_prompt,
|
||||
tools=tools,
|
||||
tool_choice=tool_choice,
|
||||
)
|
||||
|
||||
async def generate_stream(
|
||||
self,
|
||||
prompt: str,
|
||||
model: str,
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
system_prompt: Optional[str] = None,
|
||||
tools: Optional[List[Dict]] = None,
|
||||
tool_choice: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
# 如果有工具,使用真正的流式工具调用
|
||||
if tools:
|
||||
logger.debug(f"🔧 GeminiProvider: 有 {len(tools)} 个工具,使用流式处理")
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
actual_tool_choice = tool_choice if tool_choice else "auto"
|
||||
|
||||
tool_calls_buffer = []
|
||||
|
||||
async for chunk in self.client.chat_completion_stream(
|
||||
messages=messages,
|
||||
model=model,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
system_prompt=system_prompt,
|
||||
tools=tools,
|
||||
tool_choice=actual_tool_choice,
|
||||
):
|
||||
# 检查是否有工具调用
|
||||
if chunk.get("tool_calls"):
|
||||
tool_calls_buffer.extend(chunk["tool_calls"])
|
||||
logger.debug(f"🔧 收到工具调用: {len(chunk['tool_calls'])} 个")
|
||||
|
||||
# 检查是否结束
|
||||
if chunk.get("done"):
|
||||
if tool_calls_buffer:
|
||||
logger.info(f"🔧 流式结束,处理 {len(tool_calls_buffer)} 个工具调用")
|
||||
from app.mcp import mcp_client
|
||||
actual_user_id = user_id or ""
|
||||
tool_results = await mcp_client.batch_call_tools(
|
||||
user_id=actual_user_id,
|
||||
tool_calls=tool_calls_buffer
|
||||
)
|
||||
# 将工具结果注入到上下文中
|
||||
tool_context = mcp_client.build_tool_context(tool_results, format="markdown")
|
||||
|
||||
# 构建最终提示词,要求AI基于工具结果回答
|
||||
final_prompt = f"{prompt}\n\n{tool_context}\n\n请基于以上工具查询结果,给出完整详细的回答。"
|
||||
final_messages = [{"role": "user", "content": final_prompt}]
|
||||
|
||||
# 递归调用生成最终结果
|
||||
async for final_chunk in self._generate_with_tools(
|
||||
final_messages, model, temperature, max_tokens, system_prompt, tools, user_id
|
||||
):
|
||||
yield final_chunk
|
||||
break
|
||||
|
||||
# 输出文本内容
|
||||
if chunk.get("content"):
|
||||
yield chunk["content"]
|
||||
return
|
||||
|
||||
# 无工具时普通流式生成
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
async for chunk in self.client.chat_completion_stream(
|
||||
messages=messages,
|
||||
model=model,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
system_prompt=system_prompt,
|
||||
):
|
||||
# 确保只 yield 字符串内容,避免 yield 字典导致类型错误
|
||||
if isinstance(chunk, dict):
|
||||
if chunk.get("content"):
|
||||
yield chunk["content"]
|
||||
else:
|
||||
yield chunk
|
||||
|
||||
async def _generate_with_tools(
|
||||
self,
|
||||
messages: list,
|
||||
model: str,
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
system_prompt: Optional[str] = None,
|
||||
tools: list = None,
|
||||
user_id: Optional[str] = None,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""辅助方法:带工具的流式生成"""
|
||||
tool_calls_buffer = []
|
||||
|
||||
async for chunk in self.client.chat_completion_stream(
|
||||
messages=messages,
|
||||
model=model,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
system_prompt=system_prompt,
|
||||
tools=tools,
|
||||
tool_choice="auto",
|
||||
):
|
||||
if chunk.get("tool_calls"):
|
||||
tool_calls_buffer.extend(chunk["tool_calls"])
|
||||
logger.debug(f"🔧 _generate_with_tools 收到工具调用: {len(chunk['tool_calls'])} 个")
|
||||
|
||||
if chunk.get("done"):
|
||||
if tool_calls_buffer:
|
||||
from app.mcp import mcp_client
|
||||
actual_user_id = user_id or ""
|
||||
tool_results = await mcp_client.batch_call_tools(
|
||||
user_id=actual_user_id,
|
||||
tool_calls=tool_calls_buffer
|
||||
)
|
||||
tool_context = mcp_client.build_tool_context(tool_results, format="markdown")
|
||||
|
||||
messages.append({"role": "user", "content": f"{tool_context}\n\n请基于以上工具查询结果,给出完整详细的回答。"})
|
||||
|
||||
async for final_chunk in self._generate_with_tools(
|
||||
messages, model, temperature, max_tokens, system_prompt, tools, user_id
|
||||
):
|
||||
yield final_chunk
|
||||
break
|
||||
|
||||
if chunk.get("content"):
|
||||
yield chunk["content"]
|
||||
@@ -0,0 +1,161 @@
|
||||
"""OpenAI Provider"""
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional
|
||||
|
||||
from app.logger import get_logger
|
||||
from app.services.ai_clients.openai_client import OpenAIClient
|
||||
from .base_provider import BaseAIProvider
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class OpenAIProvider(BaseAIProvider):
|
||||
"""OpenAI 提供商"""
|
||||
|
||||
def __init__(self, client: OpenAIClient):
|
||||
self.client = client
|
||||
|
||||
async def generate(
|
||||
self,
|
||||
prompt: str,
|
||||
model: str,
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
system_prompt: Optional[str] = None,
|
||||
tools: Optional[List[Dict]] = None,
|
||||
tool_choice: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
messages = []
|
||||
if system_prompt:
|
||||
messages.append({"role": "system", "content": system_prompt})
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
|
||||
return await self.client.chat_completion(
|
||||
messages=messages,
|
||||
model=model,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
tools=tools,
|
||||
tool_choice=tool_choice,
|
||||
)
|
||||
|
||||
async def generate_stream(
|
||||
self,
|
||||
prompt: str,
|
||||
model: str,
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
system_prompt: Optional[str] = None,
|
||||
tools: Optional[List[Dict]] = None,
|
||||
tool_choice: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
messages = []
|
||||
if system_prompt:
|
||||
messages.append({"role": "system", "content": system_prompt})
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
|
||||
# 如果有工具,使用真正的流式工具调用
|
||||
if tools:
|
||||
logger.debug(f"🔧 OpenAIProvider: 有 {len(tools)} 个工具,使用流式处理")
|
||||
actual_tool_choice = tool_choice if tool_choice else "auto"
|
||||
|
||||
tool_calls_buffer = []
|
||||
|
||||
async for chunk in self.client.chat_completion_stream(
|
||||
messages=messages,
|
||||
model=model,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
tools=tools,
|
||||
tool_choice=actual_tool_choice,
|
||||
):
|
||||
# 检查是否有工具调用
|
||||
if chunk.get("tool_calls"):
|
||||
tool_calls_buffer.extend(chunk["tool_calls"])
|
||||
logger.debug(f"🔧 收到工具调用: {len(chunk['tool_calls'])} 个")
|
||||
|
||||
# 检查是否结束
|
||||
if chunk.get("done"):
|
||||
if tool_calls_buffer:
|
||||
logger.info(f"🔧 流式结束,处理 {len(tool_calls_buffer)} 个工具调用")
|
||||
from app.mcp import mcp_client
|
||||
actual_user_id = user_id or ""
|
||||
tool_results = await mcp_client.batch_call_tools(
|
||||
user_id=actual_user_id,
|
||||
tool_calls=tool_calls_buffer
|
||||
)
|
||||
# 将工具结果注入到上下文中
|
||||
tool_context = mcp_client.build_tool_context(tool_results, format="markdown")
|
||||
|
||||
# 构建最终提示词,要求AI基于工具结果回答
|
||||
final_prompt = f"{prompt}\n\n{tool_context}\n\n请基于以上工具查询结果,给出完整详细的回答。"
|
||||
final_messages = messages.copy()
|
||||
final_messages.append({"role": "user", "content": final_prompt})
|
||||
|
||||
# 递归调用生成最终结果
|
||||
async for final_chunk in self._generate_with_tools(
|
||||
final_messages, model, temperature, max_tokens, tools, user_id
|
||||
):
|
||||
yield final_chunk
|
||||
break
|
||||
|
||||
# 输出文本内容
|
||||
if chunk.get("content"):
|
||||
yield chunk["content"]
|
||||
return
|
||||
|
||||
# 无工具时普通流式生成
|
||||
async for chunk in self.client.chat_completion_stream(
|
||||
messages=messages,
|
||||
model=model,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
):
|
||||
# 确保只 yield 字符串内容,避免 yield 字典导致类型错误
|
||||
if isinstance(chunk, dict):
|
||||
if chunk.get("content"):
|
||||
yield chunk["content"]
|
||||
else:
|
||||
yield chunk
|
||||
|
||||
async def _generate_with_tools(
|
||||
self,
|
||||
messages: list,
|
||||
model: str,
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
tools: list,
|
||||
user_id: Optional[str] = None,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""辅助方法:带工具的流式生成(无tool_choice,AI自由决定)"""
|
||||
async for chunk in self.client.chat_completion_stream(
|
||||
messages=messages,
|
||||
model=model,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
tools=tools,
|
||||
tool_choice="auto",
|
||||
):
|
||||
if chunk.get("tool_calls"):
|
||||
from app.mcp import mcp_client
|
||||
actual_user_id = user_id or ""
|
||||
tool_results = await mcp_client.batch_call_tools(
|
||||
user_id=actual_user_id,
|
||||
tool_calls=chunk["tool_calls"]
|
||||
)
|
||||
tool_context = mcp_client.build_tool_context(tool_results, format="markdown")
|
||||
|
||||
# 再次调用获取最终回答
|
||||
messages.append({"role": "user", "content": f"{tool_context}\n\n请基于以上工具查询结果,给出完整详细的回答。"})
|
||||
|
||||
async for final_chunk in self._generate_with_tools(
|
||||
messages, model, temperature, max_tokens, tools, user_id
|
||||
):
|
||||
yield final_chunk
|
||||
break
|
||||
|
||||
if chunk.get("done"):
|
||||
break
|
||||
|
||||
if chunk.get("content"):
|
||||
yield chunk["content"]
|
||||
+475
-394
@@ -1,17 +1,68 @@
|
||||
"""AI服务封装 - 统一的OpenAI和Claude接口"""
|
||||
from typing import Optional, AsyncGenerator, List, Dict, Any
|
||||
from openai import AsyncOpenAI
|
||||
from anthropic import AsyncAnthropic
|
||||
"""AI服务封装 - 统一的AI接口
|
||||
|
||||
重构后支持自动MCP工具加载:
|
||||
- 所有AI方法在请求前自动检查用户MCP配置
|
||||
- 如果有启用的MCP插件且有可用工具,自动发送tools
|
||||
- 通过 auto_mcp 参数控制是否启用自动工具加载
|
||||
"""
|
||||
from typing import Optional, AsyncGenerator, List, Dict, Any, Union
|
||||
|
||||
from app.config import settings as app_settings
|
||||
from app.logger import get_logger
|
||||
import httpx
|
||||
from app.services.ai_config import AIClientConfig, default_config
|
||||
from app.services.ai_clients.openai_client import OpenAIClient
|
||||
from app.services.ai_clients.anthropic_client import AnthropicClient
|
||||
from app.services.ai_clients.gemini_client import GeminiClient
|
||||
from app.services.ai_clients.base_client import cleanup_all_clients
|
||||
from app.services.ai_providers.openai_provider import OpenAIProvider
|
||||
from app.services.ai_providers.anthropic_provider import AnthropicProvider
|
||||
from app.services.ai_providers.gemini_provider import GeminiProvider
|
||||
from app.services.ai_providers.base_provider import BaseAIProvider
|
||||
from app.services.json_helper import clean_json_response, parse_json
|
||||
|
||||
# 导出清理函数
|
||||
cleanup_http_clients = cleanup_all_clients
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class AIService:
|
||||
"""AI服务统一接口 - 支持从用户设置或全局配置初始化"""
|
||||
"""
|
||||
AI服务统一接口
|
||||
|
||||
MCP工具支持:
|
||||
- 在创建服务时传入 user_id 和 db_session
|
||||
- 根据用户MCP插件的enabled状态自动决定是否启用MCP
|
||||
- 如果有任意一个MCP插件启用,则加载并使用工具
|
||||
- 如果所有插件都关闭,则不使用任何MCP工具
|
||||
- 通过 auto_mcp=False 可临时禁用自动工具加载
|
||||
- 通过 mcp_max_rounds 控制工具调用轮数
|
||||
- 通过 clear_mcp_cache() 可清理MCP工具缓存
|
||||
|
||||
MCP启用逻辑(backend/app/api/settings.py 中的 get_user_ai_service):
|
||||
- 查询用户的所有MCP插件
|
||||
- 如果有启用的插件 (enabled=True),则 enable_mcp=True
|
||||
- 如果所有插件都关闭或没有插件,则 enable_mcp=False
|
||||
|
||||
使用示例:
|
||||
# 创建支持MCP的AI服务(根据插件状态自动决定是否启用)
|
||||
ai_service = create_user_ai_service_with_mcp(
|
||||
api_provider="openai",
|
||||
api_key="...",
|
||||
user_id="user123",
|
||||
db_session=db
|
||||
)
|
||||
|
||||
# 自动加载MCP工具(如果有启用的插件)
|
||||
result = await ai_service.generate_text(prompt="...")
|
||||
|
||||
# 临时禁用MCP工具
|
||||
result = await ai_service.generate_text(prompt="...", auto_mcp=False)
|
||||
|
||||
# 自定义轮数
|
||||
result = await ai_service.generate_text(prompt="...", mcp_max_rounds=3)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_provider: Optional[str] = None,
|
||||
@@ -19,106 +70,252 @@ class AIService:
|
||||
api_base_url: Optional[str] = None,
|
||||
default_model: Optional[str] = None,
|
||||
default_temperature: Optional[float] = None,
|
||||
default_max_tokens: Optional[int] = None
|
||||
default_max_tokens: Optional[int] = None,
|
||||
default_system_prompt: Optional[str] = None,
|
||||
config: Optional[AIClientConfig] = None,
|
||||
# MCP支持参数
|
||||
user_id: Optional[str] = None,
|
||||
db_session: Optional[Any] = None,
|
||||
enable_mcp: bool = True,
|
||||
):
|
||||
"""
|
||||
初始化AI客户端(优化并发性能)
|
||||
|
||||
Args:
|
||||
api_provider: API提供商 (openai/anthropic),为None时使用全局配置
|
||||
api_key: API密钥,为None时使用全局配置
|
||||
api_base_url: API基础URL,为None时使用全局配置
|
||||
default_model: 默认模型,为None时使用全局配置
|
||||
default_temperature: 默认温度,为None时使用全局配置
|
||||
default_max_tokens: 默认最大tokens,为None时使用全局配置
|
||||
"""
|
||||
# 保存用户设置或使用全局配置
|
||||
self.api_provider = api_provider or app_settings.default_ai_provider
|
||||
self.default_model = default_model or app_settings.default_model
|
||||
self.default_temperature = default_temperature or app_settings.default_temperature
|
||||
self.default_max_tokens = default_max_tokens or app_settings.default_max_tokens
|
||||
self.default_system_prompt = default_system_prompt
|
||||
self.config = config or default_config
|
||||
|
||||
# 初始化OpenAI客户端
|
||||
# MCP配置
|
||||
self.user_id = user_id
|
||||
self.db_session = db_session
|
||||
self._enable_mcp = enable_mcp
|
||||
self._cached_tools: Optional[List[Dict]] = None
|
||||
self._tools_loaded = False
|
||||
|
||||
self._openai_provider: Optional[OpenAIProvider] = None
|
||||
self._anthropic_provider: Optional[AnthropicProvider] = None
|
||||
self._gemini_provider: Optional[GeminiProvider] = None
|
||||
|
||||
# 初始化 OpenAI
|
||||
openai_key = api_key if api_provider == "openai" else app_settings.openai_api_key
|
||||
if openai_key:
|
||||
try:
|
||||
limits = httpx.Limits(
|
||||
max_keepalive_connections=50,
|
||||
max_connections=100,
|
||||
keepalive_expiry=30.0
|
||||
)
|
||||
|
||||
http_client = httpx.AsyncClient(
|
||||
timeout=httpx.Timeout(connect=60.0, read=180.0, write=60.0, pool=60.0),
|
||||
limits=limits,
|
||||
headers={
|
||||
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36"
|
||||
}
|
||||
)
|
||||
|
||||
client_kwargs = {
|
||||
"api_key": openai_key,
|
||||
"http_client": http_client
|
||||
}
|
||||
|
||||
base_url = api_base_url if api_provider == "openai" else app_settings.openai_base_url
|
||||
if base_url:
|
||||
client_kwargs["base_url"] = base_url
|
||||
|
||||
self.openai_client = AsyncOpenAI(**client_kwargs)
|
||||
self.openai_http_client = http_client
|
||||
self.openai_api_key = openai_key
|
||||
self.openai_base_url = base_url
|
||||
logger.info("✅ OpenAI客户端初始化成功")
|
||||
except Exception as e:
|
||||
logger.error(f"OpenAI客户端初始化失败: {e}")
|
||||
self.openai_client = None
|
||||
self.openai_http_client = None
|
||||
self.openai_api_key = None
|
||||
self.openai_base_url = None
|
||||
else:
|
||||
self.openai_client = None
|
||||
self.openai_http_client = None
|
||||
self.openai_api_key = None
|
||||
self.openai_base_url = None
|
||||
logger.warning("OpenAI API key未配置")
|
||||
base_url = api_base_url if api_provider == "openai" else app_settings.openai_base_url
|
||||
client = OpenAIClient(openai_key, base_url or "https://api.openai.com/v1", self.config)
|
||||
self._openai_provider = OpenAIProvider(client)
|
||||
|
||||
# 初始化Anthropic客户端
|
||||
# 初始化 Anthropic
|
||||
anthropic_key = api_key if api_provider == "anthropic" else app_settings.anthropic_api_key
|
||||
if anthropic_key:
|
||||
try:
|
||||
limits = httpx.Limits(
|
||||
max_keepalive_connections=50,
|
||||
max_connections=100,
|
||||
keepalive_expiry=30.0
|
||||
)
|
||||
|
||||
http_client = httpx.AsyncClient(
|
||||
timeout=httpx.Timeout(connect=60.0, read=180.0, write=60.0, pool=60.0),
|
||||
limits=limits,
|
||||
headers={
|
||||
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36"
|
||||
}
|
||||
)
|
||||
|
||||
client_kwargs = {
|
||||
"api_key": anthropic_key,
|
||||
"http_client": http_client
|
||||
}
|
||||
|
||||
base_url = api_base_url if api_provider == "anthropic" else app_settings.anthropic_base_url
|
||||
if base_url:
|
||||
client_kwargs["base_url"] = base_url
|
||||
|
||||
self.anthropic_client = AsyncAnthropic(**client_kwargs)
|
||||
logger.info("✅ Anthropic客户端初始化成功")
|
||||
except Exception as e:
|
||||
logger.error(f"Anthropic客户端初始化失败: {e}")
|
||||
self.anthropic_client = None
|
||||
else:
|
||||
self.anthropic_client = None
|
||||
logger.warning("Anthropic API key未配置")
|
||||
base_url = api_base_url if api_provider == "anthropic" else app_settings.anthropic_base_url
|
||||
client = AnthropicClient(anthropic_key, base_url, self.config)
|
||||
self._anthropic_provider = AnthropicProvider(client)
|
||||
|
||||
# 初始化 Gemini
|
||||
if api_provider == "gemini" and api_key:
|
||||
client = GeminiClient(api_key, api_base_url, self.config)
|
||||
self._gemini_provider = GeminiProvider(client)
|
||||
|
||||
@property
|
||||
def enable_mcp(self) -> bool:
|
||||
"""是否启用MCP工具"""
|
||||
return self._enable_mcp
|
||||
|
||||
@enable_mcp.setter
|
||||
def enable_mcp(self, value: bool):
|
||||
"""设置MCP启用状态,如果禁用则清理缓存"""
|
||||
if value is False and self._enable_mcp is True:
|
||||
# 从启用变为禁用,清理缓存
|
||||
self.clear_mcp_cache()
|
||||
self._enable_mcp = value
|
||||
|
||||
def clear_mcp_cache(self):
|
||||
"""
|
||||
清理MCP工具缓存
|
||||
|
||||
当禁用MCP时调用此方法,确保后续AI调用不会使用缓存的工具。
|
||||
同时更新 _tools_loaded 状态,使下次调用时重新检查。
|
||||
"""
|
||||
if self._cached_tools is not None:
|
||||
logger.info(f"🔧 清理MCP工具缓存,移除 {len(self._cached_tools)} 个工具")
|
||||
self._cached_tools = None
|
||||
else:
|
||||
logger.debug(f"🔧 MCP工具缓存已经是空,无需清理")
|
||||
|
||||
# 更新加载状态,确保下次调用会重新检查
|
||||
self._tools_loaded = False
|
||||
logger.debug(f"🔧 MCP工具状态已重置: enable_mcp={self._enable_mcp}, _tools_loaded=False")
|
||||
|
||||
def _get_provider(self, provider: Optional[str] = None) -> BaseAIProvider:
|
||||
"""获取对应的 Provider"""
|
||||
p = provider or self.api_provider
|
||||
if p == "openai" and self._openai_provider:
|
||||
return self._openai_provider
|
||||
if p == "anthropic" and self._anthropic_provider:
|
||||
return self._anthropic_provider
|
||||
if p == "gemini" and self._gemini_provider:
|
||||
return self._gemini_provider
|
||||
raise ValueError(f"Provider {p} 未初始化")
|
||||
|
||||
async def _prepare_mcp_tools(self, auto_mcp: bool = True, force_refresh: bool = False) -> Optional[List[Dict]]:
|
||||
"""
|
||||
预处理MCP工具
|
||||
|
||||
检查用户MCP配置并加载可用工具。
|
||||
结果会被缓存,避免重复加载。
|
||||
|
||||
Args:
|
||||
auto_mcp: 是否自动加载MCP工具(来自调用方参数)
|
||||
force_refresh: 是否强制刷新缓存
|
||||
|
||||
Returns:
|
||||
- None: 无可用工具(未配置/未启用/加载失败)
|
||||
- List[Dict]: OpenAI格式的工具列表
|
||||
"""
|
||||
# 前置条件检查
|
||||
if not self._enable_mcp:
|
||||
logger.debug(f"🔧 MCP工具未启用 (_enable_mcp=False)")
|
||||
# 即使有缓存也清理掉,确保不使用
|
||||
self._cached_tools = None
|
||||
self._tools_loaded = False
|
||||
return None
|
||||
|
||||
if not auto_mcp:
|
||||
logger.debug(f"🔧 auto_mcp=False,跳过MCP工具加载")
|
||||
# 即使有缓存也清理掉,确保不使用
|
||||
self._cached_tools = None
|
||||
self._tools_loaded = False
|
||||
return None
|
||||
|
||||
if not self.user_id:
|
||||
logger.debug(f"🔧 MCP工具加载跳过: user_id未设置")
|
||||
return None
|
||||
|
||||
if not self.db_session:
|
||||
logger.debug(f"🔧 MCP工具加载跳过: db_session未设置")
|
||||
return None
|
||||
|
||||
# 使用缓存(只有 enable_mcp=True 时才使用缓存)
|
||||
if self._tools_loaded and not force_refresh:
|
||||
if self._cached_tools:
|
||||
logger.debug(f"🔧 使用缓存的MCP工具 ({len(self._cached_tools)}个)")
|
||||
return self._cached_tools
|
||||
|
||||
try:
|
||||
from app.services.mcp_tools_loader import mcp_tools_loader
|
||||
|
||||
self._cached_tools = await mcp_tools_loader.get_user_tools(
|
||||
user_id=self.user_id,
|
||||
db_session=self.db_session,
|
||||
use_cache=True,
|
||||
force_refresh=force_refresh
|
||||
)
|
||||
self._tools_loaded = True
|
||||
|
||||
if self._cached_tools:
|
||||
logger.info(f"🔧 已加载 {len(self._cached_tools)} 个MCP工具")
|
||||
else:
|
||||
logger.debug(f"📭 用户 {self.user_id} 没有可用的MCP工具")
|
||||
|
||||
return self._cached_tools
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"⚠️ 加载MCP工具失败: {e}")
|
||||
self._tools_loaded = True
|
||||
self._cached_tools = None
|
||||
return None
|
||||
|
||||
async def _handle_tool_calls(
|
||||
self,
|
||||
original_prompt: str,
|
||||
response: Dict[str, Any],
|
||||
max_rounds: int = 2,
|
||||
**kwargs
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
处理AI返回的工具调用
|
||||
|
||||
Args:
|
||||
original_prompt: 原始提示词
|
||||
response: AI响应(包含tool_calls)
|
||||
max_rounds: 最大工具调用轮数
|
||||
**kwargs: 传递给generate_text的其他参数
|
||||
|
||||
Returns:
|
||||
最终的AI响应
|
||||
"""
|
||||
from app.mcp import mcp_client
|
||||
|
||||
tool_calls = response.get("tool_calls", [])
|
||||
if not tool_calls or not self.user_id:
|
||||
return response
|
||||
|
||||
result = {
|
||||
"content": response.get("content", ""),
|
||||
"tool_calls_made": 0,
|
||||
"tools_used": [],
|
||||
"finish_reason": response.get("finish_reason", ""),
|
||||
"mcp_enhanced": True
|
||||
}
|
||||
|
||||
prompt = original_prompt
|
||||
|
||||
for round_num in range(max_rounds):
|
||||
logger.info(f"🔧 工具调用 - 第{round_num+1}/{max_rounds}轮,{len(tool_calls)}个工具")
|
||||
|
||||
try:
|
||||
# 批量执行工具调用
|
||||
tool_results = await mcp_client.batch_call_tools(
|
||||
user_id=self.user_id,
|
||||
tool_calls=tool_calls
|
||||
)
|
||||
|
||||
# 记录使用的工具
|
||||
for tc in tool_calls:
|
||||
name = tc["function"]["name"]
|
||||
if name not in result["tools_used"]:
|
||||
result["tools_used"].append(name)
|
||||
result["tool_calls_made"] += len(tool_calls)
|
||||
|
||||
# 构建工具上下文
|
||||
tool_context = mcp_client.build_tool_context(tool_results, format="markdown")
|
||||
|
||||
# 更新提示词
|
||||
if round_num == max_rounds - 1:
|
||||
# 最后一轮,强制要求回答
|
||||
prompt = f"{original_prompt}\n\n{tool_context}\n\n⚠️ 重要:请基于以上工具查询结果,给出完整详细的最终答案。不要再调用工具。"
|
||||
tool_choice = "none"
|
||||
else:
|
||||
prompt = f"{original_prompt}\n\n{tool_context}\n\n请基于以上工具查询结果,继续完成任务。"
|
||||
tool_choice = kwargs.get("tool_choice", "auto")
|
||||
|
||||
# 继续调用AI
|
||||
prov = self._get_provider(kwargs.get("provider"))
|
||||
next_response = await prov.generate(
|
||||
prompt=prompt,
|
||||
model=kwargs.get("model") or self.default_model,
|
||||
temperature=kwargs.get("temperature") or self.default_temperature,
|
||||
max_tokens=kwargs.get("max_tokens") or self.default_max_tokens,
|
||||
system_prompt=kwargs.get("system_prompt") or self.default_system_prompt,
|
||||
tools=None if tool_choice == "none" else self._cached_tools,
|
||||
tool_choice=tool_choice,
|
||||
)
|
||||
|
||||
tool_calls = next_response.get("tool_calls", [])
|
||||
|
||||
if not tool_calls:
|
||||
# 没有更多工具调用,返回结果
|
||||
result["content"] = next_response.get("content", "")
|
||||
result["finish_reason"] = next_response.get("finish_reason", "stop")
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 工具调用失败: {e}")
|
||||
result["content"] = response.get("content", "")
|
||||
result["finish_reason"] = "tool_error"
|
||||
break
|
||||
|
||||
return result
|
||||
|
||||
async def generate_text(
|
||||
self,
|
||||
prompt: str,
|
||||
@@ -126,38 +323,67 @@ class AIService:
|
||||
model: Optional[str] = None,
|
||||
temperature: Optional[float] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
system_prompt: Optional[str] = None
|
||||
) -> str:
|
||||
system_prompt: Optional[str] = None,
|
||||
tools: Optional[List[Dict]] = None,
|
||||
tool_choice: Optional[str] = None,
|
||||
auto_mcp: bool = True,
|
||||
handle_tool_calls: bool = True,
|
||||
mcp_max_rounds: Optional[int] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
生成文本
|
||||
生成文本(自动支持MCP工具)
|
||||
|
||||
Args:
|
||||
prompt: 用户提示词
|
||||
provider: AI提供商 (openai/anthropic)
|
||||
provider: AI提供商
|
||||
model: 模型名称
|
||||
temperature: 温度参数
|
||||
max_tokens: 最大token数
|
||||
temperature: 温度
|
||||
max_tokens: 最大令牌数
|
||||
system_prompt: 系统提示词
|
||||
tools: 手动指定的工具列表(优先级高于自动加载)
|
||||
tool_choice: 工具选择策略
|
||||
auto_mcp: 是否自动加载MCP工具(默认True)
|
||||
handle_tool_calls: 是否自动处理工具调用(默认True)
|
||||
mcp_max_rounds: 最大工具调用轮数(None使用默认值3)
|
||||
|
||||
Returns:
|
||||
生成的文本
|
||||
包含生成内容的字典
|
||||
"""
|
||||
provider = provider or self.api_provider
|
||||
model = model or self.default_model
|
||||
temperature = temperature or self.default_temperature
|
||||
max_tokens = max_tokens or self.default_max_tokens
|
||||
# 使用全局配置的MCP轮数(如果未指定)
|
||||
if mcp_max_rounds is None:
|
||||
mcp_max_rounds = app_settings.mcp_max_rounds
|
||||
|
||||
if provider == "openai":
|
||||
return await self._generate_openai(
|
||||
prompt, model, temperature, max_tokens, system_prompt
|
||||
# 自动加载MCP工具
|
||||
if auto_mcp and tools is None:
|
||||
tools = await self._prepare_mcp_tools(auto_mcp=auto_mcp)
|
||||
|
||||
prov = self._get_provider(provider)
|
||||
response = await prov.generate(
|
||||
prompt=prompt,
|
||||
model=model or self.default_model,
|
||||
temperature=temperature or self.default_temperature,
|
||||
max_tokens=max_tokens or self.default_max_tokens,
|
||||
system_prompt=system_prompt or self.default_system_prompt,
|
||||
tools=tools,
|
||||
tool_choice=tool_choice,
|
||||
)
|
||||
|
||||
# 处理工具调用
|
||||
if handle_tool_calls and response.get("tool_calls"):
|
||||
return await self._handle_tool_calls(
|
||||
original_prompt=prompt,
|
||||
response=response,
|
||||
provider=provider,
|
||||
model=model,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
system_prompt=system_prompt,
|
||||
tool_choice=tool_choice,
|
||||
max_rounds=mcp_max_rounds,
|
||||
)
|
||||
elif provider == "anthropic":
|
||||
return await self._generate_anthropic(
|
||||
prompt, model, temperature, max_tokens, system_prompt
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"不支持的AI提供商: {provider}")
|
||||
|
||||
|
||||
return response
|
||||
|
||||
async def generate_text_stream(
|
||||
self,
|
||||
prompt: str,
|
||||
@@ -165,301 +391,123 @@ class AIService:
|
||||
model: Optional[str] = None,
|
||||
temperature: Optional[float] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
system_prompt: Optional[str] = None
|
||||
system_prompt: Optional[str] = None,
|
||||
tool_choice: Optional[str] = None,
|
||||
auto_mcp: bool = True,
|
||||
mcp_max_rounds: Optional[int] = None,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""
|
||||
流式生成文本
|
||||
流式生成文本(自动支持MCP工具)
|
||||
|
||||
工具调用在 Provider 层通过流式方式处理,支持真正的流式工具调用。
|
||||
|
||||
Args:
|
||||
prompt: 用户提示词
|
||||
provider: AI提供商
|
||||
model: 模型名称
|
||||
temperature: 温度参数
|
||||
max_tokens: 最大token数
|
||||
temperature: 温度
|
||||
max_tokens: 最大令牌数
|
||||
system_prompt: 系统提示词
|
||||
tool_choice: 工具选择策略("auto"/"none"/"required")
|
||||
auto_mcp: 是否自动加载MCP工具
|
||||
mcp_max_rounds: 最大工具调用轮数(None使用默认值3)
|
||||
|
||||
Yields:
|
||||
生成的文本片段
|
||||
生成的文本块
|
||||
"""
|
||||
provider = provider or self.api_provider
|
||||
model = model or self.default_model
|
||||
temperature = temperature or self.default_temperature
|
||||
max_tokens = max_tokens or self.default_max_tokens
|
||||
logger.debug(f"🔧 generate_text_stream: auto_mcp={auto_mcp}, tool_choice={tool_choice}")
|
||||
|
||||
if provider == "openai":
|
||||
async for chunk in self._generate_openai_stream(
|
||||
prompt, model, temperature, max_tokens, system_prompt
|
||||
):
|
||||
yield chunk
|
||||
elif provider == "anthropic":
|
||||
async for chunk in self._generate_anthropic_stream(
|
||||
prompt, model, temperature, max_tokens, system_prompt
|
||||
):
|
||||
yield chunk
|
||||
else:
|
||||
raise ValueError(f"不支持的AI提供商: {provider}")
|
||||
|
||||
async def _generate_openai(
|
||||
tools_to_use = None
|
||||
|
||||
# 加载MCP工具
|
||||
if auto_mcp:
|
||||
tools_to_use = await self._prepare_mcp_tools(auto_mcp=auto_mcp)
|
||||
if tools_to_use:
|
||||
logger.info(f"🔧 已获取 {len(tools_to_use)} 个MCP工具")
|
||||
|
||||
# 流式生成(Provider 层处理工具调用)
|
||||
prov = self._get_provider(provider)
|
||||
logger.debug(f"🔧 开始流式生成,provider={provider or self.api_provider}, tools_count={len(tools_to_use) if tools_to_use else 0}")
|
||||
async for chunk in prov.generate_stream(
|
||||
prompt=prompt,
|
||||
model=model or self.default_model,
|
||||
temperature=temperature or self.default_temperature,
|
||||
max_tokens=max_tokens or self.default_max_tokens,
|
||||
system_prompt=system_prompt or self.default_system_prompt,
|
||||
tools=tools_to_use,
|
||||
tool_choice=tool_choice,
|
||||
user_id=self.user_id,
|
||||
):
|
||||
yield chunk
|
||||
|
||||
async def call_with_json_retry(
|
||||
self,
|
||||
prompt: str,
|
||||
model: str,
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
system_prompt: Optional[str]
|
||||
) -> str:
|
||||
"""使用OpenAI生成文本"""
|
||||
if not self.openai_http_client:
|
||||
raise ValueError("OpenAI客户端未初始化,请检查API key配置")
|
||||
system_prompt: Optional[str] = None,
|
||||
max_retries: int = 3,
|
||||
temperature: Optional[float] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
provider: Optional[str] = None,
|
||||
model: Optional[str] = None,
|
||||
expected_type: Optional[str] = None,
|
||||
auto_mcp: bool = True,
|
||||
) -> Union[Dict, List]:
|
||||
"""
|
||||
带重试的 JSON 调用(自动支持MCP工具)
|
||||
|
||||
messages = []
|
||||
if system_prompt:
|
||||
messages.append({"role": "system", "content": system_prompt})
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
Args:
|
||||
prompt: 用户提示词
|
||||
system_prompt: 系统提示词
|
||||
max_retries: 最大重试次数
|
||||
temperature: 温度
|
||||
max_tokens: 最大令牌数
|
||||
provider: AI提供商
|
||||
model: 模型名称
|
||||
expected_type: 期望的返回类型("object"或"array")
|
||||
auto_mcp: 是否自动加载MCP工具
|
||||
|
||||
Returns:
|
||||
解析后的JSON数据
|
||||
"""
|
||||
last_response = ""
|
||||
|
||||
try:
|
||||
logger.info(f"🔵 开始调用OpenAI API(直接HTTP请求)")
|
||||
logger.info(f" - 模型: {model}")
|
||||
logger.info(f" - 温度: {temperature}")
|
||||
logger.info(f" - 最大tokens: {max_tokens}")
|
||||
logger.info(f" - Prompt长度: {len(prompt)} 字符")
|
||||
logger.info(f" - 消息数量: {len(messages)}")
|
||||
for attempt in range(1, max_retries + 1):
|
||||
current_prompt = prompt if attempt == 1 else self._add_json_hint(prompt, last_response, attempt)
|
||||
|
||||
url = f"{self.openai_base_url}/chat/completions"
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.openai_api_key}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
payload = {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"temperature": temperature,
|
||||
"max_tokens": max_tokens
|
||||
}
|
||||
|
||||
logger.debug(f" - 请求URL: {url}")
|
||||
logger.debug(f" - 请求头: Authorization=Bearer ***")
|
||||
|
||||
response = await self.openai_http_client.post(url, headers=headers, json=payload)
|
||||
response.raise_for_status()
|
||||
|
||||
data = response.json()
|
||||
|
||||
logger.info(f"✅ OpenAI API调用成功")
|
||||
logger.info(f" - 响应ID: {data.get('id', 'N/A')}")
|
||||
logger.info(f" - 选项数量: {len(data.get('choices', []))}")
|
||||
|
||||
if not data.get('choices'):
|
||||
logger.error("❌ OpenAI返回的choices为空")
|
||||
raise ValueError("API返回的响应格式错误:choices字段为空")
|
||||
|
||||
choice = data['choices'][0]
|
||||
message = choice.get('message', {})
|
||||
finish_reason = choice.get('finish_reason')
|
||||
|
||||
# DeepSeek R1特殊处理:只使用content(最终答案),忽略reasoning_content(思考过程)
|
||||
# reasoning_content是AI的思考过程,不是我们需要的JSON结果
|
||||
content = message.get('content', '')
|
||||
|
||||
# 检查是否因达到长度限制而截断
|
||||
if finish_reason == 'length':
|
||||
logger.warning(f"⚠️ 响应因达到max_tokens限制而被截断")
|
||||
logger.warning(f" - 当前max_tokens: {max_tokens}")
|
||||
logger.warning(f" - 建议: 增加max_tokens参数(推荐2000+)")
|
||||
|
||||
if content:
|
||||
logger.info(f" - 返回内容长度: {len(content)} 字符")
|
||||
logger.info(f" - 完成原因: {finish_reason}")
|
||||
logger.info(f" - 返回内容预览(前200字符): {content[:200]}")
|
||||
return content
|
||||
else:
|
||||
logger.error("❌ AI返回了空内容")
|
||||
logger.error(f" - 完整响应: {data}")
|
||||
logger.error(f" - 完成原因: {finish_reason}")
|
||||
|
||||
# 提供更详细的错误信息
|
||||
if finish_reason == 'length':
|
||||
raise ValueError(f"AI响应被截断且无有效内容。请增加max_tokens参数(当前: {max_tokens},建议: 2000+)")
|
||||
else:
|
||||
raise ValueError(f"AI返回了空内容(finish_reason: {finish_reason}),请检查API配置或稍后重试")
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.error(f"❌ OpenAI API调用失败 (HTTP {e.response.status_code})")
|
||||
logger.error(f" - 错误信息: {e.response.text}")
|
||||
logger.error(f" - 模型: {model}")
|
||||
raise Exception(f"API返回错误 ({e.response.status_code}): {e.response.text}")
|
||||
except Exception as e:
|
||||
logger.error(f"❌ OpenAI API调用失败")
|
||||
logger.error(f" - 错误类型: {type(e).__name__}")
|
||||
logger.error(f" - 错误信息: {str(e)}")
|
||||
logger.error(f" - 模型: {model}")
|
||||
raise
|
||||
|
||||
async def _generate_openai_stream(
|
||||
self,
|
||||
prompt: str,
|
||||
model: str,
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
system_prompt: Optional[str]
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""使用OpenAI流式生成文本"""
|
||||
if not self.openai_http_client:
|
||||
raise ValueError("OpenAI客户端未初始化,请检查API key配置")
|
||||
|
||||
messages = []
|
||||
if system_prompt:
|
||||
messages.append({"role": "system", "content": system_prompt})
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
|
||||
try:
|
||||
logger.info(f"🔵 开始调用OpenAI流式API(直接HTTP请求)")
|
||||
logger.info(f" - 模型: {model}")
|
||||
logger.info(f" - Prompt长度: {len(prompt)} 字符")
|
||||
logger.info(f" - 最大tokens: {max_tokens}")
|
||||
|
||||
url = f"{self.openai_base_url}/chat/completions"
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.openai_api_key}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
payload = {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"temperature": temperature,
|
||||
"max_tokens": max_tokens,
|
||||
"stream": True
|
||||
}
|
||||
|
||||
async with self.openai_http_client.stream('POST', url, headers=headers, json=payload) as response:
|
||||
response.raise_for_status()
|
||||
logger.info(f"✅ OpenAI流式API连接成功,开始接收数据...")
|
||||
|
||||
chunk_count = 0
|
||||
has_content = False
|
||||
finish_reason = None
|
||||
|
||||
async for line in response.aiter_lines():
|
||||
if line.startswith('data: '):
|
||||
data_str = line[6:]
|
||||
if data_str.strip() == '[DONE]':
|
||||
break
|
||||
|
||||
try:
|
||||
import json
|
||||
data = json.loads(data_str)
|
||||
if 'choices' in data and len(data['choices']) > 0:
|
||||
choice = data['choices'][0]
|
||||
delta = choice.get('delta', {})
|
||||
finish_reason = choice.get('finish_reason') or finish_reason
|
||||
|
||||
# DeepSeek R1特殊处理:只收集content(最终答案),忽略reasoning_content(思考过程)
|
||||
# reasoning_content是AI的思考过程,不是我们需要的JSON结果
|
||||
content = delta.get('content', '')
|
||||
|
||||
if content:
|
||||
chunk_count += 1
|
||||
has_content = True
|
||||
yield content
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
# 检查是否因长度限制截断
|
||||
if finish_reason == 'length':
|
||||
logger.warning(f"⚠️ 流式响应因达到max_tokens限制而被截断")
|
||||
logger.warning(f" - 当前max_tokens: {max_tokens}")
|
||||
logger.warning(f" - 建议: 增加max_tokens参数(推荐2000+)")
|
||||
|
||||
if not has_content:
|
||||
logger.warning(f"⚠️ 流式响应未返回任何内容")
|
||||
logger.warning(f" - 完成原因: {finish_reason}")
|
||||
|
||||
logger.info(f"✅ OpenAI流式生成完成,共接收 {chunk_count} 个chunk,完成原因: {finish_reason}")
|
||||
|
||||
except httpx.TimeoutException as e:
|
||||
logger.error(f"❌ OpenAI流式API超时")
|
||||
logger.error(f" - 错误: {str(e)}")
|
||||
logger.error(f" - 提示: 请检查网络连接或考虑缩短prompt长度")
|
||||
raise TimeoutError(f"AI服务超时(180秒),请稍后重试或减少上下文长度") from e
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.error(f"❌ OpenAI流式API调用失败 (HTTP {e.response.status_code})")
|
||||
logger.error(f" - 错误信息: {await e.response.aread()}")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"❌ OpenAI流式API调用失败: {str(e)}")
|
||||
logger.error(f" - 错误类型: {type(e).__name__}")
|
||||
raise
|
||||
|
||||
async def _generate_anthropic(
|
||||
self,
|
||||
prompt: str,
|
||||
model: str,
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
system_prompt: Optional[str]
|
||||
) -> str:
|
||||
"""使用Anthropic生成文本"""
|
||||
if not self.anthropic_client:
|
||||
raise ValueError("Anthropic客户端未初始化,请检查API key配置")
|
||||
|
||||
try:
|
||||
response = await self.anthropic_client.messages.create(
|
||||
result = await self.generate_text(
|
||||
prompt=current_prompt,
|
||||
provider=provider,
|
||||
model=model,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
system=system_prompt or "",
|
||||
messages=[{"role": "user", "content": prompt}]
|
||||
max_tokens=max_tokens,
|
||||
system_prompt=system_prompt,
|
||||
auto_mcp=auto_mcp,
|
||||
handle_tool_calls=True,
|
||||
)
|
||||
return response.content[0].text
|
||||
except Exception as e:
|
||||
logger.error(f"Anthropic API调用失败: {str(e)}")
|
||||
raise
|
||||
|
||||
async def _generate_anthropic_stream(
|
||||
self,
|
||||
prompt: str,
|
||||
model: str,
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
system_prompt: Optional[str]
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""使用Anthropic流式生成文本"""
|
||||
if not self.anthropic_client:
|
||||
raise ValueError("Anthropic客户端未初始化,请检查API key配置")
|
||||
|
||||
try:
|
||||
logger.info(f"🔵 开始调用Anthropic流式API")
|
||||
logger.info(f" - 模型: {model}")
|
||||
logger.info(f" - Prompt长度: {len(prompt)} 字符")
|
||||
logger.info(f" - 最大tokens: {max_tokens}")
|
||||
|
||||
async with self.anthropic_client.messages.stream(
|
||||
model=model,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
system=system_prompt or "",
|
||||
messages=[{"role": "user", "content": prompt}]
|
||||
) as stream:
|
||||
logger.info(f"✅ Anthropic流式API连接成功,开始接收数据...")
|
||||
|
||||
chunk_count = 0
|
||||
async for text in stream.text_stream:
|
||||
chunk_count += 1
|
||||
yield text
|
||||
|
||||
logger.info(f"✅ Anthropic流式生成完成,共接收 {chunk_count} 个chunk")
|
||||
|
||||
except httpx.TimeoutException as e:
|
||||
logger.error(f"❌ Anthropic流式API超时")
|
||||
logger.error(f" - 错误: {str(e)}")
|
||||
raise TimeoutError(f"AI服务超时(180秒),请稍后重试或减少上下文长度") from e
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Anthropic流式API调用失败: {str(e)}")
|
||||
logger.error(f" - 错误类型: {type(e).__name__}")
|
||||
raise
|
||||
last_response = result.get("content", "")
|
||||
|
||||
try:
|
||||
data = parse_json(last_response)
|
||||
if expected_type == "object" and not isinstance(data, dict):
|
||||
raise ValueError("期望对象")
|
||||
if expected_type == "array" and not isinstance(data, list):
|
||||
raise ValueError("期望数组")
|
||||
return data
|
||||
except Exception as e:
|
||||
if attempt == max_retries:
|
||||
raise ValueError(f"JSON 解析失败: {e}")
|
||||
|
||||
raise ValueError("JSON 调用失败")
|
||||
|
||||
@staticmethod
|
||||
def _add_json_hint(prompt: str, failed: str, attempt: int) -> str:
|
||||
return f"{prompt}\n\n⚠️ 第{attempt}次重试,请返回纯JSON,不要markdown包裹。上次错误: {failed[:200]}..."
|
||||
|
||||
# 创建全局AI服务实例
|
||||
ai_service = AIService()
|
||||
@staticmethod
|
||||
def _clean_json_response(text: str) -> str:
|
||||
"""清洗 JSON 响应"""
|
||||
return clean_json_response(text)
|
||||
|
||||
|
||||
def create_user_ai_service(
|
||||
@@ -468,21 +516,50 @@ def create_user_ai_service(
|
||||
api_base_url: str,
|
||||
model_name: str,
|
||||
temperature: float,
|
||||
max_tokens: int
|
||||
max_tokens: int,
|
||||
system_prompt: Optional[str] = None,
|
||||
) -> AIService:
|
||||
"""创建用户 AI 服务(不带MCP支持)"""
|
||||
return AIService(
|
||||
api_provider=api_provider,
|
||||
api_key=api_key,
|
||||
api_base_url=api_base_url,
|
||||
default_model=model_name,
|
||||
default_temperature=temperature,
|
||||
default_max_tokens=max_tokens,
|
||||
default_system_prompt=system_prompt,
|
||||
)
|
||||
|
||||
|
||||
def create_user_ai_service_with_mcp(
|
||||
api_provider: str,
|
||||
api_key: str,
|
||||
api_base_url: str,
|
||||
model_name: str,
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
user_id: str,
|
||||
db_session,
|
||||
system_prompt: Optional[str] = None,
|
||||
enable_mcp: bool = True,
|
||||
) -> AIService:
|
||||
"""
|
||||
根据用户设置创建AI服务实例
|
||||
创建支持MCP的用户AI服务
|
||||
|
||||
Args:
|
||||
api_provider: API提供商
|
||||
api_provider: AI提供商
|
||||
api_key: API密钥
|
||||
api_base_url: API基础URL
|
||||
model_name: 模型名称
|
||||
temperature: 温度参数
|
||||
max_tokens: 最大tokens
|
||||
temperature: 温度
|
||||
max_tokens: 最大令牌数
|
||||
user_id: 用户ID(用于加载MCP工具)
|
||||
db_session: 数据库会话
|
||||
system_prompt: 系统提示词
|
||||
enable_mcp: 是否启用MCP工具
|
||||
|
||||
Returns:
|
||||
AIService实例
|
||||
配置好的AIService实例
|
||||
"""
|
||||
return AIService(
|
||||
api_provider=api_provider,
|
||||
@@ -490,5 +567,9 @@ def create_user_ai_service(
|
||||
api_base_url=api_base_url,
|
||||
default_model=model_name,
|
||||
default_temperature=temperature,
|
||||
default_max_tokens=max_tokens
|
||||
default_max_tokens=max_tokens,
|
||||
default_system_prompt=system_prompt,
|
||||
user_id=user_id,
|
||||
db_session=db_session,
|
||||
enable_mcp=enable_mcp,
|
||||
)
|
||||
@@ -0,0 +1,606 @@
|
||||
"""自动角色引入服务 - 在续写大纲时根据剧情推进自动引入新角色"""
|
||||
from typing import List, Dict, Any, Optional, Callable, Awaitable
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select
|
||||
import json
|
||||
|
||||
from app.models.character import Character
|
||||
from app.models.relationship import CharacterRelationship, Organization, OrganizationMember, RelationshipType
|
||||
from app.models.project import Project
|
||||
from app.services.ai_service import AIService
|
||||
from app.services.prompt_service import PromptService
|
||||
from app.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class AutoCharacterService:
|
||||
"""自动角色引入服务"""
|
||||
|
||||
def __init__(self, ai_service: AIService):
|
||||
self.ai_service = ai_service
|
||||
|
||||
async def analyze_and_create_characters(
|
||||
self,
|
||||
project_id: str,
|
||||
outline_content: str,
|
||||
existing_characters: List[Character],
|
||||
db: AsyncSession,
|
||||
user_id: str = None,
|
||||
enable_mcp: bool = True,
|
||||
all_chapters_brief: str = "",
|
||||
start_chapter: int = 1,
|
||||
chapter_count: int = 3,
|
||||
plot_stage: str = "发展",
|
||||
story_direction: str = "继续推进主线剧情",
|
||||
preview_only: bool = False,
|
||||
progress_callback: Optional[Callable[[str], Awaitable[None]]] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
预测性分析并创建需要的新角色(方案A:先角色后大纲)
|
||||
|
||||
Args:
|
||||
project_id: 项目ID
|
||||
outline_content: 当前批次大纲内容(用于向后兼容,实际不使用)
|
||||
existing_characters: 现有角色列表
|
||||
db: 数据库会话
|
||||
user_id: 用户ID(用于MCP和自定义提示词)
|
||||
enable_mcp: 是否启用MCP增强
|
||||
all_chapters_brief: 已有章节概览
|
||||
start_chapter: 起始章节号
|
||||
chapter_count: 续写章节数
|
||||
plot_stage: 剧情阶段
|
||||
story_direction: 故事发展方向
|
||||
preview_only: 仅预测不创建(用于角色确认机制)
|
||||
|
||||
Returns:
|
||||
{
|
||||
"new_characters": [角色对象列表], # preview_only=True时为空
|
||||
"relationships_created": [关系对象列表], # preview_only=True时为空
|
||||
"character_count": 新增角色数量,
|
||||
"analysis_result": AI分析结果,
|
||||
"predicted_characters": [预测的角色数据] # 仅preview_only=True时返回
|
||||
"needs_new_characters": bool,
|
||||
"reason": str
|
||||
}
|
||||
"""
|
||||
logger.info(f"🎭 【方案A】预测性分析:检测是否需要引入新角色...")
|
||||
logger.info(f" - 项目ID: {project_id}")
|
||||
logger.info(f" - 续写计划: 第{start_chapter}章起,共{chapter_count}章")
|
||||
logger.info(f" - 剧情阶段: {plot_stage}")
|
||||
logger.info(f" - 发展方向: {story_direction}")
|
||||
logger.info(f" - 现有角色数: {len(existing_characters)}")
|
||||
|
||||
# 1. 获取项目信息
|
||||
project_result = await db.execute(
|
||||
select(Project).where(Project.id == project_id)
|
||||
)
|
||||
project = project_result.scalar_one_or_none()
|
||||
if not project:
|
||||
raise ValueError("项目不存在")
|
||||
|
||||
# 2. 构建现有角色信息摘要
|
||||
existing_chars_summary = self._build_character_summary(existing_characters)
|
||||
|
||||
# 3. AI预测性分析是否需要新角色
|
||||
analysis_result = await self._analyze_character_needs(
|
||||
project=project,
|
||||
outline_content=outline_content, # 保留参数向后兼容
|
||||
existing_chars_summary=existing_chars_summary,
|
||||
db=db,
|
||||
user_id=user_id,
|
||||
enable_mcp=enable_mcp,
|
||||
all_chapters_brief=all_chapters_brief,
|
||||
start_chapter=start_chapter,
|
||||
chapter_count=chapter_count,
|
||||
plot_stage=plot_stage,
|
||||
story_direction=story_direction
|
||||
)
|
||||
|
||||
# 4. 判断是否需要创建角色
|
||||
if not analysis_result or not analysis_result.get("needs_new_characters"):
|
||||
logger.info("✅ AI判断:当前剧情不需要引入新角色")
|
||||
return {
|
||||
"new_characters": [],
|
||||
"relationships_created": [],
|
||||
"character_count": 0,
|
||||
"analysis_result": analysis_result,
|
||||
"predicted_characters": [],
|
||||
"needs_new_characters": False,
|
||||
"reason": analysis_result.get("reason", "当前剧情不需要新角色")
|
||||
}
|
||||
|
||||
# 5. 如果是预览模式,仅返回预测结果,不创建角色
|
||||
if preview_only:
|
||||
character_specs = analysis_result.get("character_specifications", [])
|
||||
logger.info(f"🔮 预览模式:预测到 {len(character_specs)} 个角色,不创建数据库记录")
|
||||
return {
|
||||
"new_characters": [],
|
||||
"relationships_created": [],
|
||||
"character_count": 0,
|
||||
"analysis_result": analysis_result,
|
||||
"predicted_characters": character_specs,
|
||||
"needs_new_characters": True,
|
||||
"reason": analysis_result.get("reason", "预测需要新角色")
|
||||
}
|
||||
|
||||
# 6. 批量生成新角色(非预览模式)
|
||||
new_characters = []
|
||||
relationships_created = []
|
||||
|
||||
character_specs = analysis_result.get("character_specifications", [])
|
||||
logger.info(f"🎯 AI建议引入 {len(character_specs)} 个新角色")
|
||||
|
||||
for idx, spec in enumerate(character_specs):
|
||||
try:
|
||||
spec_name = spec.get('name', spec.get('role_description', '未命名'))
|
||||
logger.info(f" [{idx+1}/{len(character_specs)}] 生成角色规格: {spec_name}")
|
||||
logger.debug(f" 角色规格内容: {json.dumps(spec, ensure_ascii=False)}")
|
||||
|
||||
if progress_callback:
|
||||
await progress_callback(f"🎨 [{idx+1}/{len(character_specs)}] 生成角色详情: {spec_name}")
|
||||
|
||||
# 生成角色详细信息
|
||||
character_data = await self._generate_character_details(
|
||||
spec=spec,
|
||||
project=project,
|
||||
existing_characters=existing_characters + new_characters, # 包含新创建的
|
||||
db=db,
|
||||
user_id=user_id,
|
||||
enable_mcp=enable_mcp
|
||||
)
|
||||
|
||||
logger.debug(f" AI生成的角色数据: {json.dumps(character_data, ensure_ascii=False)[:200]}")
|
||||
|
||||
if progress_callback:
|
||||
await progress_callback(f"💾 [{idx+1}/{len(character_specs)}] 保存角色: {character_data.get('name', spec_name)}")
|
||||
|
||||
# 创建角色记录
|
||||
character = await self._create_character_record(
|
||||
project_id=project_id,
|
||||
character_data=character_data,
|
||||
db=db
|
||||
)
|
||||
|
||||
new_characters.append(character)
|
||||
logger.info(f" ✅ 创建新角色: {character.name} ({character.role_type}), ID: {character.id}")
|
||||
|
||||
if progress_callback:
|
||||
await progress_callback(f"✅ [{idx+1}/{len(character_specs)}] 角色创建成功: {character.name}")
|
||||
|
||||
# 建立关系(兼容两种字段名)
|
||||
relationships_data = character_data.get("relationships") or character_data.get("relationships_array", [])
|
||||
logger.info(f" 🔍 检查关系数据:")
|
||||
logger.info(f" - relationships字段: {character_data.get('relationships')}")
|
||||
logger.info(f" - relationships_array字段: {character_data.get('relationships_array')}")
|
||||
logger.info(f" - 最终使用的数据: {relationships_data}")
|
||||
logger.info(f" - 关系数量: {len(relationships_data) if relationships_data else 0}")
|
||||
|
||||
if relationships_data:
|
||||
logger.info(f" 🔗 开始创建 {len(relationships_data)} 条关系...")
|
||||
for idx, rel in enumerate(relationships_data):
|
||||
logger.info(f" [{idx+1}] {rel.get('target_character_name')} - {rel.get('relationship_type')}")
|
||||
|
||||
if progress_callback:
|
||||
await progress_callback(f"🔗 [{idx+1}/{len(character_specs)}] 建立 {len(relationships_data)} 个关系")
|
||||
else:
|
||||
logger.warning(f" ⚠️ AI返回的角色数据中没有关系信息!")
|
||||
logger.warning(f" 完整的character_data keys: {list(character_data.keys())}")
|
||||
|
||||
rels = await self._create_relationships(
|
||||
new_character=character,
|
||||
relationship_specs=relationships_data,
|
||||
existing_characters=existing_characters + new_characters,
|
||||
project_id=project_id,
|
||||
db=db
|
||||
)
|
||||
|
||||
relationships_created.extend(rels)
|
||||
logger.info(f" ✅ 实际创建了 {len(rels)} 条关系记录")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f" ❌ 创建角色失败: {e}", exc_info=True)
|
||||
continue
|
||||
|
||||
# 7. 提交事务(注意:这里只flush,让调用方commit)
|
||||
await db.flush()
|
||||
|
||||
logger.info(f"🎉 自动角色引入完成: 新增{len(new_characters)}个角色, {len(relationships_created)}条关系")
|
||||
|
||||
return {
|
||||
"new_characters": new_characters,
|
||||
"relationships_created": relationships_created,
|
||||
"character_count": len(new_characters),
|
||||
"analysis_result": analysis_result
|
||||
}
|
||||
|
||||
def _build_character_summary(self, characters: List[Character]) -> str:
|
||||
"""构建现有角色摘要"""
|
||||
if not characters:
|
||||
return "暂无角色"
|
||||
|
||||
summary = []
|
||||
for char in characters:
|
||||
char_type = "组织" if char.is_organization else "角色"
|
||||
role_desc = char.role_type or "未知"
|
||||
personality = (char.personality or "")[:50]
|
||||
summary.append(f"- {char.name} ({char_type}, {role_desc}): {personality}")
|
||||
|
||||
return "\n".join(summary[:20]) # 最多显示20个
|
||||
|
||||
async def _analyze_character_needs(
|
||||
self,
|
||||
project: Project,
|
||||
outline_content: str,
|
||||
existing_chars_summary: str,
|
||||
db: AsyncSession,
|
||||
user_id: str,
|
||||
enable_mcp: bool,
|
||||
all_chapters_brief: str = "",
|
||||
start_chapter: int = 1,
|
||||
chapter_count: int = 3,
|
||||
plot_stage: str = "发展",
|
||||
story_direction: str = "继续推进主线剧情"
|
||||
) -> Dict[str, Any]:
|
||||
"""AI预测性分析是否需要新角色(方案A)"""
|
||||
|
||||
# 构建分析提示词
|
||||
template = await PromptService.get_template(
|
||||
"AUTO_CHARACTER_ANALYSIS",
|
||||
user_id,
|
||||
db
|
||||
)
|
||||
|
||||
# 使用新的预测性分析参数
|
||||
prompt = PromptService.format_prompt(
|
||||
template,
|
||||
title=project.title,
|
||||
theme=project.theme or "未设定",
|
||||
genre=project.genre or "未设定",
|
||||
time_period=project.world_time_period or "未设定",
|
||||
location=project.world_location or "未设定",
|
||||
atmosphere=project.world_atmosphere or "未设定",
|
||||
existing_characters=existing_chars_summary,
|
||||
all_chapters_brief=all_chapters_brief,
|
||||
start_chapter=start_chapter,
|
||||
chapter_count=chapter_count,
|
||||
plot_stage=plot_stage,
|
||||
story_direction=story_direction
|
||||
)
|
||||
|
||||
try:
|
||||
# 使用统一的JSON调用方法(支持自动MCP工具加载)
|
||||
analysis = await self.ai_service.call_with_json_retry(
|
||||
prompt=prompt,
|
||||
max_retries=3,
|
||||
)
|
||||
|
||||
logger.info(f" ✅ AI分析完成: needs_new_characters={analysis.get('needs_new_characters')}")
|
||||
return analysis
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f" ❌ 角色需求分析JSON解析失败: {e}")
|
||||
return {"needs_new_characters": False}
|
||||
except Exception as e:
|
||||
logger.error(f" ❌ 角色需求分析失败: {e}")
|
||||
return {"needs_new_characters": False}
|
||||
|
||||
async def _generate_character_details(
|
||||
self,
|
||||
spec: Dict[str, Any],
|
||||
project: Project,
|
||||
existing_characters: List[Character],
|
||||
db: AsyncSession,
|
||||
user_id: str,
|
||||
enable_mcp: bool
|
||||
) -> Dict[str, Any]:
|
||||
"""生成角色详细信息"""
|
||||
|
||||
# 🎯 获取项目职业列表
|
||||
from app.models.career import Career
|
||||
careers_result = await db.execute(
|
||||
select(Career)
|
||||
.where(Career.project_id == project.id)
|
||||
.order_by(Career.type, Career.name)
|
||||
)
|
||||
careers = careers_result.scalars().all()
|
||||
|
||||
# 构建职业信息摘要(包含最高阶段信息)
|
||||
careers_info = ""
|
||||
if careers:
|
||||
main_careers = [c for c in careers if c.type == 'main']
|
||||
sub_careers = [c for c in careers if c.type == 'sub']
|
||||
|
||||
if main_careers:
|
||||
careers_info += "\n\n可用主职业列表(请在career_info中填写职业名称和阶段):\n"
|
||||
for career in main_careers:
|
||||
careers_info += f"- 名称: {career.name}, 最高阶段: {career.max_stage}阶"
|
||||
if career.description:
|
||||
careers_info += f", 描述: {career.description[:50]}"
|
||||
careers_info += "\n"
|
||||
|
||||
if sub_careers:
|
||||
careers_info += "\n可用副职业列表(请在career_info中填写职业名称和阶段):\n"
|
||||
for career in sub_careers[:5]:
|
||||
careers_info += f"- 名称: {career.name}, 最高阶段: {career.max_stage}阶"
|
||||
if career.description:
|
||||
careers_info += f", 描述: {career.description[:50]}"
|
||||
careers_info += "\n"
|
||||
|
||||
careers_info += "\n⚠️ 重要提示:生成角色时,职业阶段不能超过该职业的最高阶段!\n"
|
||||
|
||||
# 构建角色生成提示词
|
||||
template = await PromptService.get_template(
|
||||
"AUTO_CHARACTER_GENERATION",
|
||||
user_id,
|
||||
db
|
||||
)
|
||||
|
||||
existing_chars_summary = self._build_character_summary(existing_characters)
|
||||
|
||||
prompt = PromptService.format_prompt(
|
||||
template,
|
||||
title=project.title,
|
||||
genre=project.genre or "未设定",
|
||||
theme=project.theme or "未设定",
|
||||
time_period=project.world_time_period or "未设定",
|
||||
location=project.world_location or "未设定",
|
||||
atmosphere=project.world_atmosphere or "未设定",
|
||||
rules=project.world_rules or "未设定",
|
||||
existing_characters=existing_chars_summary + careers_info,
|
||||
plot_context="根据剧情需要引入的新角色",
|
||||
character_specification=json.dumps(spec, ensure_ascii=False, indent=2),
|
||||
mcp_references="" # MCP工具通过AI服务自动加载
|
||||
)
|
||||
|
||||
logger.info(f"🔧 角色详情生成: enable_mcp={enable_mcp}")
|
||||
|
||||
# 调用AI生成
|
||||
try:
|
||||
character_data = await self.ai_service.call_with_json_retry(
|
||||
prompt=prompt,
|
||||
max_retries=2, # 减少重试次数以加快速度
|
||||
)
|
||||
|
||||
char_name = character_data.get('name', '未知')
|
||||
logger.info(f" ✅ 角色详情生成成功: {char_name}")
|
||||
logger.debug(f" 角色数据字段: {list(character_data.keys())}")
|
||||
|
||||
# 确保关键字段存在
|
||||
if 'name' not in character_data or not character_data['name']:
|
||||
logger.warning(f" ⚠️ AI返回的角色数据缺少name字段,使用规格中的信息")
|
||||
character_data['name'] = spec.get('name', f"新角色{spec.get('role_description', '')[:10]}")
|
||||
|
||||
return character_data
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f" ❌ 生成角色详情失败: {e}")
|
||||
raise
|
||||
|
||||
async def _create_character_record(
|
||||
self,
|
||||
project_id: str,
|
||||
character_data: Dict[str, Any],
|
||||
db: AsyncSession
|
||||
) -> Character:
|
||||
"""创建角色数据库记录"""
|
||||
|
||||
is_organization = character_data.get("is_organization", False)
|
||||
|
||||
# 提取职业信息(支持通过名称匹配)
|
||||
career_info = character_data.get("career_info", {})
|
||||
raw_main_career_name = career_info.get("main_career_name") if career_info else None
|
||||
main_career_stage = career_info.get("main_career_stage", 1) if career_info else None
|
||||
raw_sub_careers_data = career_info.get("sub_careers", []) if career_info else []
|
||||
|
||||
# 🔧 通过职业名称匹配数据库中的职业ID
|
||||
from app.models.career import Career, CharacterCareer
|
||||
main_career_id = None
|
||||
sub_careers_data = []
|
||||
|
||||
# 匹配主职业名称
|
||||
if raw_main_career_name and not is_organization:
|
||||
career_check = await db.execute(
|
||||
select(Career).where(
|
||||
Career.name == raw_main_career_name,
|
||||
Career.project_id == project_id,
|
||||
Career.type == 'main'
|
||||
)
|
||||
)
|
||||
matched_career = career_check.scalar_one_or_none()
|
||||
if matched_career:
|
||||
main_career_id = matched_career.id
|
||||
# ✅ 验证阶段不超过最高阶段
|
||||
if main_career_stage and main_career_stage > matched_career.max_stage:
|
||||
logger.warning(f" ⚠️ AI返回的主职业阶段({main_career_stage})超过最高阶段({matched_career.max_stage}),自动修正为最高阶段")
|
||||
main_career_stage = matched_career.max_stage
|
||||
logger.info(f" ✅ 主职业名称匹配成功: {raw_main_career_name} -> ID: {main_career_id}, 阶段: {main_career_stage}/{matched_career.max_stage}")
|
||||
else:
|
||||
logger.warning(f" ⚠️ AI返回的主职业名称未找到: {raw_main_career_name}")
|
||||
|
||||
# 匹配副职业名称
|
||||
if raw_sub_careers_data and not is_organization and isinstance(raw_sub_careers_data, list):
|
||||
for sub_data in raw_sub_careers_data[:2]:
|
||||
if isinstance(sub_data, dict):
|
||||
career_name = sub_data.get('career_name')
|
||||
if career_name:
|
||||
career_check = await db.execute(
|
||||
select(Career).where(
|
||||
Career.name == career_name,
|
||||
Career.project_id == project_id,
|
||||
Career.type == 'sub'
|
||||
)
|
||||
)
|
||||
matched_career = career_check.scalar_one_or_none()
|
||||
if matched_career:
|
||||
sub_stage = sub_data.get('stage', 1)
|
||||
# ✅ 验证阶段不超过最高阶段
|
||||
if sub_stage > matched_career.max_stage:
|
||||
logger.warning(f" ⚠️ AI返回的副职业阶段({sub_stage})超过最高阶段({matched_career.max_stage}),自动修正为最高阶段")
|
||||
sub_stage = matched_career.max_stage
|
||||
|
||||
sub_careers_data.append({
|
||||
'career_id': matched_career.id,
|
||||
'stage': sub_stage
|
||||
})
|
||||
logger.info(f" ✅ 副职业名称匹配成功: {career_name} -> ID: {matched_career.id}, 阶段: {sub_stage}/{matched_career.max_stage}")
|
||||
else:
|
||||
logger.warning(f" ⚠️ AI返回的副职业名称未找到: {career_name}")
|
||||
|
||||
# 创建角色
|
||||
character = Character(
|
||||
project_id=project_id,
|
||||
name=character_data.get("name", "未命名角色"),
|
||||
age=str(character_data.get("age", "")),
|
||||
gender=character_data.get("gender"),
|
||||
is_organization=is_organization,
|
||||
role_type=character_data.get("role_type", "supporting"),
|
||||
personality=character_data.get("personality", ""),
|
||||
background=character_data.get("background", ""),
|
||||
appearance=character_data.get("appearance", ""),
|
||||
relationships=character_data.get("relationships_text", ""),
|
||||
organization_type=character_data.get("organization_type") if is_organization else None,
|
||||
organization_purpose=character_data.get("organization_purpose") if is_organization else None,
|
||||
traits=json.dumps(character_data.get("traits", []), ensure_ascii=False) if character_data.get("traits") else None,
|
||||
main_career_id=main_career_id,
|
||||
main_career_stage=main_career_stage if main_career_id else None,
|
||||
sub_careers=json.dumps(sub_careers_data, ensure_ascii=False) if sub_careers_data else None
|
||||
)
|
||||
|
||||
db.add(character)
|
||||
await db.flush()
|
||||
|
||||
# 处理主职业关联
|
||||
if main_career_id and not is_organization:
|
||||
char_career = CharacterCareer(
|
||||
character_id=character.id,
|
||||
career_id=main_career_id,
|
||||
career_type='main',
|
||||
current_stage=main_career_stage,
|
||||
stage_progress=0
|
||||
)
|
||||
db.add(char_career)
|
||||
logger.info(f" ✅ 创建主职业关联: {character.name} -> {raw_main_career_name}")
|
||||
|
||||
# 处理副职业关联
|
||||
if sub_careers_data and not is_organization:
|
||||
for sub_data in sub_careers_data:
|
||||
char_career = CharacterCareer(
|
||||
character_id=character.id,
|
||||
career_id=sub_data['career_id'],
|
||||
career_type='sub',
|
||||
current_stage=sub_data['stage'],
|
||||
stage_progress=0
|
||||
)
|
||||
db.add(char_career)
|
||||
logger.info(f" ✅ 创建副职业关联: {character.name}, 数量: {len(sub_careers_data)}")
|
||||
|
||||
# 如果是组织,创建Organization记录
|
||||
if is_organization:
|
||||
org = Organization(
|
||||
character_id=character.id,
|
||||
project_id=project_id,
|
||||
member_count=0,
|
||||
power_level=character_data.get("power_level", 50),
|
||||
location=character_data.get("location"),
|
||||
motto=character_data.get("motto"),
|
||||
color=character_data.get("color")
|
||||
)
|
||||
db.add(org)
|
||||
await db.flush()
|
||||
logger.info(f" ✅ 创建组织详情: {character.name}")
|
||||
|
||||
return character
|
||||
|
||||
async def _create_relationships(
|
||||
self,
|
||||
new_character: Character,
|
||||
relationship_specs: List[Dict[str, Any]],
|
||||
existing_characters: List[Character],
|
||||
project_id: str,
|
||||
db: AsyncSession
|
||||
) -> List[CharacterRelationship]:
|
||||
"""创建角色关系"""
|
||||
|
||||
if not relationship_specs:
|
||||
return []
|
||||
|
||||
relationships = []
|
||||
|
||||
for rel_spec in relationship_specs:
|
||||
try:
|
||||
target_name = rel_spec.get("target_character_name")
|
||||
if not target_name:
|
||||
continue
|
||||
|
||||
# 查找目标角色
|
||||
target_char = next(
|
||||
(c for c in existing_characters if c.name == target_name),
|
||||
None
|
||||
)
|
||||
|
||||
if not target_char:
|
||||
logger.warning(f" ⚠️ 目标角色不存在: {target_name}")
|
||||
continue
|
||||
|
||||
# 检查关系是否已存在
|
||||
existing_rel = await db.execute(
|
||||
select(CharacterRelationship).where(
|
||||
CharacterRelationship.project_id == project_id,
|
||||
CharacterRelationship.character_from_id == new_character.id,
|
||||
CharacterRelationship.character_to_id == target_char.id
|
||||
)
|
||||
)
|
||||
if existing_rel.scalar_one_or_none():
|
||||
logger.debug(f" ℹ️ 关系已存在: {new_character.name} -> {target_name}")
|
||||
continue
|
||||
|
||||
# 创建关系
|
||||
relationship = CharacterRelationship(
|
||||
project_id=project_id,
|
||||
character_from_id=new_character.id,
|
||||
character_to_id=target_char.id,
|
||||
relationship_name=rel_spec.get("relationship_type", "未知关系"),
|
||||
intimacy_level=rel_spec.get("intimacy_level", 50),
|
||||
description=rel_spec.get("description", ""),
|
||||
status=rel_spec.get("status", "active"),
|
||||
source="auto" # 标记为自动生成
|
||||
)
|
||||
|
||||
# 尝试匹配预定义关系类型
|
||||
rel_type_name = rel_spec.get("relationship_type")
|
||||
if rel_type_name:
|
||||
rel_type_result = await db.execute(
|
||||
select(RelationshipType).where(
|
||||
RelationshipType.name == rel_type_name
|
||||
)
|
||||
)
|
||||
rel_type = rel_type_result.scalar_one_or_none()
|
||||
if rel_type:
|
||||
relationship.relationship_type_id = rel_type.id
|
||||
|
||||
db.add(relationship)
|
||||
relationships.append(relationship)
|
||||
|
||||
logger.info(
|
||||
f" ✅ 创建关系: {new_character.name} -> {target_name} "
|
||||
f"({rel_spec.get('relationship_type', '未知')})"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f" ❌ 创建关系失败: {e}")
|
||||
continue
|
||||
|
||||
return relationships
|
||||
|
||||
|
||||
# 全局实例缓存
|
||||
_auto_character_service_instance: Optional[AutoCharacterService] = None
|
||||
|
||||
|
||||
def get_auto_character_service(ai_service: AIService) -> AutoCharacterService:
|
||||
"""获取自动角色服务实例(单例模式)"""
|
||||
global _auto_character_service_instance
|
||||
if _auto_character_service_instance is None:
|
||||
_auto_character_service_instance = AutoCharacterService(ai_service)
|
||||
return _auto_character_service_instance
|
||||
@@ -0,0 +1,497 @@
|
||||
"""自动组织引入服务 - 在续写大纲时根据剧情推进自动引入新组织"""
|
||||
from typing import List, Dict, Any, Optional, Callable, Awaitable
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select
|
||||
import json
|
||||
|
||||
from app.models.character import Character
|
||||
from app.models.relationship import Organization, OrganizationMember
|
||||
from app.models.project import Project
|
||||
from app.services.ai_service import AIService
|
||||
from app.services.prompt_service import PromptService
|
||||
from app.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class AutoOrganizationService:
|
||||
"""自动组织引入服务"""
|
||||
|
||||
def __init__(self, ai_service: AIService):
|
||||
self.ai_service = ai_service
|
||||
|
||||
async def analyze_and_create_organizations(
|
||||
self,
|
||||
project_id: str,
|
||||
outline_content: str,
|
||||
existing_characters: List[Character],
|
||||
existing_organizations: List[Dict[str, Any]],
|
||||
db: AsyncSession,
|
||||
user_id: str = None,
|
||||
enable_mcp: bool = True,
|
||||
all_chapters_brief: str = "",
|
||||
start_chapter: int = 1,
|
||||
chapter_count: int = 3,
|
||||
plot_stage: str = "发展",
|
||||
story_direction: str = "继续推进主线剧情",
|
||||
preview_only: bool = False,
|
||||
progress_callback: Optional[Callable[[str], Awaitable[None]]] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
预测性分析并创建需要的新组织
|
||||
|
||||
Args:
|
||||
project_id: 项目ID
|
||||
outline_content: 当前批次大纲内容(用于向后兼容,实际不使用)
|
||||
existing_characters: 现有角色列表
|
||||
existing_organizations: 现有组织列表
|
||||
db: 数据库会话
|
||||
user_id: 用户ID(用于MCP和自定义提示词)
|
||||
enable_mcp: 是否启用MCP增强
|
||||
all_chapters_brief: 已有章节概览
|
||||
start_chapter: 起始章节号
|
||||
chapter_count: 续写章节数
|
||||
plot_stage: 剧情阶段
|
||||
story_direction: 故事发展方向
|
||||
preview_only: 仅预测不创建(用于组织确认机制)
|
||||
|
||||
Returns:
|
||||
{
|
||||
"new_organizations": [组织对象列表], # preview_only=True时为空
|
||||
"members_created": [成员关系列表], # preview_only=True时为空
|
||||
"organization_count": 新增组织数量,
|
||||
"analysis_result": AI分析结果,
|
||||
"predicted_organizations": [预测的组织数据] # 仅preview_only=True时返回
|
||||
"needs_new_organizations": bool,
|
||||
"reason": str
|
||||
}
|
||||
"""
|
||||
logger.info(f"🏛️ 【组织引入】预测性分析:检测是否需要引入新组织...")
|
||||
logger.info(f" - 项目ID: {project_id}")
|
||||
logger.info(f" - 续写计划: 第{start_chapter}章起,共{chapter_count}章")
|
||||
logger.info(f" - 剧情阶段: {plot_stage}")
|
||||
logger.info(f" - 发展方向: {story_direction}")
|
||||
logger.info(f" - 现有角色数: {len(existing_characters)}")
|
||||
logger.info(f" - 现有组织数: {len(existing_organizations)}")
|
||||
|
||||
# 1. 获取项目信息
|
||||
project_result = await db.execute(
|
||||
select(Project).where(Project.id == project_id)
|
||||
)
|
||||
project = project_result.scalar_one_or_none()
|
||||
if not project:
|
||||
raise ValueError("项目不存在")
|
||||
|
||||
# 2. 构建现有组织信息摘要
|
||||
existing_orgs_summary = self._build_organization_summary(existing_organizations)
|
||||
existing_chars_summary = self._build_character_summary(existing_characters)
|
||||
|
||||
# 3. AI预测性分析是否需要新组织
|
||||
if progress_callback:
|
||||
await progress_callback("🤖 AI分析组织需求...")
|
||||
|
||||
analysis_result = await self._analyze_organization_needs(
|
||||
project=project,
|
||||
outline_content=outline_content,
|
||||
existing_orgs_summary=existing_orgs_summary,
|
||||
existing_chars_summary=existing_chars_summary,
|
||||
db=db,
|
||||
user_id=user_id,
|
||||
enable_mcp=enable_mcp,
|
||||
all_chapters_brief=all_chapters_brief,
|
||||
start_chapter=start_chapter,
|
||||
chapter_count=chapter_count,
|
||||
plot_stage=plot_stage,
|
||||
story_direction=story_direction
|
||||
)
|
||||
|
||||
if progress_callback:
|
||||
await progress_callback("✅ 组织需求分析完成")
|
||||
|
||||
# 4. 判断是否需要创建组织
|
||||
if not analysis_result or not analysis_result.get("needs_new_organizations"):
|
||||
logger.info("✅ AI判断:当前剧情不需要引入新组织")
|
||||
return {
|
||||
"new_organizations": [],
|
||||
"members_created": [],
|
||||
"organization_count": 0,
|
||||
"analysis_result": analysis_result,
|
||||
"predicted_organizations": [],
|
||||
"needs_new_organizations": False,
|
||||
"reason": analysis_result.get("reason", "当前剧情不需要新组织")
|
||||
}
|
||||
|
||||
# 5. 如果是预览模式,仅返回预测结果,不创建组织
|
||||
if preview_only:
|
||||
organization_specs = analysis_result.get("organization_specifications", [])
|
||||
logger.info(f"🔮 预览模式:预测到 {len(organization_specs)} 个组织,不创建数据库记录")
|
||||
return {
|
||||
"new_organizations": [],
|
||||
"members_created": [],
|
||||
"organization_count": 0,
|
||||
"analysis_result": analysis_result,
|
||||
"predicted_organizations": organization_specs,
|
||||
"needs_new_organizations": True,
|
||||
"reason": analysis_result.get("reason", "预测需要新组织")
|
||||
}
|
||||
|
||||
# 6. 批量生成新组织(非预览模式)
|
||||
new_organizations = []
|
||||
members_created = []
|
||||
|
||||
organization_specs = analysis_result.get("organization_specifications", [])
|
||||
logger.info(f"🎯 AI建议引入 {len(organization_specs)} 个新组织")
|
||||
|
||||
for idx, spec in enumerate(organization_specs):
|
||||
try:
|
||||
spec_name = spec.get('name', spec.get('organization_description', '未命名'))
|
||||
logger.info(f" [{idx+1}/{len(organization_specs)}] 生成组织规格: {spec_name}")
|
||||
logger.debug(f" 组织规格内容: {json.dumps(spec, ensure_ascii=False)}")
|
||||
|
||||
if progress_callback:
|
||||
await progress_callback(f"🏛️ [{idx+1}/{len(organization_specs)}] 生成组织详情: {spec_name}")
|
||||
|
||||
# 生成组织详细信息
|
||||
organization_data = await self._generate_organization_details(
|
||||
spec=spec,
|
||||
project=project,
|
||||
existing_characters=existing_characters,
|
||||
existing_organizations=existing_organizations,
|
||||
db=db,
|
||||
user_id=user_id,
|
||||
enable_mcp=enable_mcp
|
||||
)
|
||||
|
||||
logger.debug(f" AI生成的组织数据: {json.dumps(organization_data, ensure_ascii=False)[:200]}")
|
||||
|
||||
if progress_callback:
|
||||
await progress_callback(f"💾 [{idx+1}/{len(organization_specs)}] 保存组织: {organization_data.get('name', spec_name)}")
|
||||
|
||||
# 创建组织记录(先创建Character记录,再创建Organization记录)
|
||||
character, organization = await self._create_organization_record(
|
||||
project_id=project_id,
|
||||
organization_data=organization_data,
|
||||
db=db
|
||||
)
|
||||
|
||||
new_organizations.append({
|
||||
"character": character,
|
||||
"organization": organization
|
||||
})
|
||||
logger.info(f" ✅ 创建新组织: {character.name}, ID: {organization.id}")
|
||||
|
||||
if progress_callback:
|
||||
await progress_callback(f"✅ [{idx+1}/{len(organization_specs)}] 组织创建成功: {character.name}")
|
||||
|
||||
# 建立成员关系
|
||||
members_data = organization_data.get("initial_members", [])
|
||||
if members_data:
|
||||
logger.info(f" 🔗 开始创建 {len(members_data)} 个成员关系...")
|
||||
|
||||
if progress_callback:
|
||||
await progress_callback(f"🔗 [{idx+1}/{len(organization_specs)}] 建立 {len(members_data)} 个成员关系")
|
||||
|
||||
members = await self._create_member_relationships(
|
||||
organization=organization,
|
||||
member_specs=members_data,
|
||||
existing_characters=existing_characters,
|
||||
project_id=project_id,
|
||||
db=db
|
||||
)
|
||||
members_created.extend(members)
|
||||
logger.info(f" ✅ 实际创建了 {len(members)} 个成员关系记录")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f" ❌ 创建组织失败: {e}", exc_info=True)
|
||||
continue
|
||||
|
||||
# 7. 提交事务(注意:这里只flush,让调用方commit)
|
||||
await db.flush()
|
||||
|
||||
logger.info(f"🎉 自动组织引入完成: 新增{len(new_organizations)}个组织, {len(members_created)}个成员关系")
|
||||
|
||||
return {
|
||||
"new_organizations": new_organizations,
|
||||
"members_created": members_created,
|
||||
"organization_count": len(new_organizations),
|
||||
"analysis_result": analysis_result,
|
||||
"predicted_organizations": [],
|
||||
"needs_new_organizations": True,
|
||||
"reason": analysis_result.get("reason", "")
|
||||
}
|
||||
|
||||
def _build_organization_summary(self, organizations: List[Dict[str, Any]]) -> str:
|
||||
"""构建现有组织摘要"""
|
||||
if not organizations:
|
||||
return "暂无组织"
|
||||
|
||||
summary = []
|
||||
for org in organizations:
|
||||
org_name = org.get("name", "未知")
|
||||
org_type = org.get("organization_type", "未知类型")
|
||||
power_level = org.get("power_level", 50)
|
||||
purpose = (org.get("organization_purpose") or "")[:50]
|
||||
summary.append(f"- {org_name} ({org_type}, 势力等级:{power_level}): {purpose}")
|
||||
|
||||
return "\n".join(summary[:15]) # 最多显示15个
|
||||
|
||||
def _build_character_summary(self, characters: List[Character]) -> str:
|
||||
"""构建现有角色摘要"""
|
||||
if not characters:
|
||||
return "暂无角色"
|
||||
|
||||
summary = []
|
||||
for char in characters:
|
||||
if not char.is_organization: # 只统计非组织角色
|
||||
char_role = char.role_type or "未知"
|
||||
personality = (char.personality or "")[:30]
|
||||
summary.append(f"- {char.name} ({char_role}): {personality}")
|
||||
|
||||
return "\n".join(summary[:20]) # 最多显示20个
|
||||
|
||||
async def _analyze_organization_needs(
|
||||
self,
|
||||
project: Project,
|
||||
outline_content: str,
|
||||
existing_orgs_summary: str,
|
||||
existing_chars_summary: str,
|
||||
db: AsyncSession,
|
||||
user_id: str,
|
||||
enable_mcp: bool,
|
||||
all_chapters_brief: str = "",
|
||||
start_chapter: int = 1,
|
||||
chapter_count: int = 3,
|
||||
plot_stage: str = "发展",
|
||||
story_direction: str = "继续推进主线剧情"
|
||||
) -> Dict[str, Any]:
|
||||
"""AI预测性分析是否需要新组织"""
|
||||
|
||||
# 构建分析提示词
|
||||
template = await PromptService.get_template(
|
||||
"AUTO_ORGANIZATION_ANALYSIS",
|
||||
user_id,
|
||||
db
|
||||
)
|
||||
|
||||
# 使用新的预测性分析参数
|
||||
prompt = PromptService.format_prompt(
|
||||
template,
|
||||
title=project.title,
|
||||
theme=project.theme or "未设定",
|
||||
genre=project.genre or "未设定",
|
||||
time_period=project.world_time_period or "未设定",
|
||||
location=project.world_location or "未设定",
|
||||
atmosphere=project.world_atmosphere or "未设定",
|
||||
existing_organizations=existing_orgs_summary,
|
||||
existing_characters=existing_chars_summary,
|
||||
all_chapters_brief=all_chapters_brief,
|
||||
start_chapter=start_chapter,
|
||||
chapter_count=chapter_count,
|
||||
plot_stage=plot_stage,
|
||||
story_direction=story_direction
|
||||
)
|
||||
|
||||
try:
|
||||
# 使用统一的JSON调用方法(支持自动MCP工具加载)
|
||||
analysis = await self.ai_service.call_with_json_retry(
|
||||
prompt=prompt,
|
||||
max_retries=3,
|
||||
)
|
||||
|
||||
logger.info(f" ✅ AI分析完成: needs_new_organizations={analysis.get('needs_new_organizations')}")
|
||||
return analysis
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f" ❌ 组织需求分析JSON解析失败: {e}")
|
||||
return {"needs_new_organizations": False}
|
||||
except Exception as e:
|
||||
logger.error(f" ❌ 组织需求分析失败: {e}")
|
||||
return {"needs_new_organizations": False}
|
||||
|
||||
async def _generate_organization_details(
|
||||
self,
|
||||
spec: Dict[str, Any],
|
||||
project: Project,
|
||||
existing_characters: List[Character],
|
||||
existing_organizations: List[Dict[str, Any]],
|
||||
db: AsyncSession,
|
||||
user_id: str,
|
||||
enable_mcp: bool
|
||||
) -> Dict[str, Any]:
|
||||
"""生成组织详细信息"""
|
||||
|
||||
# 构建组织生成提示词
|
||||
template = await PromptService.get_template(
|
||||
"AUTO_ORGANIZATION_GENERATION",
|
||||
user_id,
|
||||
db
|
||||
)
|
||||
|
||||
existing_orgs_summary = self._build_organization_summary(existing_organizations)
|
||||
existing_chars_summary = self._build_character_summary(existing_characters)
|
||||
|
||||
prompt = PromptService.format_prompt(
|
||||
template,
|
||||
title=project.title,
|
||||
genre=project.genre or "未设定",
|
||||
theme=project.theme or "未设定",
|
||||
time_period=project.world_time_period or "未设定",
|
||||
location=project.world_location or "未设定",
|
||||
atmosphere=project.world_atmosphere or "未设定",
|
||||
rules=project.world_rules or "未设定",
|
||||
existing_organizations=existing_orgs_summary,
|
||||
existing_characters=existing_chars_summary,
|
||||
plot_context="根据剧情需要引入的新组织",
|
||||
organization_specification=json.dumps(spec, ensure_ascii=False, indent=2),
|
||||
mcp_references="" # 暂时不使用MCP增强
|
||||
)
|
||||
|
||||
# 调用AI生成(使用统一的JSON调用方法)
|
||||
try:
|
||||
# 使用统一的JSON调用方法(支持自动MCP工具加载)
|
||||
organization_data = await self.ai_service.call_with_json_retry(
|
||||
prompt=prompt,
|
||||
max_retries=3,
|
||||
)
|
||||
|
||||
org_name = organization_data.get('name', '未知')
|
||||
logger.info(f" ✅ 组织详情生成成功: {org_name}")
|
||||
logger.debug(f" 组织数据字段: {list(organization_data.keys())}")
|
||||
|
||||
# 确保关键字段存在
|
||||
if 'name' not in organization_data or not organization_data['name']:
|
||||
logger.warning(f" ⚠️ AI返回的组织数据缺少name字段,使用规格中的信息")
|
||||
organization_data['name'] = spec.get('name', f"新组织{spec.get('organization_description', '')[:10]}")
|
||||
|
||||
return organization_data
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f" ❌ 生成组织详情失败: {e}")
|
||||
raise
|
||||
|
||||
async def _create_organization_record(
|
||||
self,
|
||||
project_id: str,
|
||||
organization_data: Dict[str, Any],
|
||||
db: AsyncSession
|
||||
) -> tuple:
|
||||
"""创建组织数据库记录(包括Character和Organization)"""
|
||||
|
||||
# 首先创建Character记录(is_organization=True)
|
||||
character = Character(
|
||||
project_id=project_id,
|
||||
name=organization_data.get("name", "未命名组织"),
|
||||
is_organization=True,
|
||||
role_type=organization_data.get("role_type", "supporting"),
|
||||
personality=organization_data.get("personality", ""), # 组织特性
|
||||
background=organization_data.get("background", ""), # 组织背景
|
||||
appearance=organization_data.get("appearance", ""), # 外在表现
|
||||
organization_type=organization_data.get("organization_type"),
|
||||
organization_purpose=organization_data.get("organization_purpose"),
|
||||
traits=json.dumps(organization_data.get("traits", []), ensure_ascii=False) if organization_data.get("traits") else None
|
||||
)
|
||||
|
||||
db.add(character)
|
||||
await db.flush()
|
||||
|
||||
# 然后创建Organization记录
|
||||
organization = Organization(
|
||||
character_id=character.id,
|
||||
project_id=project_id,
|
||||
power_level=organization_data.get("power_level", 50),
|
||||
member_count=0,
|
||||
location=organization_data.get("location"),
|
||||
motto=organization_data.get("motto"),
|
||||
color=organization_data.get("color")
|
||||
)
|
||||
|
||||
db.add(organization)
|
||||
await db.flush()
|
||||
|
||||
logger.info(f" ✅ 创建组织记录: {character.name}, Organization ID: {organization.id}")
|
||||
|
||||
return character, organization
|
||||
|
||||
async def _create_member_relationships(
|
||||
self,
|
||||
organization: Organization,
|
||||
member_specs: List[Dict[str, Any]],
|
||||
existing_characters: List[Character],
|
||||
project_id: str,
|
||||
db: AsyncSession
|
||||
) -> List[OrganizationMember]:
|
||||
"""创建组织成员关系"""
|
||||
|
||||
if not member_specs:
|
||||
return []
|
||||
|
||||
members = []
|
||||
|
||||
for member_spec in member_specs:
|
||||
try:
|
||||
character_name = member_spec.get("character_name")
|
||||
if not character_name:
|
||||
continue
|
||||
|
||||
# 查找目标角色
|
||||
target_char = next(
|
||||
(c for c in existing_characters if c.name == character_name and not c.is_organization),
|
||||
None
|
||||
)
|
||||
|
||||
if not target_char:
|
||||
logger.warning(f" ⚠️ 目标角色不存在: {character_name}")
|
||||
continue
|
||||
|
||||
# 检查成员关系是否已存在
|
||||
existing_member = await db.execute(
|
||||
select(OrganizationMember).where(
|
||||
OrganizationMember.organization_id == organization.id,
|
||||
OrganizationMember.character_id == target_char.id
|
||||
)
|
||||
)
|
||||
if existing_member.scalar_one_or_none():
|
||||
logger.debug(f" ℹ️ 成员关系已存在: {character_name} -> {organization.id}")
|
||||
continue
|
||||
|
||||
# 创建成员关系
|
||||
member = OrganizationMember(
|
||||
organization_id=organization.id,
|
||||
character_id=target_char.id,
|
||||
position=member_spec.get("position", "成员"),
|
||||
rank=member_spec.get("rank", 0),
|
||||
loyalty=member_spec.get("loyalty", 50),
|
||||
status=member_spec.get("status", "active"),
|
||||
joined_at=member_spec.get("joined_at"),
|
||||
source="auto" # 标记为自动生成
|
||||
)
|
||||
|
||||
db.add(member)
|
||||
members.append(member)
|
||||
|
||||
logger.info(
|
||||
f" ✅ 创建成员关系: {character_name} -> {organization.id} "
|
||||
f"({member_spec.get('position', '成员')})"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f" ❌ 创建成员关系失败: {e}")
|
||||
continue
|
||||
|
||||
# 更新组织成员数量
|
||||
if members:
|
||||
organization.member_count = (organization.member_count or 0) + len(members)
|
||||
|
||||
return members
|
||||
|
||||
|
||||
# 全局实例缓存
|
||||
_auto_organization_service_instance: Optional[AutoOrganizationService] = None
|
||||
|
||||
|
||||
def get_auto_organization_service(ai_service: AIService) -> AutoOrganizationService:
|
||||
"""获取自动组织服务实例(单例模式)"""
|
||||
global _auto_organization_service_instance
|
||||
if _auto_organization_service_instance is None:
|
||||
_auto_organization_service_instance = AutoOrganizationService(ai_service)
|
||||
return _auto_organization_service_instance
|
||||
@@ -0,0 +1,234 @@
|
||||
"""职业生成服务"""
|
||||
from typing import Dict, Any, List
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select
|
||||
import json
|
||||
|
||||
from app.models.project import Project
|
||||
from app.models.career import Career
|
||||
from app.services.ai_service import AIService
|
||||
from app.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class CareerService:
|
||||
"""职业相关业务逻辑服务"""
|
||||
|
||||
@staticmethod
|
||||
async def get_career_generation_prompt(
|
||||
project: Project,
|
||||
main_career_count: int = 2,
|
||||
sub_career_count: int = 6
|
||||
) -> str:
|
||||
"""
|
||||
构建职业体系生成的提示词
|
||||
|
||||
Args:
|
||||
project: 项目对象
|
||||
main_career_count: 主职业数量
|
||||
sub_career_count: 副职业数量
|
||||
|
||||
Returns:
|
||||
完整的提示词
|
||||
"""
|
||||
project_context = f"""
|
||||
项目信息:
|
||||
- 书名:{project.title}
|
||||
- 类型:{project.genre or '未设定'}
|
||||
- 主题:{project.theme or '未设定'}
|
||||
- 时间背景:{project.world_time_period or '未设定'}
|
||||
- 地理位置:{project.world_location or '未设定'}
|
||||
- 氛围基调:{project.world_atmosphere or '未设定'}
|
||||
- 世界规则:{project.world_rules or '未设定'}
|
||||
"""
|
||||
|
||||
user_requirements = f"""
|
||||
生成要求:
|
||||
- 主职业数量:{main_career_count}个
|
||||
- 副职业数量:{sub_career_count}个
|
||||
- 主职业必须严格符合世界观规则,体现核心能力体系
|
||||
- 副职业可以更加自由灵活,包含生产、辅助、特殊类型
|
||||
"""
|
||||
|
||||
prompt = f"""{project_context}
|
||||
|
||||
{user_requirements}
|
||||
|
||||
请为这个小说项目生成完整的职业体系。返回JSON格式,结构如下:
|
||||
|
||||
{{
|
||||
"main_careers": [
|
||||
{{
|
||||
"name": "职业名称",
|
||||
"description": "职业描述(100-200字)",
|
||||
"category": "职业分类(如:战斗系、法术系、体修系等)",
|
||||
"stages": [
|
||||
{{"level": 1, "name": "阶段名称", "description": "阶段描述"}},
|
||||
{{"level": 2, "name": "阶段名称", "description": "阶段描述"}},
|
||||
...(共10个阶段)
|
||||
],
|
||||
"max_stage": 10,
|
||||
"requirements": "职业要求(如:需要特定天赋、资质等)",
|
||||
"special_abilities": "特殊能力描述",
|
||||
"worldview_rules": "世界观规则关联(说明该职业如何融入世界观)",
|
||||
"attribute_bonuses": {{"strength": "+10%", "intelligence": "+5%"}}
|
||||
}}
|
||||
],
|
||||
"sub_careers": [
|
||||
{{
|
||||
"name": "副职业名称",
|
||||
"description": "职业描述",
|
||||
"category": "生产系/辅助系/特殊系",
|
||||
"stages": [
|
||||
{{"level": 1, "name": "阶段名称", "description": "阶段描述"}},
|
||||
...(5-8个阶段)
|
||||
],
|
||||
"max_stage": 5,
|
||||
"requirements": "职业要求",
|
||||
"special_abilities": "特殊能力"
|
||||
}}
|
||||
]
|
||||
}}
|
||||
|
||||
重要注意事项:
|
||||
1. 主职业的阶段设定要详细,体现明确的成长路径,阶段名称要有特色
|
||||
2. 根据小说类型选择合适的职业:
|
||||
- 修仙类:剑修、体修、法修、符修等,阶段如:炼气、筑基、金丹、元婴...
|
||||
- 玄幻类:战士、法师、刺客等,阶段如:见习、初级、中级、高级...
|
||||
- 都市异能:异能者分类,阶段如:觉醒、初阶、中阶、高阶...
|
||||
- 科幻未来:基因战士、机甲师等,阶段如:E级、D级、C级、B级...
|
||||
3. 副职业要有实用性和趣味性,如:炼丹师、炼器师、阵法师、驯兽师、医师等
|
||||
4. 所有职业都要符合项目的整体世界观设定
|
||||
5. 阶段描述要简洁明了,体现该阶段的核心特征
|
||||
6. **只返回纯JSON对象,不要添加任何解释文字或markdown标记**
|
||||
"""
|
||||
|
||||
return prompt
|
||||
|
||||
@staticmethod
|
||||
async def parse_and_save_careers(
|
||||
career_data: Dict[str, Any],
|
||||
project_id: str,
|
||||
db: AsyncSession
|
||||
) -> Dict[str, List[str]]:
|
||||
"""
|
||||
解析AI返回的职业数据并保存到数据库
|
||||
|
||||
Args:
|
||||
career_data: AI返回的职业数据(已解析为dict)
|
||||
project_id: 项目ID
|
||||
db: 数据库会话
|
||||
|
||||
Returns:
|
||||
{"main_careers": [...], "sub_careers": [...]} 创建的职业名称列表
|
||||
"""
|
||||
result = {
|
||||
"main_careers": [],
|
||||
"sub_careers": []
|
||||
}
|
||||
|
||||
# 保存主职业
|
||||
for idx, career_info in enumerate(career_data.get("main_careers", [])):
|
||||
try:
|
||||
stages_json = json.dumps(career_info.get("stages", []), ensure_ascii=False)
|
||||
attribute_bonuses = career_info.get("attribute_bonuses")
|
||||
attribute_bonuses_json = json.dumps(attribute_bonuses, ensure_ascii=False) if attribute_bonuses else None
|
||||
|
||||
career = Career(
|
||||
project_id=project_id,
|
||||
name=career_info.get("name", f"未命名主职业{idx+1}"),
|
||||
type="main",
|
||||
description=career_info.get("description"),
|
||||
category=career_info.get("category"),
|
||||
stages=stages_json,
|
||||
max_stage=career_info.get("max_stage", 10),
|
||||
requirements=career_info.get("requirements"),
|
||||
special_abilities=career_info.get("special_abilities"),
|
||||
worldview_rules=career_info.get("worldview_rules"),
|
||||
attribute_bonuses=attribute_bonuses_json,
|
||||
source="ai"
|
||||
)
|
||||
db.add(career)
|
||||
await db.flush()
|
||||
result["main_careers"].append(career.name)
|
||||
logger.info(f" ✅ 创建主职业:{career.name}")
|
||||
except Exception as e:
|
||||
logger.error(f" ❌ 创建主职业失败:{str(e)}")
|
||||
continue
|
||||
|
||||
# 保存副职业
|
||||
for idx, career_info in enumerate(career_data.get("sub_careers", [])):
|
||||
try:
|
||||
stages_json = json.dumps(career_info.get("stages", []), ensure_ascii=False)
|
||||
attribute_bonuses = career_info.get("attribute_bonuses")
|
||||
attribute_bonuses_json = json.dumps(attribute_bonuses, ensure_ascii=False) if attribute_bonuses else None
|
||||
|
||||
career = Career(
|
||||
project_id=project_id,
|
||||
name=career_info.get("name", f"未命名副职业{idx+1}"),
|
||||
type="sub",
|
||||
description=career_info.get("description"),
|
||||
category=career_info.get("category"),
|
||||
stages=stages_json,
|
||||
max_stage=career_info.get("max_stage", 5),
|
||||
requirements=career_info.get("requirements"),
|
||||
special_abilities=career_info.get("special_abilities"),
|
||||
worldview_rules=career_info.get("worldview_rules"),
|
||||
attribute_bonuses=attribute_bonuses_json,
|
||||
source="ai"
|
||||
)
|
||||
db.add(career)
|
||||
await db.flush()
|
||||
result["sub_careers"].append(career.name)
|
||||
logger.info(f" ✅ 创建副职业:{career.name}")
|
||||
except Exception as e:
|
||||
logger.error(f" ❌ 创建副职业失败:{str(e)}")
|
||||
continue
|
||||
|
||||
await db.commit()
|
||||
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
async def get_project_careers_summary(project_id: str, db: AsyncSession) -> Dict[str, Any]:
|
||||
"""
|
||||
获取项目职业体系摘要
|
||||
|
||||
Args:
|
||||
project_id: 项目ID
|
||||
db: 数据库会话
|
||||
|
||||
Returns:
|
||||
职业体系摘要信息
|
||||
"""
|
||||
result = await db.execute(
|
||||
select(Career).where(Career.project_id == project_id)
|
||||
)
|
||||
careers = result.scalars().all()
|
||||
|
||||
main_careers = []
|
||||
sub_careers = []
|
||||
|
||||
for career in careers:
|
||||
career_info = {
|
||||
"id": career.id,
|
||||
"name": career.name,
|
||||
"category": career.category,
|
||||
"max_stage": career.max_stage
|
||||
}
|
||||
|
||||
if career.type == "main":
|
||||
main_careers.append(career_info)
|
||||
else:
|
||||
sub_careers.append(career_info)
|
||||
|
||||
return {
|
||||
"main_careers": main_careers,
|
||||
"sub_careers": sub_careers,
|
||||
"total_count": len(careers)
|
||||
}
|
||||
|
||||
|
||||
# 创建全局服务实例
|
||||
career_service = CareerService()
|
||||
@@ -0,0 +1,398 @@
|
||||
"""职业更新服务 - 根据章节分析自动更新角色职业信息"""
|
||||
from typing import Dict, Any, List, Optional
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select
|
||||
from app.models.character import Character
|
||||
from app.models.career import Career, CharacterCareer
|
||||
from app.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class CareerUpdateService:
|
||||
"""职业更新服务 - 根据章节分析结果自动更新角色职业"""
|
||||
|
||||
@staticmethod
|
||||
async def update_careers_from_analysis(
|
||||
db: AsyncSession,
|
||||
project_id: str,
|
||||
character_states: List[Dict[str, Any]],
|
||||
chapter_id: str,
|
||||
chapter_number: int
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
根据章节分析结果更新角色职业
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
project_id: 项目ID
|
||||
character_states: 角色状态变化列表(来自PlotAnalysis)
|
||||
chapter_id: 章节ID
|
||||
chapter_number: 章节编号
|
||||
|
||||
Returns:
|
||||
更新结果字典,包含更新数量和变更日志
|
||||
"""
|
||||
if not character_states:
|
||||
logger.info("📋 角色状态列表为空,跳过职业更新")
|
||||
return {"updated_count": 0, "changes": []}
|
||||
|
||||
updated_count = 0
|
||||
changes_log = []
|
||||
|
||||
logger.info(f"🔍 开始分析第{chapter_number}章的角色职业变化...")
|
||||
|
||||
for char_state in character_states:
|
||||
char_name = char_state.get('character_name')
|
||||
career_changes = char_state.get('career_changes', {})
|
||||
|
||||
# 如果没有职业变化信息,跳过
|
||||
if not career_changes or not isinstance(career_changes, dict):
|
||||
continue
|
||||
|
||||
# 检查是否有实质性的职业变化
|
||||
main_stage_change = career_changes.get('main_career_stage_change', 0)
|
||||
sub_career_changes = career_changes.get('sub_career_changes', [])
|
||||
new_careers = career_changes.get('new_careers', [])
|
||||
|
||||
if main_stage_change == 0 and not sub_career_changes and not new_careers:
|
||||
continue
|
||||
|
||||
logger.info(f" 👤 检测到角色 [{char_name}] 有职业变化")
|
||||
|
||||
# 1. 查询角色
|
||||
char_result = await db.execute(
|
||||
select(Character).where(
|
||||
Character.name == char_name,
|
||||
Character.project_id == project_id
|
||||
)
|
||||
)
|
||||
character = char_result.scalar_one_or_none()
|
||||
|
||||
if not character:
|
||||
logger.warning(f" ⚠️ 角色不存在: {char_name},跳过")
|
||||
continue
|
||||
|
||||
# 2. 更新主职业阶段
|
||||
if main_stage_change != 0 and character.main_career_id:
|
||||
success = await CareerUpdateService._update_main_career_stage(
|
||||
db=db,
|
||||
character=character,
|
||||
stage_change=main_stage_change,
|
||||
chapter_number=chapter_number,
|
||||
career_changes=career_changes,
|
||||
changes_log=changes_log
|
||||
)
|
||||
if success:
|
||||
updated_count += 1
|
||||
|
||||
# 3. 更新副职业(如果有)
|
||||
if sub_career_changes and isinstance(sub_career_changes, list):
|
||||
for sub_change in sub_career_changes:
|
||||
success = await CareerUpdateService._update_sub_career_stage(
|
||||
db=db,
|
||||
character=character,
|
||||
project_id=project_id,
|
||||
sub_change=sub_change,
|
||||
chapter_number=chapter_number,
|
||||
changes_log=changes_log
|
||||
)
|
||||
if success:
|
||||
updated_count += 1
|
||||
|
||||
# 4. 添加新职业(如果有)
|
||||
if new_careers and isinstance(new_careers, list):
|
||||
for new_career_name in new_careers:
|
||||
success = await CareerUpdateService._add_new_career(
|
||||
db=db,
|
||||
character=character,
|
||||
project_id=project_id,
|
||||
career_name=new_career_name,
|
||||
chapter_number=chapter_number,
|
||||
changes_log=changes_log
|
||||
)
|
||||
if success:
|
||||
updated_count += 1
|
||||
|
||||
# 提交所有更改
|
||||
if updated_count > 0:
|
||||
await db.commit()
|
||||
logger.info(f"✅ 职业更新完成: 共更新了 {updated_count} 个角色的职业信息")
|
||||
else:
|
||||
logger.info("📋 本章没有角色职业变化")
|
||||
|
||||
return {
|
||||
"updated_count": updated_count,
|
||||
"changes": changes_log
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
async def _update_main_career_stage(
|
||||
db: AsyncSession,
|
||||
character: Character,
|
||||
stage_change: int,
|
||||
chapter_number: int,
|
||||
career_changes: Dict[str, Any],
|
||||
changes_log: List[Dict[str, Any]]
|
||||
) -> bool:
|
||||
"""更新主职业阶段"""
|
||||
try:
|
||||
# 查询主职业关联
|
||||
char_career_result = await db.execute(
|
||||
select(CharacterCareer).where(
|
||||
CharacterCareer.character_id == character.id,
|
||||
CharacterCareer.career_type == 'main'
|
||||
)
|
||||
)
|
||||
char_career = char_career_result.scalar_one_or_none()
|
||||
|
||||
if not char_career:
|
||||
logger.warning(f" ⚠️ {character.name} 没有主职业关联记录")
|
||||
return False
|
||||
|
||||
# 查询职业信息
|
||||
career_result = await db.execute(
|
||||
select(Career).where(Career.id == char_career.career_id)
|
||||
)
|
||||
career = career_result.scalar_one_or_none()
|
||||
|
||||
if not career:
|
||||
logger.warning(f" ⚠️ 职业ID {char_career.career_id} 不存在")
|
||||
return False
|
||||
|
||||
# 计算新阶段(不超过最大阶段,不低于1)
|
||||
old_stage = char_career.current_stage
|
||||
new_stage = min(max(1, old_stage + stage_change), career.max_stage)
|
||||
|
||||
# 如果没有实际变化,跳过
|
||||
if new_stage == old_stage:
|
||||
logger.info(f" 📊 {character.name} 的 {career.name} 已达到边界,无法变更")
|
||||
return False
|
||||
|
||||
# 更新CharacterCareer表
|
||||
char_career.current_stage = new_stage
|
||||
|
||||
# 同步更新Character表的冗余字段
|
||||
character.main_career_stage = new_stage
|
||||
|
||||
# 记录变更日志
|
||||
change_desc = f"{'晋升' if stage_change > 0 else '降级'}"
|
||||
breakthrough_desc = career_changes.get('career_breakthrough', '')
|
||||
|
||||
changes_log.append({
|
||||
'character': character.name,
|
||||
'career': career.name,
|
||||
'career_type': 'main',
|
||||
'old_stage': old_stage,
|
||||
'new_stage': new_stage,
|
||||
'change': stage_change,
|
||||
'chapter': chapter_number,
|
||||
'description': breakthrough_desc
|
||||
})
|
||||
|
||||
logger.info(
|
||||
f" ✨ {character.name} 的主职业 [{career.name}] "
|
||||
f"{old_stage}阶 → {new_stage}阶 ({change_desc})"
|
||||
)
|
||||
if breakthrough_desc:
|
||||
logger.info(f" 突破描述: {breakthrough_desc[:50]}...")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f" ❌ 更新主职业失败: {str(e)}")
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
async def _update_sub_career_stage(
|
||||
db: AsyncSession,
|
||||
character: Character,
|
||||
project_id: str,
|
||||
sub_change: Dict[str, Any],
|
||||
chapter_number: int,
|
||||
changes_log: List[Dict[str, Any]]
|
||||
) -> bool:
|
||||
"""更新副职业阶段"""
|
||||
try:
|
||||
career_name = sub_change.get('career_name')
|
||||
stage_change = sub_change.get('stage_change', 0)
|
||||
|
||||
if not career_name or stage_change == 0:
|
||||
return False
|
||||
|
||||
# 1. 查询职业(通过名称)
|
||||
career_result = await db.execute(
|
||||
select(Career).where(
|
||||
Career.name == career_name,
|
||||
Career.project_id == project_id,
|
||||
Career.type == 'sub'
|
||||
)
|
||||
)
|
||||
career = career_result.scalar_one_or_none()
|
||||
|
||||
if not career:
|
||||
logger.warning(f" ⚠️ 副职业 [{career_name}] 不存在")
|
||||
return False
|
||||
|
||||
# 2. 查询角色-职业关联
|
||||
char_career_result = await db.execute(
|
||||
select(CharacterCareer).where(
|
||||
CharacterCareer.character_id == character.id,
|
||||
CharacterCareer.career_id == career.id,
|
||||
CharacterCareer.career_type == 'sub'
|
||||
)
|
||||
)
|
||||
char_career = char_career_result.scalar_one_or_none()
|
||||
|
||||
if not char_career:
|
||||
logger.warning(f" ⚠️ {character.name} 没有 [{career_name}] 副职业")
|
||||
return False
|
||||
|
||||
# 3. 计算新阶段
|
||||
old_stage = char_career.current_stage
|
||||
new_stage = min(max(1, old_stage + stage_change), career.max_stage)
|
||||
|
||||
if new_stage == old_stage:
|
||||
return False
|
||||
|
||||
# 4. 更新阶段
|
||||
char_career.current_stage = new_stage
|
||||
|
||||
# 5. 同步更新Character表的sub_careers JSON字段
|
||||
import json
|
||||
sub_careers = json.loads(character.sub_careers) if character.sub_careers else []
|
||||
for sc in sub_careers:
|
||||
if sc.get('career_id') == career.id:
|
||||
sc['stage'] = new_stage
|
||||
break
|
||||
character.sub_careers = json.dumps(sub_careers, ensure_ascii=False)
|
||||
|
||||
# 6. 记录变更
|
||||
changes_log.append({
|
||||
'character': character.name,
|
||||
'career': career.name,
|
||||
'career_type': 'sub',
|
||||
'old_stage': old_stage,
|
||||
'new_stage': new_stage,
|
||||
'change': stage_change,
|
||||
'chapter': chapter_number
|
||||
})
|
||||
|
||||
logger.info(
|
||||
f" ✨ {character.name} 的副职业 [{career.name}] "
|
||||
f"{old_stage}阶 → {new_stage}阶"
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f" ❌ 更新副职业失败: {str(e)}")
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
async def _add_new_career(
|
||||
db: AsyncSession,
|
||||
character: Character,
|
||||
project_id: str,
|
||||
career_name: str,
|
||||
chapter_number: int,
|
||||
changes_log: List[Dict[str, Any]]
|
||||
) -> bool:
|
||||
"""为角色添加新职业"""
|
||||
try:
|
||||
# 1. 查询职业
|
||||
career_result = await db.execute(
|
||||
select(Career).where(
|
||||
Career.name == career_name,
|
||||
Career.project_id == project_id
|
||||
)
|
||||
)
|
||||
career = career_result.scalar_one_or_none()
|
||||
|
||||
if not career:
|
||||
logger.warning(f" ⚠️ 职业 [{career_name}] 不存在,无法添加")
|
||||
return False
|
||||
|
||||
# 2. 检查是否已存在
|
||||
existing_result = await db.execute(
|
||||
select(CharacterCareer).where(
|
||||
CharacterCareer.character_id == character.id,
|
||||
CharacterCareer.career_id == career.id
|
||||
)
|
||||
)
|
||||
if existing_result.scalar_one_or_none():
|
||||
logger.info(f" 📋 {character.name} 已拥有 [{career_name}],跳过")
|
||||
return False
|
||||
|
||||
# 3. 根据职业类型添加
|
||||
if career.type == 'main':
|
||||
# 检查是否已有主职业
|
||||
if character.main_career_id:
|
||||
logger.warning(f" ⚠️ {character.name} 已有主职业,无法添加新主职业")
|
||||
return False
|
||||
|
||||
# 添加主职业
|
||||
import uuid
|
||||
new_char_career = CharacterCareer(
|
||||
id=str(uuid.uuid4()),
|
||||
character_id=character.id,
|
||||
career_id=career.id,
|
||||
career_type='main',
|
||||
current_stage=1
|
||||
)
|
||||
db.add(new_char_career)
|
||||
|
||||
# 更新Character表
|
||||
character.main_career_id = career.id
|
||||
character.main_career_stage = 1
|
||||
|
||||
logger.info(f" ✨ {character.name} 获得新主职业 [{career_name}]")
|
||||
|
||||
else: # sub职业
|
||||
# 检查副职业数量(最多2个)
|
||||
sub_count_result = await db.execute(
|
||||
select(CharacterCareer).where(
|
||||
CharacterCareer.character_id == character.id,
|
||||
CharacterCareer.career_type == 'sub'
|
||||
)
|
||||
)
|
||||
if len(sub_count_result.scalars().all()) >= 2:
|
||||
logger.warning(f" ⚠️ {character.name} 的副职业已达上限(2个)")
|
||||
return False
|
||||
|
||||
# 添加副职业
|
||||
import uuid
|
||||
new_char_career = CharacterCareer(
|
||||
id=str(uuid.uuid4()),
|
||||
character_id=character.id,
|
||||
career_id=career.id,
|
||||
career_type='sub',
|
||||
current_stage=1
|
||||
)
|
||||
db.add(new_char_career)
|
||||
|
||||
# 更新Character表的sub_careers JSON
|
||||
import json
|
||||
sub_careers = json.loads(character.sub_careers) if character.sub_careers else []
|
||||
sub_careers.append({
|
||||
'career_id': career.id,
|
||||
'stage': 1
|
||||
})
|
||||
character.sub_careers = json.dumps(sub_careers, ensure_ascii=False)
|
||||
|
||||
logger.info(f" ✨ {character.name} 获得新副职业 [{career_name}]")
|
||||
|
||||
# 记录变更
|
||||
changes_log.append({
|
||||
'character': character.name,
|
||||
'career': career.name,
|
||||
'career_type': career.type,
|
||||
'action': 'new',
|
||||
'chapter': chapter_number
|
||||
})
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f" ❌ 添加新职业失败: {str(e)}")
|
||||
return False
|
||||
@@ -0,0 +1,745 @@
|
||||
"""章节上下文构建服务 - 实现RTCO框架的智能上下文构建"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, Any, Optional, List
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select
|
||||
import json
|
||||
|
||||
from app.models.chapter import Chapter
|
||||
from app.models.project import Project
|
||||
from app.models.outline import Outline
|
||||
from app.models.character import Character
|
||||
from app.models.career import Career, CharacterCareer
|
||||
from app.models.memory import StoryMemory
|
||||
from app.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChapterContext:
|
||||
"""
|
||||
章节上下文数据结构
|
||||
|
||||
采用RTCO框架的分层设计:
|
||||
- P0-核心(必须):大纲、衔接点、字数要求
|
||||
- P1-重要(按需):角色、情感基调、风格
|
||||
- P2-参考(条件触发):记忆、故事骨架、MCP资料
|
||||
"""
|
||||
|
||||
# === P0-核心信息(必须包含)===
|
||||
chapter_outline: str = "" # 本章大纲
|
||||
continuation_point: Optional[str] = None # 衔接锚点(上一章结尾)
|
||||
target_word_count: int = 3000 # 目标字数
|
||||
min_word_count: int = 2500 # 最小字数
|
||||
max_word_count: int = 4000 # 最大字数
|
||||
narrative_perspective: str = "第三人称" # 叙事视角
|
||||
|
||||
# === 本章基本信息 ===
|
||||
chapter_number: int = 1 # 章节序号
|
||||
chapter_title: str = "" # 章节标题
|
||||
|
||||
# === 项目基本信息 ===
|
||||
title: str = "" # 书名
|
||||
genre: str = "" # 类型
|
||||
theme: str = "" # 主题
|
||||
|
||||
# === P1-重要信息(按需包含)===
|
||||
chapter_characters: str = "" # 本章涉及角色(精简)
|
||||
emotional_tone: str = "" # 情感基调
|
||||
style_instruction: str = "" # 写作风格指令(摘要化)
|
||||
|
||||
# === P2-参考信息(条件触发)===
|
||||
relevant_memories: Optional[str] = None # 相关记忆(精简版)
|
||||
story_skeleton: Optional[str] = None # 故事骨架(50章+启用)
|
||||
mcp_references: Optional[str] = None # MCP参考资料
|
||||
|
||||
# === 元信息 ===
|
||||
context_stats: Dict[str, Any] = field(default_factory=dict) # 统计信息
|
||||
|
||||
def get_total_context_length(self) -> int:
|
||||
"""计算总上下文长度"""
|
||||
total = 0
|
||||
for field_name in ['chapter_outline', 'continuation_point', 'chapter_characters',
|
||||
'relevant_memories', 'story_skeleton', 'style_instruction']:
|
||||
value = getattr(self, field_name, None)
|
||||
if value:
|
||||
total += len(value)
|
||||
return total
|
||||
|
||||
|
||||
class ChapterContextBuilder:
|
||||
"""
|
||||
章节上下文构建器
|
||||
|
||||
实现动态裁剪逻辑,根据章节序号自动调整上下文复杂度:
|
||||
- 第1章:无前置上下文,仅提供大纲和角色
|
||||
- 第2-10章:上一章结尾300字 + 涉及角色
|
||||
- 第11-50章:上一章结尾500字 + 相关记忆3条
|
||||
- 第51章+:上一章结尾500字 + 故事骨架 + 智能记忆5条
|
||||
"""
|
||||
|
||||
# 配置常量
|
||||
ENDING_LENGTH_SHORT = 300 # 1-10章:短衔接
|
||||
ENDING_LENGTH_NORMAL = 500 # 11章+:标准衔接
|
||||
MEMORY_COUNT_LIGHT = 3 # 11-50章:轻量记忆
|
||||
MEMORY_COUNT_FULL = 5 # 51章+:完整记忆
|
||||
SKELETON_THRESHOLD = 50 # 启用故事骨架的章节阈值
|
||||
SKELETON_SAMPLE_INTERVAL = 10 # 故事骨架采样间隔
|
||||
MEMORY_IMPORTANCE_THRESHOLD = 0.7 # 记忆重要性阈值
|
||||
STYLE_MAX_LENGTH = 200 # 风格描述最大长度
|
||||
MAX_CONTEXT_LENGTH = 3000 # 总上下文最大字符数
|
||||
|
||||
def __init__(self, memory_service=None):
|
||||
"""
|
||||
初始化构建器
|
||||
|
||||
Args:
|
||||
memory_service: 记忆服务实例(可选,用于检索相关记忆)
|
||||
"""
|
||||
self.memory_service = memory_service
|
||||
|
||||
async def build(
|
||||
self,
|
||||
chapter: Chapter,
|
||||
project: Project,
|
||||
outline: Optional[Outline],
|
||||
user_id: str,
|
||||
db: AsyncSession,
|
||||
style_content: Optional[str] = None,
|
||||
target_word_count: int = 3000,
|
||||
temp_narrative_perspective: Optional[str] = None
|
||||
) -> ChapterContext:
|
||||
"""
|
||||
构建章节生成所需的上下文
|
||||
|
||||
Args:
|
||||
chapter: 章节对象
|
||||
project: 项目对象
|
||||
outline: 大纲对象(可选)
|
||||
user_id: 用户ID
|
||||
db: 数据库会话
|
||||
style_content: 写作风格内容(可选)
|
||||
target_word_count: 目标字数
|
||||
temp_narrative_perspective: 临时叙事视角(可选,覆盖项目默认)
|
||||
|
||||
Returns:
|
||||
ChapterContext: 结构化的上下文对象
|
||||
"""
|
||||
chapter_number = chapter.chapter_number
|
||||
logger.info(f"📝 开始构建章节上下文: 第{chapter_number}章")
|
||||
|
||||
# 确定叙事视角
|
||||
narrative_perspective = (
|
||||
temp_narrative_perspective or
|
||||
project.narrative_perspective or
|
||||
"第三人称"
|
||||
)
|
||||
|
||||
# 初始化上下文
|
||||
context = ChapterContext(
|
||||
chapter_number=chapter_number,
|
||||
chapter_title=chapter.title or "",
|
||||
title=project.title or "",
|
||||
genre=project.genre or "",
|
||||
theme=project.theme or "",
|
||||
target_word_count=target_word_count,
|
||||
min_word_count=max(500, target_word_count - 500),
|
||||
max_word_count=target_word_count + 1000,
|
||||
narrative_perspective=narrative_perspective
|
||||
)
|
||||
|
||||
# === P0-核心信息(始终构建)===
|
||||
context.chapter_outline = await self._build_chapter_outline(
|
||||
chapter, outline, project.outline_mode
|
||||
)
|
||||
|
||||
# === 衔接锚点(根据章节调整长度)===
|
||||
if chapter_number == 1:
|
||||
context.continuation_point = None
|
||||
logger.info(" ✅ 第1章无需衔接锚点")
|
||||
elif chapter_number <= 10:
|
||||
context.continuation_point = await self._get_last_ending(
|
||||
chapter, db, self.ENDING_LENGTH_SHORT
|
||||
)
|
||||
logger.info(f" ✅ 衔接锚点(短): {len(context.continuation_point or '')}字符")
|
||||
else:
|
||||
context.continuation_point = await self._get_last_ending(
|
||||
chapter, db, self.ENDING_LENGTH_NORMAL
|
||||
)
|
||||
logger.info(f" ✅ 衔接锚点(标准): {len(context.continuation_point or '')}字符")
|
||||
|
||||
# === P1-重要信息 ===
|
||||
context.chapter_characters = await self._build_chapter_characters(
|
||||
chapter, project, outline, db
|
||||
)
|
||||
context.emotional_tone = self._extract_emotional_tone(chapter, outline)
|
||||
|
||||
# 写作风格(摘要化)
|
||||
if style_content:
|
||||
context.style_instruction = self._summarize_style(style_content)
|
||||
|
||||
# === P2-参考信息(条件触发)===
|
||||
if chapter_number > 10 and self.memory_service:
|
||||
memory_limit = (
|
||||
self.MEMORY_COUNT_LIGHT if chapter_number <= 50
|
||||
else self.MEMORY_COUNT_FULL
|
||||
)
|
||||
context.relevant_memories = await self._get_relevant_memories(
|
||||
user_id, project.id, chapter_number,
|
||||
context.chapter_outline,
|
||||
limit=memory_limit
|
||||
)
|
||||
logger.info(f" ✅ 相关记忆: {len(context.relevant_memories or '')}字符")
|
||||
|
||||
# 故事骨架(50章+)
|
||||
if chapter_number > self.SKELETON_THRESHOLD:
|
||||
context.story_skeleton = await self._build_story_skeleton(
|
||||
project.id, chapter_number, db
|
||||
)
|
||||
logger.info(f" ✅ 故事骨架: {len(context.story_skeleton or '')}字符")
|
||||
|
||||
# === 统计信息 ===
|
||||
context.context_stats = {
|
||||
"chapter_number": chapter_number,
|
||||
"has_continuation": context.continuation_point is not None,
|
||||
"continuation_length": len(context.continuation_point or ""),
|
||||
"characters_length": len(context.chapter_characters),
|
||||
"memories_length": len(context.relevant_memories or ""),
|
||||
"skeleton_length": len(context.story_skeleton or ""),
|
||||
"total_length": context.get_total_context_length()
|
||||
}
|
||||
|
||||
logger.info(f"📊 上下文构建完成: 总长度 {context.context_stats['total_length']} 字符")
|
||||
|
||||
return context
|
||||
|
||||
async def _build_chapter_outline(
|
||||
self,
|
||||
chapter: Chapter,
|
||||
outline: Optional[Outline],
|
||||
outline_mode: str
|
||||
) -> str:
|
||||
"""
|
||||
构建本章大纲内容
|
||||
|
||||
Args:
|
||||
chapter: 章节对象
|
||||
outline: 大纲对象
|
||||
outline_mode: 大纲模式(one-to-one/one-to-many)
|
||||
|
||||
Returns:
|
||||
本章大纲文本
|
||||
"""
|
||||
if outline_mode == 'one-to-one':
|
||||
# 一对一模式:使用大纲的 content
|
||||
return outline.content if outline else chapter.summary or '暂无大纲'
|
||||
else:
|
||||
# 一对多模式:优先使用 expansion_plan 的详细规划
|
||||
if chapter.expansion_plan:
|
||||
try:
|
||||
plan = json.loads(chapter.expansion_plan)
|
||||
outline_content = f"""剧情摘要:{plan.get('plot_summary', '无')}
|
||||
|
||||
关键事件:
|
||||
{chr(10).join(f'- {event}' for event in plan.get('key_events', []))}
|
||||
|
||||
角色焦点:{', '.join(plan.get('character_focus', []))}
|
||||
情感基调:{plan.get('emotional_tone', '未设定')}
|
||||
叙事目标:{plan.get('narrative_goal', '未设定')}
|
||||
冲突类型:{plan.get('conflict_type', '未设定')}"""
|
||||
return outline_content
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# 回退到大纲内容
|
||||
return outline.content if outline else chapter.summary or '暂无大纲'
|
||||
|
||||
async def _get_last_ending(
|
||||
self,
|
||||
chapter: Chapter,
|
||||
db: AsyncSession,
|
||||
max_length: int
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
获取上一章结尾内容作为衔接锚点
|
||||
|
||||
Args:
|
||||
chapter: 当前章节
|
||||
db: 数据库会话
|
||||
max_length: 最大长度
|
||||
|
||||
Returns:
|
||||
上一章结尾内容
|
||||
"""
|
||||
if chapter.chapter_number <= 1:
|
||||
return None
|
||||
|
||||
# 查询上一章
|
||||
result = await db.execute(
|
||||
select(Chapter)
|
||||
.where(Chapter.project_id == chapter.project_id)
|
||||
.where(Chapter.chapter_number == chapter.chapter_number - 1)
|
||||
)
|
||||
prev_chapter = result.scalar_one_or_none()
|
||||
|
||||
if not prev_chapter or not prev_chapter.content:
|
||||
return None
|
||||
|
||||
# 提取结尾内容
|
||||
content = prev_chapter.content.strip()
|
||||
if len(content) <= max_length:
|
||||
return content
|
||||
|
||||
return content[-max_length:]
|
||||
|
||||
async def _build_chapter_characters(
|
||||
self,
|
||||
chapter: Chapter,
|
||||
project: Project,
|
||||
outline: Optional[Outline],
|
||||
db: AsyncSession
|
||||
) -> str:
|
||||
"""
|
||||
构建本章涉及的角色信息(精简版)
|
||||
|
||||
只返回本章相关的角色,而非全部角色
|
||||
|
||||
Args:
|
||||
chapter: 章节对象
|
||||
project: 项目对象
|
||||
outline: 大纲对象
|
||||
db: 数据库会话
|
||||
|
||||
Returns:
|
||||
本章角色信息文本
|
||||
"""
|
||||
# 获取所有角色
|
||||
characters_result = await db.execute(
|
||||
select(Character).where(Character.project_id == project.id)
|
||||
)
|
||||
characters = characters_result.scalars().all()
|
||||
|
||||
if not characters:
|
||||
return "暂无角色信息"
|
||||
|
||||
# 提取本章相关角色名单
|
||||
filter_character_names = None
|
||||
|
||||
# 从大纲或扩展计划中提取角色
|
||||
if project.outline_mode == 'one-to-one':
|
||||
if outline and outline.structure:
|
||||
try:
|
||||
structure = json.loads(outline.structure)
|
||||
filter_character_names = structure.get('characters', [])
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
else:
|
||||
if chapter.expansion_plan:
|
||||
try:
|
||||
plan = json.loads(chapter.expansion_plan)
|
||||
filter_character_names = plan.get('character_focus', [])
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# 筛选角色
|
||||
if filter_character_names:
|
||||
characters = [c for c in characters if c.name in filter_character_names]
|
||||
|
||||
if not characters:
|
||||
return "暂无相关角色"
|
||||
|
||||
# 构建精简的角色信息(每个角色最多100字符)
|
||||
char_lines = []
|
||||
for c in characters[:10]: # 最多10个角色
|
||||
role_type = "主角" if c.role_type == "protagonist" else (
|
||||
"反派" if c.role_type == "antagonist" else "配角"
|
||||
)
|
||||
|
||||
# 性格摘要(最多50字符)
|
||||
personality_brief = ""
|
||||
if c.personality:
|
||||
personality_brief = c.personality[:50]
|
||||
if len(c.personality) > 50:
|
||||
personality_brief += "..."
|
||||
|
||||
char_lines.append(f"- {c.name}({role_type}): {personality_brief}")
|
||||
|
||||
return "\n".join(char_lines)
|
||||
|
||||
def _extract_emotional_tone(
|
||||
self,
|
||||
chapter: Chapter,
|
||||
outline: Optional[Outline]
|
||||
) -> str:
|
||||
"""
|
||||
提取本章情感基调
|
||||
|
||||
Args:
|
||||
chapter: 章节对象
|
||||
outline: 大纲对象
|
||||
|
||||
Returns:
|
||||
情感基调描述
|
||||
"""
|
||||
# 尝试从扩展计划中提取
|
||||
if chapter.expansion_plan:
|
||||
try:
|
||||
plan = json.loads(chapter.expansion_plan)
|
||||
tone = plan.get('emotional_tone')
|
||||
if tone:
|
||||
return tone
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# 尝试从大纲结构中提取
|
||||
if outline and outline.structure:
|
||||
try:
|
||||
structure = json.loads(outline.structure)
|
||||
tone = structure.get('emotion') or structure.get('emotional_tone')
|
||||
if tone:
|
||||
return tone
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
return "未设定"
|
||||
|
||||
def _summarize_style(self, style_content: str) -> str:
|
||||
"""
|
||||
将风格描述压缩为关键要点
|
||||
|
||||
Args:
|
||||
style_content: 完整风格描述
|
||||
|
||||
Returns:
|
||||
摘要化的风格描述
|
||||
"""
|
||||
if not style_content:
|
||||
return ""
|
||||
|
||||
if len(style_content) <= self.STYLE_MAX_LENGTH:
|
||||
return style_content
|
||||
|
||||
# 简单截断(后续可以用AI提取关键词)
|
||||
return style_content[:self.STYLE_MAX_LENGTH] + "..."
|
||||
|
||||
async def _get_relevant_memories(
|
||||
self,
|
||||
user_id: str,
|
||||
project_id: str,
|
||||
chapter_number: int,
|
||||
chapter_outline: str,
|
||||
limit: int = 3
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
获取与本章最相关的记忆(精简版)
|
||||
|
||||
策略:
|
||||
1. 仅检索与大纲语义最相关的记忆
|
||||
2. 提高重要性阈值,过滤低质量记忆
|
||||
3. 优先返回未回收的伏笔
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
project_id: 项目ID
|
||||
chapter_number: 当前章节号
|
||||
chapter_outline: 本章大纲
|
||||
limit: 返回数量限制
|
||||
|
||||
Returns:
|
||||
格式化的记忆文本
|
||||
"""
|
||||
if not self.memory_service:
|
||||
return None
|
||||
|
||||
try:
|
||||
# 1. 语义检索相关记忆(提高阈值)
|
||||
relevant = await self.memory_service.search_memories(
|
||||
user_id=user_id,
|
||||
project_id=project_id,
|
||||
query=chapter_outline,
|
||||
limit=limit,
|
||||
min_importance=self.MEMORY_IMPORTANCE_THRESHOLD
|
||||
)
|
||||
|
||||
# 2. 检查即将到期的伏笔
|
||||
foreshadows = await self._get_due_foreshadows(
|
||||
user_id, project_id, chapter_number,
|
||||
lookahead=5 # 仅看5章内需要回收的
|
||||
)
|
||||
|
||||
# 3. 合并并格式化
|
||||
return self._format_memories(relevant, foreshadows, max_length=500)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 获取相关记忆失败: {str(e)}")
|
||||
return None
|
||||
|
||||
async def _get_due_foreshadows(
|
||||
self,
|
||||
user_id: str,
|
||||
project_id: str,
|
||||
chapter_number: int,
|
||||
lookahead: int = 5
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取即将需要回收的伏笔
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
project_id: 项目ID
|
||||
chapter_number: 当前章节号
|
||||
lookahead: 往前看的章节数
|
||||
|
||||
Returns:
|
||||
待回收伏笔列表
|
||||
"""
|
||||
if not self.memory_service:
|
||||
return []
|
||||
|
||||
try:
|
||||
foreshadows = await self.memory_service.find_unresolved_foreshadows(
|
||||
user_id, project_id, chapter_number
|
||||
)
|
||||
|
||||
# 过滤:只保留埋下时间较长(超过lookahead章)的伏笔
|
||||
due_foreshadows = []
|
||||
for fs in foreshadows:
|
||||
meta = fs.get('metadata', {})
|
||||
fs_chapter = meta.get('chapter_number', 0)
|
||||
if chapter_number - fs_chapter >= lookahead:
|
||||
due_foreshadows.append({
|
||||
'chapter': fs_chapter,
|
||||
'content': fs.get('content', '')[:60],
|
||||
'importance': meta.get('importance', 0.5)
|
||||
})
|
||||
|
||||
return due_foreshadows[:2] # 最多2条
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 获取待回收伏笔失败: {str(e)}")
|
||||
return []
|
||||
|
||||
def _format_memories(
|
||||
self,
|
||||
relevant: List[Dict[str, Any]],
|
||||
foreshadows: List[Dict[str, Any]],
|
||||
max_length: int = 500
|
||||
) -> str:
|
||||
"""
|
||||
格式化记忆为简洁文本,严格限制长度
|
||||
|
||||
Args:
|
||||
relevant: 相关记忆列表
|
||||
foreshadows: 待回收伏笔列表
|
||||
max_length: 最大长度
|
||||
|
||||
Returns:
|
||||
格式化的记忆文本
|
||||
"""
|
||||
lines = []
|
||||
current_length = 0
|
||||
|
||||
# 优先添加待回收伏笔
|
||||
if foreshadows:
|
||||
lines.append("【待回收伏笔】")
|
||||
for fs in foreshadows[:2]:
|
||||
text = f"- 第{fs['chapter']}章埋下:{fs['content']}"
|
||||
if current_length + len(text) > max_length:
|
||||
break
|
||||
lines.append(text)
|
||||
current_length += len(text)
|
||||
|
||||
# 添加相关记忆
|
||||
if relevant and current_length < max_length:
|
||||
lines.append("【相关记忆】")
|
||||
for mem in relevant:
|
||||
content = mem.get('content', '')[:80]
|
||||
text = f"- {content}"
|
||||
if current_length + len(text) > max_length:
|
||||
break
|
||||
lines.append(text)
|
||||
current_length += len(text)
|
||||
|
||||
return "\n".join(lines) if lines else None
|
||||
|
||||
async def _build_story_skeleton(
|
||||
self,
|
||||
project_id: str,
|
||||
chapter_number: int,
|
||||
db: AsyncSession
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
构建故事骨架(每N章采样)
|
||||
|
||||
Args:
|
||||
project_id: 项目ID
|
||||
chapter_number: 当前章节号
|
||||
db: 数据库会话
|
||||
|
||||
Returns:
|
||||
故事骨架文本
|
||||
"""
|
||||
try:
|
||||
# 获取所有已完成章节的摘要
|
||||
result = await db.execute(
|
||||
select(Chapter.chapter_number, Chapter.title)
|
||||
.where(Chapter.project_id == project_id)
|
||||
.where(Chapter.chapter_number < chapter_number)
|
||||
.where(Chapter.content != None)
|
||||
.where(Chapter.content != "")
|
||||
.order_by(Chapter.chapter_number)
|
||||
)
|
||||
chapters = result.all()
|
||||
|
||||
if not chapters:
|
||||
return None
|
||||
|
||||
# 采样:每N章取一个
|
||||
skeleton_lines = ["【故事骨架】"]
|
||||
for i, (ch_num, ch_title) in enumerate(chapters):
|
||||
if i % self.SKELETON_SAMPLE_INTERVAL == 0:
|
||||
# 尝试获取章节摘要
|
||||
summary_result = await db.execute(
|
||||
select(StoryMemory.content)
|
||||
.where(StoryMemory.project_id == project_id)
|
||||
.where(StoryMemory.story_timeline == ch_num)
|
||||
.where(StoryMemory.memory_type == 'chapter_summary')
|
||||
.limit(1)
|
||||
)
|
||||
summary = summary_result.scalar_one_or_none()
|
||||
|
||||
if summary:
|
||||
skeleton_lines.append(f"第{ch_num}章《{ch_title}》:{summary[:100]}")
|
||||
else:
|
||||
skeleton_lines.append(f"第{ch_num}章《{ch_title}》")
|
||||
|
||||
if len(skeleton_lines) <= 1:
|
||||
return None
|
||||
|
||||
return "\n".join(skeleton_lines)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 构建故事骨架失败: {str(e)}")
|
||||
return None
|
||||
|
||||
|
||||
class FocusedMemoryRetriever:
|
||||
"""
|
||||
精简记忆检索器
|
||||
|
||||
相比原有的memory_service,提供更精准、更简洁的记忆检索
|
||||
"""
|
||||
|
||||
def __init__(self, memory_service):
|
||||
"""
|
||||
初始化检索器
|
||||
|
||||
Args:
|
||||
memory_service: 基础记忆服务实例
|
||||
"""
|
||||
self.memory_service = memory_service
|
||||
|
||||
async def get_relevant_memories(
|
||||
self,
|
||||
user_id: str,
|
||||
project_id: str,
|
||||
chapter_number: int,
|
||||
chapter_outline: str,
|
||||
limit: int = 3
|
||||
) -> str:
|
||||
"""
|
||||
获取与本章最相关的记忆
|
||||
|
||||
策略:
|
||||
1. 仅检索与大纲语义最相关的记忆
|
||||
2. 提高重要性阈值,过滤低质量记忆
|
||||
3. 优先返回未回收的伏笔
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
project_id: 项目ID
|
||||
chapter_number: 当前章节号
|
||||
chapter_outline: 本章大纲
|
||||
limit: 返回数量限制
|
||||
|
||||
Returns:
|
||||
格式化的记忆文本
|
||||
"""
|
||||
# 1. 语义检索相关记忆(提高阈值)
|
||||
relevant = await self.memory_service.search_memories(
|
||||
user_id=user_id,
|
||||
project_id=project_id,
|
||||
query=chapter_outline,
|
||||
limit=limit,
|
||||
min_importance=0.7 # 从0.4提高到0.7
|
||||
)
|
||||
|
||||
# 2. 检查即将到期的伏笔
|
||||
due_foreshadows = await self._get_due_foreshadows(
|
||||
user_id, project_id, chapter_number,
|
||||
lookahead=5 # 仅看5章内需要回收的
|
||||
)
|
||||
|
||||
# 3. 合并并格式化
|
||||
return self._format_memories(relevant, due_foreshadows, max_length=500)
|
||||
|
||||
async def _get_due_foreshadows(
|
||||
self,
|
||||
user_id: str,
|
||||
project_id: str,
|
||||
chapter_number: int,
|
||||
lookahead: int = 5
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""获取即将需要回收的伏笔"""
|
||||
foreshadows = await self.memory_service.find_unresolved_foreshadows(
|
||||
user_id, project_id, chapter_number
|
||||
)
|
||||
|
||||
# 过滤:只保留埋下时间较长的伏笔
|
||||
due_foreshadows = []
|
||||
for fs in foreshadows:
|
||||
meta = fs.get('metadata', {})
|
||||
fs_chapter = meta.get('chapter_number', 0)
|
||||
if chapter_number - fs_chapter >= lookahead:
|
||||
due_foreshadows.append({
|
||||
'chapter': fs_chapter,
|
||||
'content': fs.get('content', '')[:60],
|
||||
'importance': meta.get('importance', 0.5)
|
||||
})
|
||||
|
||||
return due_foreshadows[:2] # 最多2条
|
||||
|
||||
def _format_memories(
|
||||
self,
|
||||
relevant: List[Dict[str, Any]],
|
||||
foreshadows: List[Dict[str, Any]],
|
||||
max_length: int = 500
|
||||
) -> str:
|
||||
"""格式化为简洁文本,严格限制长度"""
|
||||
lines = []
|
||||
current_length = 0
|
||||
|
||||
# 优先添加待回收伏笔
|
||||
if foreshadows:
|
||||
lines.append("【待回收伏笔】")
|
||||
for fs in foreshadows[:2]:
|
||||
text = f"- 第{fs['chapter']}章埋下:{fs['content']}"
|
||||
if current_length + len(text) > max_length:
|
||||
break
|
||||
lines.append(text)
|
||||
current_length += len(text)
|
||||
|
||||
# 添加相关记忆
|
||||
if relevant and current_length < max_length:
|
||||
lines.append("【相关记忆】")
|
||||
for mem in relevant:
|
||||
content = mem.get('content', '')[:80]
|
||||
text = f"- {content}"
|
||||
if current_length + len(text) > max_length:
|
||||
break
|
||||
lines.append(text)
|
||||
current_length += len(text)
|
||||
|
||||
return "\n".join(lines) if lines else ""
|
||||
@@ -0,0 +1,248 @@
|
||||
"""章节重新生成服务"""
|
||||
from typing import Dict, Any, AsyncGenerator, Optional, List
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from app.services.ai_service import AIService
|
||||
from app.services.prompt_service import prompt_service, PromptService
|
||||
from app.models.chapter import Chapter
|
||||
from app.models.memory import PlotAnalysis
|
||||
from app.schemas.regeneration import ChapterRegenerateRequest, PreserveElementsConfig
|
||||
from app.logger import get_logger
|
||||
import difflib
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class ChapterRegenerator:
|
||||
"""章节重新生成服务"""
|
||||
|
||||
def __init__(self, ai_service: AIService):
|
||||
self.ai_service = ai_service
|
||||
logger.info("✅ ChapterRegenerator初始化成功")
|
||||
|
||||
async def regenerate_with_feedback(
|
||||
self,
|
||||
chapter: Chapter,
|
||||
analysis: Optional[PlotAnalysis],
|
||||
regenerate_request: ChapterRegenerateRequest,
|
||||
project_context: Dict[str, Any],
|
||||
style_content: str = "",
|
||||
user_id: str = None,
|
||||
db: AsyncSession = None
|
||||
) -> AsyncGenerator[Dict[str, Any], None]:
|
||||
"""
|
||||
根据反馈重新生成章节(流式)
|
||||
|
||||
Args:
|
||||
chapter: 原始章节对象
|
||||
analysis: 分析结果(可选)
|
||||
regenerate_request: 重新生成请求参数
|
||||
project_context: 项目上下文(项目信息、角色、大纲等)
|
||||
style_content: 写作风格
|
||||
user_id: 用户ID(用于获取自定义提示词)
|
||||
db: 数据库会话(用于查询自定义提示词)
|
||||
|
||||
Yields:
|
||||
包含类型和数据的字典: {'type': 'progress'/'chunk', 'data': ...}
|
||||
"""
|
||||
try:
|
||||
logger.info(f"🔄 开始重新生成章节: 第{chapter.chapter_number}章")
|
||||
|
||||
# 1. 构建修改指令
|
||||
yield {'type': 'progress', 'progress': 5, 'message': '正在构建修改指令...'}
|
||||
modification_instructions = self._build_modification_instructions(
|
||||
analysis=analysis,
|
||||
regenerate_request=regenerate_request
|
||||
)
|
||||
|
||||
logger.info(f"📝 修改指令构建完成,长度: {len(modification_instructions)}字符")
|
||||
|
||||
# 2. 构建完整提示词
|
||||
yield {'type': 'progress', 'progress': 10, 'message': '正在构建生成提示词...'}
|
||||
full_prompt = await self._build_regeneration_prompt(
|
||||
chapter=chapter,
|
||||
modification_instructions=modification_instructions,
|
||||
project_context=project_context,
|
||||
regenerate_request=regenerate_request,
|
||||
style_content=style_content,
|
||||
user_id=user_id,
|
||||
db=db
|
||||
)
|
||||
|
||||
logger.info(f"🎯 提示词构建完成,开始AI生成")
|
||||
yield {'type': 'progress', 'progress': 15, 'message': '开始AI生成内容...'}
|
||||
|
||||
# 3. 构建系统提示词(注入写作风格)
|
||||
system_prompt_with_style = None
|
||||
if style_content:
|
||||
system_prompt_with_style = f"""【🎨 写作风格要求 - 最高优先级】
|
||||
|
||||
{style_content}
|
||||
|
||||
⚠️ 请严格遵循上述写作风格要求进行重写,这是最重要的指令!
|
||||
确保在整个章节重写过程中始终保持风格的一致性。"""
|
||||
logger.info(f"✅ 已将写作风格注入系统提示词({len(style_content)}字符)")
|
||||
|
||||
# 4. 流式生成新内容,同时跟踪进度
|
||||
target_word_count = regenerate_request.target_word_count
|
||||
accumulated_length = 0
|
||||
|
||||
async for chunk in self.ai_service.generate_text_stream(
|
||||
prompt=full_prompt,
|
||||
system_prompt=system_prompt_with_style,
|
||||
temperature=0.7
|
||||
):
|
||||
# 发送内容块
|
||||
yield {'type': 'chunk', 'content': chunk}
|
||||
|
||||
# 更新累积字数并计算进度(15%-95%)
|
||||
accumulated_length += len(chunk)
|
||||
# 进度从15%开始,到95%结束,为后处理预留5%
|
||||
generation_progress = min(15 + (accumulated_length / target_word_count) * 80, 95)
|
||||
yield {'type': 'progress', 'progress': int(generation_progress), 'word_count': accumulated_length}
|
||||
|
||||
logger.info(f"✅ 章节重新生成完成,共生成 {accumulated_length} 字")
|
||||
yield {'type': 'progress', 'progress': 100, 'message': '生成完成'}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 重新生成失败: {str(e)}", exc_info=True)
|
||||
raise
|
||||
|
||||
def _build_modification_instructions(
|
||||
self,
|
||||
analysis: Optional[PlotAnalysis],
|
||||
regenerate_request: ChapterRegenerateRequest
|
||||
) -> str:
|
||||
"""构建修改指令"""
|
||||
|
||||
instructions = []
|
||||
|
||||
# 标题
|
||||
instructions.append("# 章节修改指令\n")
|
||||
|
||||
# 1. 来自分析的建议
|
||||
if (analysis and
|
||||
regenerate_request.selected_suggestion_indices and
|
||||
analysis.suggestions):
|
||||
|
||||
instructions.append("## 📋 需要改进的问题(来自AI分析):\n")
|
||||
for idx in regenerate_request.selected_suggestion_indices:
|
||||
if 0 <= idx < len(analysis.suggestions):
|
||||
suggestion = analysis.suggestions[idx]
|
||||
instructions.append(f"{idx + 1}. {suggestion}")
|
||||
instructions.append("")
|
||||
|
||||
# 2. 用户自定义指令
|
||||
if regenerate_request.custom_instructions:
|
||||
instructions.append("## ✍️ 用户自定义修改要求:\n")
|
||||
instructions.append(regenerate_request.custom_instructions)
|
||||
instructions.append("")
|
||||
|
||||
# 3. 重点优化方向
|
||||
if regenerate_request.focus_areas:
|
||||
instructions.append("## 🎯 重点优化方向:\n")
|
||||
focus_map = {
|
||||
"pacing": "节奏把控 - 调整叙事速度,避免拖沓或过快",
|
||||
"emotion": "情感渲染 - 深化人物情感表达,增强感染力",
|
||||
"description": "场景描写 - 丰富环境细节,增强画面感",
|
||||
"dialogue": "对话质量 - 让对话更自然真实,推动剧情",
|
||||
"conflict": "冲突强度 - 强化矛盾冲突,提升戏剧张力"
|
||||
}
|
||||
|
||||
for area in regenerate_request.focus_areas:
|
||||
if area in focus_map:
|
||||
instructions.append(f"- {focus_map[area]}")
|
||||
instructions.append("")
|
||||
|
||||
# 4. 保留要求
|
||||
if regenerate_request.preserve_elements:
|
||||
preserve = regenerate_request.preserve_elements
|
||||
instructions.append("## 🔒 必须保留的元素:\n")
|
||||
|
||||
if preserve.preserve_structure:
|
||||
instructions.append("- 保持原章节的整体结构和情节框架")
|
||||
|
||||
if preserve.preserve_dialogues:
|
||||
instructions.append("- 必须保留以下关键对话:")
|
||||
for dialogue in preserve.preserve_dialogues:
|
||||
instructions.append(f" * {dialogue}")
|
||||
|
||||
if preserve.preserve_plot_points:
|
||||
instructions.append("- 必须保留以下关键情节点:")
|
||||
for plot in preserve.preserve_plot_points:
|
||||
instructions.append(f" * {plot}")
|
||||
|
||||
if preserve.preserve_character_traits:
|
||||
instructions.append("- 保持所有角色的性格特征和行为模式一致")
|
||||
|
||||
instructions.append("")
|
||||
|
||||
return "\n".join(instructions)
|
||||
|
||||
async def _build_regeneration_prompt(
|
||||
self,
|
||||
chapter: Chapter,
|
||||
modification_instructions: str,
|
||||
project_context: Dict[str, Any],
|
||||
regenerate_request: ChapterRegenerateRequest,
|
||||
style_content: str = "",
|
||||
user_id: str = None,
|
||||
db: AsyncSession = None
|
||||
) -> str:
|
||||
"""构建完整的重新生成提示词"""
|
||||
# 使用PromptService的get_chapter_regeneration_prompt方法
|
||||
# 该方法会处理自定义模板加载和完整提示词构建
|
||||
return await PromptService.get_chapter_regeneration_prompt(
|
||||
chapter_number=chapter.chapter_number,
|
||||
title=chapter.title,
|
||||
word_count=chapter.word_count,
|
||||
content=chapter.content,
|
||||
modification_instructions=modification_instructions,
|
||||
project_context=project_context,
|
||||
style_content=style_content,
|
||||
target_word_count=regenerate_request.target_word_count,
|
||||
user_id=user_id,
|
||||
db=db
|
||||
)
|
||||
|
||||
def calculate_content_diff(
|
||||
self,
|
||||
original_content: str,
|
||||
new_content: str
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
计算两个版本的差异
|
||||
|
||||
Returns:
|
||||
差异统计信息
|
||||
"""
|
||||
# 基本统计
|
||||
diff_stats = {
|
||||
'original_length': len(original_content),
|
||||
'new_length': len(new_content),
|
||||
'length_change': len(new_content) - len(original_content),
|
||||
'length_change_percent': round((len(new_content) - len(original_content)) / len(original_content) * 100, 2) if len(original_content) > 0 else 0
|
||||
}
|
||||
|
||||
# 计算相似度
|
||||
similarity = difflib.SequenceMatcher(None, original_content, new_content).ratio()
|
||||
diff_stats['similarity'] = round(similarity * 100, 2)
|
||||
diff_stats['difference'] = round((1 - similarity) * 100, 2)
|
||||
|
||||
# 段落统计
|
||||
original_paragraphs = [p for p in original_content.split('\n\n') if p.strip()]
|
||||
new_paragraphs = [p for p in new_content.split('\n\n') if p.strip()]
|
||||
diff_stats['original_paragraph_count'] = len(original_paragraphs)
|
||||
diff_stats['new_paragraph_count'] = len(new_paragraphs)
|
||||
|
||||
return diff_stats
|
||||
|
||||
|
||||
# 全局实例
|
||||
_regenerator_instance = None
|
||||
|
||||
def get_chapter_regenerator(ai_service: AIService) -> ChapterRegenerator:
|
||||
"""获取章节重新生成器实例"""
|
||||
global _regenerator_instance
|
||||
if _regenerator_instance is None:
|
||||
_regenerator_instance = ChapterRegenerator(ai_service)
|
||||
return _regenerator_instance
|
||||
@@ -1,7 +1,7 @@
|
||||
"""导入导出服务"""
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from typing import Dict, List, Optional, Tuple, Any
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select
|
||||
from app.models.project import Project
|
||||
@@ -77,6 +77,8 @@ class ImportExportService:
|
||||
"chapter_count": project.chapter_count,
|
||||
"narrative_perspective": project.narrative_perspective,
|
||||
"character_count": project.character_count,
|
||||
"outline_mode": project.outline_mode,
|
||||
"user_id": project.user_id,
|
||||
"created_at": project.created_at.isoformat() if project.created_at else None,
|
||||
}
|
||||
|
||||
@@ -143,18 +145,41 @@ class ImportExportService:
|
||||
)
|
||||
chapters = result.scalars().all()
|
||||
|
||||
return [
|
||||
ChapterExportData(
|
||||
# 构建大纲ID到标题的映射
|
||||
outline_mapping = {}
|
||||
if chapters:
|
||||
outline_ids = [ch.outline_id for ch in chapters if ch.outline_id]
|
||||
if outline_ids:
|
||||
outline_result = await db.execute(
|
||||
select(Outline).where(Outline.id.in_(outline_ids))
|
||||
)
|
||||
outlines = outline_result.scalars().all()
|
||||
outline_mapping = {ol.id: ol.title for ol in outlines}
|
||||
|
||||
exported_chapters = []
|
||||
for ch in chapters:
|
||||
# 解析expansion_plan JSON
|
||||
expansion_plan = None
|
||||
if ch.expansion_plan:
|
||||
try:
|
||||
expansion_plan = json.loads(ch.expansion_plan) if isinstance(ch.expansion_plan, str) else ch.expansion_plan
|
||||
except:
|
||||
expansion_plan = None
|
||||
|
||||
exported_chapters.append(ChapterExportData(
|
||||
title=ch.title,
|
||||
content=ch.content,
|
||||
summary=ch.summary,
|
||||
chapter_number=ch.chapter_number,
|
||||
word_count=ch.word_count or 0,
|
||||
status=ch.status,
|
||||
created_at=ch.created_at.isoformat() if ch.created_at else None
|
||||
)
|
||||
for ch in chapters
|
||||
]
|
||||
created_at=ch.created_at.isoformat() if ch.created_at else None,
|
||||
outline_title=outline_mapping.get(ch.outline_id) if ch.outline_id else None,
|
||||
sub_index=ch.sub_index,
|
||||
expansion_plan=expansion_plan
|
||||
))
|
||||
|
||||
return exported_chapters
|
||||
|
||||
@staticmethod
|
||||
async def _export_characters(project_id: str, db: AsyncSession) -> List[CharacterExportData]:
|
||||
@@ -315,10 +340,19 @@ class ImportExportService:
|
||||
|
||||
@staticmethod
|
||||
async def _export_writing_styles(project_id: str, db: AsyncSession) -> List[WritingStyleExportData]:
|
||||
"""导出写作风格"""
|
||||
"""导出写作风格(用户自定义风格)"""
|
||||
# 获取项目所属用户
|
||||
project_result = await db.execute(
|
||||
select(Project).where(Project.id == project_id)
|
||||
)
|
||||
project = project_result.scalar_one_or_none()
|
||||
if not project:
|
||||
return []
|
||||
|
||||
# 导出该用户的自定义风格(不包括全局预设)
|
||||
result = await db.execute(
|
||||
select(WritingStyle)
|
||||
.where(WritingStyle.project_id == project_id)
|
||||
.where(WritingStyle.user_id == project.user_id)
|
||||
.order_by(WritingStyle.order_index)
|
||||
)
|
||||
styles = result.scalars().all()
|
||||
@@ -423,7 +457,8 @@ class ImportExportService:
|
||||
@staticmethod
|
||||
async def import_project(
|
||||
data: Dict,
|
||||
db: AsyncSession
|
||||
db: AsyncSession,
|
||||
user_id: str
|
||||
) -> ImportResult:
|
||||
"""
|
||||
导入项目数据(创建新项目)
|
||||
@@ -431,6 +466,7 @@ class ImportExportService:
|
||||
Args:
|
||||
data: 导入的JSON数据
|
||||
db: 数据库会话
|
||||
user_id: 目标用户ID(导入后的项目归属)
|
||||
|
||||
Returns:
|
||||
ImportResult: 导入结果
|
||||
@@ -456,6 +492,7 @@ class ImportExportService:
|
||||
# 创建项目
|
||||
project_data = data["project"]
|
||||
new_project = Project(
|
||||
user_id=user_id, # 设置为当前用户ID
|
||||
title=project_data.get("title"),
|
||||
description=project_data.get("description"),
|
||||
theme=project_data.get("theme"),
|
||||
@@ -469,6 +506,7 @@ class ImportExportService:
|
||||
chapter_count=project_data.get("chapter_count"),
|
||||
narrative_perspective=project_data.get("narrative_perspective"),
|
||||
character_count=project_data.get("character_count"),
|
||||
outline_mode=project_data.get("outline_mode", "one-to-many"), # ✅ 导入大纲模式,默认为一对多
|
||||
current_words=project_data.get("current_words", 0), # 保留原项目的字数
|
||||
wizard_step=4, # 导入的项目设置为向导完成状态
|
||||
wizard_status="completed" # 标记向导已完成
|
||||
@@ -478,26 +516,26 @@ class ImportExportService:
|
||||
|
||||
logger.info(f"创建项目成功: {new_project.id}")
|
||||
|
||||
# 导入章节
|
||||
chapters_count = await ImportExportService._import_chapters(
|
||||
new_project.id, data.get("chapters", []), db
|
||||
)
|
||||
statistics["chapters"] = chapters_count
|
||||
logger.info(f"导入章节数: {chapters_count}")
|
||||
|
||||
# 导入角色(包括组织)
|
||||
# 导入角色(包括组织)- 需要先导入角色,因为大纲可能需要角色信息
|
||||
char_mapping = await ImportExportService._import_characters(
|
||||
new_project.id, data.get("characters", []), db
|
||||
)
|
||||
statistics["characters"] = len(char_mapping)
|
||||
logger.info(f"导入角色数: {len(char_mapping)}")
|
||||
|
||||
# 导入大纲
|
||||
outlines_count = await ImportExportService._import_outlines(
|
||||
# 导入大纲 - 需要在章节之前导入,以便建立关联
|
||||
outline_mapping = await ImportExportService._import_outlines(
|
||||
new_project.id, data.get("outlines", []), db
|
||||
)
|
||||
statistics["outlines"] = outlines_count
|
||||
logger.info(f"导入大纲数: {outlines_count}")
|
||||
statistics["outlines"] = len(outline_mapping)
|
||||
logger.info(f"导入大纲数: {len(outline_mapping)}")
|
||||
|
||||
# 导入章节 - 使用大纲映射重建关联关系
|
||||
chapters_count = await ImportExportService._import_chapters(
|
||||
new_project.id, data.get("chapters", []), outline_mapping, db
|
||||
)
|
||||
statistics["chapters"] = chapters_count
|
||||
logger.info(f"导入章节数: {chapters_count}")
|
||||
|
||||
# 导入关系
|
||||
relationships_count = await ImportExportService._import_relationships(
|
||||
@@ -554,11 +592,23 @@ class ImportExportService:
|
||||
async def _import_chapters(
|
||||
project_id: str,
|
||||
chapters_data: List[Dict],
|
||||
outline_mapping: Dict[str, str],
|
||||
db: AsyncSession
|
||||
) -> int:
|
||||
"""导入章节"""
|
||||
count = 0
|
||||
for ch_data in chapters_data:
|
||||
# 根据大纲标题查找对应的新大纲ID
|
||||
outline_id = None
|
||||
outline_title = ch_data.get("outline_title")
|
||||
if outline_title and outline_title in outline_mapping:
|
||||
outline_id = outline_mapping[outline_title]
|
||||
|
||||
# 处理expansion_plan
|
||||
expansion_plan = ch_data.get("expansion_plan")
|
||||
if expansion_plan and isinstance(expansion_plan, dict):
|
||||
expansion_plan = json.dumps(expansion_plan, ensure_ascii=False)
|
||||
|
||||
chapter = Chapter(
|
||||
project_id=project_id,
|
||||
title=ch_data.get("title"),
|
||||
@@ -566,7 +616,10 @@ class ImportExportService:
|
||||
summary=ch_data.get("summary"),
|
||||
chapter_number=ch_data.get("chapter_number"),
|
||||
word_count=ch_data.get("word_count", 0),
|
||||
status=ch_data.get("status", "draft")
|
||||
status=ch_data.get("status", "draft"),
|
||||
outline_id=outline_id,
|
||||
sub_index=ch_data.get("sub_index"),
|
||||
expansion_plan=expansion_plan
|
||||
)
|
||||
db.add(chapter)
|
||||
count += 1
|
||||
@@ -585,7 +638,7 @@ class ImportExportService:
|
||||
for char_data in characters_data:
|
||||
# 处理traits
|
||||
traits = char_data.get("traits")
|
||||
if traits and isinstance(traits, list):
|
||||
if isinstance(traits, list):
|
||||
traits = json.dumps(traits, ensure_ascii=False)
|
||||
|
||||
character = Character(
|
||||
@@ -613,9 +666,10 @@ class ImportExportService:
|
||||
project_id: str,
|
||||
outlines_data: List[Dict],
|
||||
db: AsyncSession
|
||||
) -> int:
|
||||
"""导入大纲"""
|
||||
count = 0
|
||||
) -> Dict[str, str]:
|
||||
"""导入大纲,返回标题到ID的映射"""
|
||||
outline_mapping = {}
|
||||
|
||||
for ol_data in outlines_data:
|
||||
outline = Outline(
|
||||
project_id=project_id,
|
||||
@@ -625,9 +679,10 @@ class ImportExportService:
|
||||
order_index=ol_data.get("order_index")
|
||||
)
|
||||
db.add(outline)
|
||||
count += 1
|
||||
await db.flush() # 获取ID
|
||||
outline_mapping[ol_data.get("title")] = outline.id
|
||||
|
||||
return count
|
||||
return outline_mapping
|
||||
|
||||
@staticmethod
|
||||
async def _import_relationships(
|
||||
@@ -751,11 +806,30 @@ class ImportExportService:
|
||||
styles_data: List[Dict],
|
||||
db: AsyncSession
|
||||
) -> int:
|
||||
"""导入写作风格"""
|
||||
"""导入写作风格(用户自定义风格)"""
|
||||
# 获取项目所属用户
|
||||
project_result = await db.execute(
|
||||
select(Project).where(Project.id == project_id)
|
||||
)
|
||||
project = project_result.scalar_one_or_none()
|
||||
if not project:
|
||||
return 0
|
||||
|
||||
count = 0
|
||||
for style_data in styles_data:
|
||||
# 检查是否已存在同名风格(避免重复导入)
|
||||
existing = await db.execute(
|
||||
select(WritingStyle).where(
|
||||
WritingStyle.user_id == project.user_id,
|
||||
WritingStyle.name == style_data.get("name")
|
||||
)
|
||||
)
|
||||
if existing.scalar_one_or_none():
|
||||
logger.debug(f"风格 {style_data.get('name')} 已存在,跳过导入")
|
||||
continue
|
||||
|
||||
style = WritingStyle(
|
||||
project_id=project_id,
|
||||
user_id=project.user_id, # 使用 user_id 而不是 project_id
|
||||
name=style_data.get("name"),
|
||||
style_type=style_data.get("style_type"),
|
||||
preset_id=style_data.get("preset_id"),
|
||||
@@ -766,4 +840,403 @@ class ImportExportService:
|
||||
db.add(style)
|
||||
count += 1
|
||||
|
||||
return count
|
||||
return count
|
||||
|
||||
@staticmethod
|
||||
async def export_characters(
|
||||
character_ids: List[str],
|
||||
db: AsyncSession
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
导出角色/组织卡片
|
||||
|
||||
Args:
|
||||
character_ids: 要导出的角色/组织ID列表
|
||||
db: 数据库会话
|
||||
|
||||
Returns:
|
||||
Dict: 导出的角色数据
|
||||
"""
|
||||
logger.info(f"开始导出角色/组织: {len(character_ids)} 个")
|
||||
|
||||
# 查询角色数据
|
||||
result = await db.execute(
|
||||
select(Character).where(Character.id.in_(character_ids))
|
||||
)
|
||||
characters = result.scalars().all()
|
||||
|
||||
if not characters:
|
||||
raise ValueError("未找到指定的角色/组织")
|
||||
|
||||
# 导出角色数据
|
||||
exported_characters = []
|
||||
for char in characters:
|
||||
# 解析 traits
|
||||
traits = None
|
||||
if char.traits:
|
||||
try:
|
||||
traits = json.loads(char.traits) if isinstance(char.traits, str) else char.traits
|
||||
except:
|
||||
traits = None
|
||||
|
||||
# 基础角色数据
|
||||
char_data = {
|
||||
"name": char.name,
|
||||
"age": char.age,
|
||||
"gender": char.gender,
|
||||
"is_organization": char.is_organization or False,
|
||||
"role_type": char.role_type,
|
||||
"personality": char.personality,
|
||||
"background": char.background,
|
||||
"appearance": char.appearance,
|
||||
"relationships": char.relationships,
|
||||
"traits": traits,
|
||||
"organization_type": char.organization_type,
|
||||
"organization_purpose": char.organization_purpose,
|
||||
"organization_members": char.organization_members,
|
||||
"avatar_url": char.avatar_url,
|
||||
"main_career_id": char.main_career_id,
|
||||
"main_career_stage": char.main_career_stage,
|
||||
"sub_careers": char.sub_careers,
|
||||
"created_at": char.created_at.isoformat() if char.created_at else None
|
||||
}
|
||||
|
||||
# 如果是组织,添加组织专属字段
|
||||
if char.is_organization:
|
||||
org_result = await db.execute(
|
||||
select(Organization).where(Organization.character_id == char.id)
|
||||
)
|
||||
org = org_result.scalar_one_or_none()
|
||||
|
||||
if org:
|
||||
char_data.update({
|
||||
"power_level": org.power_level,
|
||||
"location": org.location,
|
||||
"motto": org.motto,
|
||||
"color": org.color
|
||||
})
|
||||
|
||||
exported_characters.append(char_data)
|
||||
|
||||
export_data = {
|
||||
"version": ImportExportService.SUPPORTED_VERSION,
|
||||
"export_time": datetime.utcnow().isoformat(),
|
||||
"export_type": "characters",
|
||||
"count": len(exported_characters),
|
||||
"data": exported_characters
|
||||
}
|
||||
|
||||
logger.info(f"角色/组织导出完成: {len(exported_characters)} 个")
|
||||
return export_data
|
||||
|
||||
@staticmethod
|
||||
async def import_characters(
|
||||
data: Dict,
|
||||
project_id: str,
|
||||
user_id: str,
|
||||
db: AsyncSession
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
导入角色/组织卡片
|
||||
|
||||
Args:
|
||||
data: 导入的JSON数据
|
||||
project_id: 目标项目ID
|
||||
user_id: 用户ID
|
||||
db: 数据库会话
|
||||
|
||||
Returns:
|
||||
Dict: 导入结果
|
||||
"""
|
||||
from app.models.career import CharacterCareer, Career
|
||||
|
||||
warnings = []
|
||||
imported_characters = []
|
||||
imported_organizations = []
|
||||
skipped = []
|
||||
errors = []
|
||||
|
||||
try:
|
||||
# 验证数据格式
|
||||
if "data" not in data:
|
||||
raise ValueError("导入数据格式错误:缺少data字段")
|
||||
|
||||
characters_data = data["data"]
|
||||
if not isinstance(characters_data, list):
|
||||
raise ValueError("导入数据格式错误:data字段必须是数组")
|
||||
|
||||
# 验证项目权限
|
||||
project_result = await db.execute(
|
||||
select(Project).where(
|
||||
Project.id == project_id,
|
||||
Project.user_id == user_id
|
||||
)
|
||||
)
|
||||
project = project_result.scalar_one_or_none()
|
||||
if not project:
|
||||
raise ValueError("项目不存在或无权访问")
|
||||
|
||||
logger.info(f"开始导入 {len(characters_data)} 个角色/组织到项目 {project_id}")
|
||||
|
||||
# 处理每个角色/组织
|
||||
for idx, char_data in enumerate(characters_data):
|
||||
try:
|
||||
name = char_data.get("name")
|
||||
if not name:
|
||||
errors.append(f"第{idx+1}个角色缺少name字段")
|
||||
continue
|
||||
|
||||
# 检查重复名称
|
||||
existing_result = await db.execute(
|
||||
select(Character).where(
|
||||
Character.project_id == project_id,
|
||||
Character.name == name
|
||||
)
|
||||
)
|
||||
existing = existing_result.scalar_one_or_none()
|
||||
|
||||
if existing:
|
||||
warnings.append(f"角色'{name}'已存在,已跳过")
|
||||
skipped.append(name)
|
||||
continue
|
||||
|
||||
# 处理traits
|
||||
traits = char_data.get("traits")
|
||||
if isinstance(traits, list):
|
||||
traits = json.dumps(traits, ensure_ascii=False)
|
||||
|
||||
is_organization = char_data.get("is_organization", False)
|
||||
|
||||
# 创建角色
|
||||
character = Character(
|
||||
project_id=project_id,
|
||||
name=name,
|
||||
age=char_data.get("age"),
|
||||
gender=char_data.get("gender"),
|
||||
is_organization=is_organization,
|
||||
role_type=char_data.get("role_type"),
|
||||
personality=char_data.get("personality"),
|
||||
background=char_data.get("background"),
|
||||
appearance=char_data.get("appearance"),
|
||||
relationships=char_data.get("relationships"),
|
||||
traits=traits,
|
||||
organization_type=char_data.get("organization_type"),
|
||||
organization_purpose=char_data.get("organization_purpose"),
|
||||
organization_members=char_data.get("organization_members"),
|
||||
avatar_url=char_data.get("avatar_url"),
|
||||
main_career_id=None, # 职业ID需要验证后再设置
|
||||
main_career_stage=char_data.get("main_career_stage"),
|
||||
sub_careers=None # 副职业需要验证后再设置
|
||||
)
|
||||
db.add(character)
|
||||
await db.flush() # 获取character.id
|
||||
|
||||
# 处理主职业(如果有)
|
||||
main_career_id = char_data.get("main_career_id")
|
||||
main_career_stage = char_data.get("main_career_stage")
|
||||
|
||||
if main_career_id and not is_organization:
|
||||
# 验证职业是否存在
|
||||
career_result = await db.execute(
|
||||
select(Career).where(
|
||||
Career.id == main_career_id,
|
||||
Career.project_id == project_id,
|
||||
Career.type == 'main'
|
||||
)
|
||||
)
|
||||
career = career_result.scalar_one_or_none()
|
||||
|
||||
if career:
|
||||
character.main_career_id = main_career_id
|
||||
character.main_career_stage = main_career_stage or 1
|
||||
|
||||
# 创建职业关联
|
||||
char_career = CharacterCareer(
|
||||
character_id=character.id,
|
||||
career_id=main_career_id,
|
||||
career_type='main',
|
||||
current_stage=main_career_stage or 1,
|
||||
stage_progress=0
|
||||
)
|
||||
db.add(char_career)
|
||||
else:
|
||||
warnings.append(f"角色'{name}'的主职业ID不存在,已忽略职业信息")
|
||||
|
||||
# 处理副职业(如果有)
|
||||
sub_careers = char_data.get("sub_careers")
|
||||
if sub_careers and not is_organization:
|
||||
try:
|
||||
sub_careers_data = json.loads(sub_careers) if isinstance(sub_careers, str) else sub_careers
|
||||
|
||||
if isinstance(sub_careers_data, list):
|
||||
valid_sub_careers = []
|
||||
|
||||
for sub_data in sub_careers_data[:2]: # 最多2个副职业
|
||||
if isinstance(sub_data, dict):
|
||||
career_id = sub_data.get('career_id')
|
||||
stage = sub_data.get('stage', 1)
|
||||
|
||||
if career_id:
|
||||
# 验证副职业是否存在
|
||||
career_result = await db.execute(
|
||||
select(Career).where(
|
||||
Career.id == career_id,
|
||||
Career.project_id == project_id,
|
||||
Career.type == 'sub'
|
||||
)
|
||||
)
|
||||
career = career_result.scalar_one_or_none()
|
||||
|
||||
if career:
|
||||
valid_sub_careers.append({
|
||||
'career_id': career_id,
|
||||
'stage': stage
|
||||
})
|
||||
|
||||
# 创建副职业关联
|
||||
char_career = CharacterCareer(
|
||||
character_id=character.id,
|
||||
career_id=career_id,
|
||||
career_type='sub',
|
||||
current_stage=stage,
|
||||
stage_progress=0
|
||||
)
|
||||
db.add(char_career)
|
||||
|
||||
if valid_sub_careers:
|
||||
character.sub_careers = json.dumps(valid_sub_careers, ensure_ascii=False)
|
||||
elif sub_careers_data:
|
||||
warnings.append(f"角色'{name}'的副职业ID不存在,已忽略副职业信息")
|
||||
except Exception as e:
|
||||
warnings.append(f"角色'{name}'的副职业数据解析失败: {str(e)}")
|
||||
|
||||
# 如果是组织,创建Organization记录
|
||||
if is_organization:
|
||||
organization = Organization(
|
||||
character_id=character.id,
|
||||
project_id=project_id,
|
||||
member_count=0,
|
||||
power_level=char_data.get("power_level", 50),
|
||||
location=char_data.get("location"),
|
||||
motto=char_data.get("motto"),
|
||||
color=char_data.get("color")
|
||||
)
|
||||
db.add(organization)
|
||||
await db.flush()
|
||||
imported_organizations.append(name)
|
||||
else:
|
||||
imported_characters.append(name)
|
||||
|
||||
logger.info(f"导入{'组织' if is_organization else '角色'}成功: {name}")
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"导入角色'{char_data.get('name', f'第{idx+1}个')}'失败: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
errors.append(error_msg)
|
||||
continue
|
||||
|
||||
# 提交事务
|
||||
await db.commit()
|
||||
|
||||
total = len(imported_characters) + len(imported_organizations)
|
||||
|
||||
result = {
|
||||
"success": True,
|
||||
"message": f"成功导入 {total} 个角色/组织",
|
||||
"statistics": {
|
||||
"total": len(characters_data),
|
||||
"imported": total,
|
||||
"skipped": len(skipped),
|
||||
"errors": len(errors)
|
||||
},
|
||||
"details": {
|
||||
"imported_characters": imported_characters,
|
||||
"imported_organizations": imported_organizations,
|
||||
"skipped": skipped,
|
||||
"errors": errors
|
||||
},
|
||||
"warnings": warnings
|
||||
}
|
||||
|
||||
logger.info(f"角色/组织导入完成: 成功{total}个,跳过{len(skipped)}个,失败{len(errors)}个")
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"导入角色/组织失败: {str(e)}", exc_info=True)
|
||||
return {
|
||||
"success": False,
|
||||
"message": f"导入失败: {str(e)}",
|
||||
"statistics": {
|
||||
"total": len(characters_data) if "data" in data else 0,
|
||||
"imported": len(imported_characters) + len(imported_organizations),
|
||||
"skipped": len(skipped),
|
||||
"errors": len(errors)
|
||||
},
|
||||
"details": {
|
||||
"imported_characters": imported_characters,
|
||||
"imported_organizations": imported_organizations,
|
||||
"skipped": skipped,
|
||||
"errors": errors
|
||||
},
|
||||
"warnings": warnings
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def validate_characters_import(data: Dict) -> Dict[str, Any]:
|
||||
"""
|
||||
验证角色/组织导入数据
|
||||
|
||||
Args:
|
||||
data: 导入的JSON数据
|
||||
|
||||
Returns:
|
||||
Dict: 验证结果
|
||||
"""
|
||||
errors = []
|
||||
warnings = []
|
||||
|
||||
# 检查版本
|
||||
version = data.get("version", "")
|
||||
if not version:
|
||||
errors.append("缺少版本信息")
|
||||
elif version != ImportExportService.SUPPORTED_VERSION:
|
||||
warnings.append(f"版本不匹配: 导入文件版本为 {version}, 当前支持版本为 {ImportExportService.SUPPORTED_VERSION}")
|
||||
|
||||
# 检查导出类型
|
||||
export_type = data.get("export_type", "")
|
||||
if export_type != "characters":
|
||||
errors.append(f"导出类型错误: 期望'characters',实际'{export_type}'")
|
||||
|
||||
# 检查数据字段
|
||||
if "data" not in data:
|
||||
errors.append("缺少data字段")
|
||||
elif not isinstance(data["data"], list):
|
||||
errors.append("data字段必须是数组")
|
||||
else:
|
||||
characters_data = data["data"]
|
||||
|
||||
# 统计信息
|
||||
character_count = sum(1 for c in characters_data if not c.get("is_organization", False))
|
||||
org_count = sum(1 for c in characters_data if c.get("is_organization", False))
|
||||
|
||||
# 检查必填字段
|
||||
for idx, char_data in enumerate(characters_data):
|
||||
if not char_data.get("name"):
|
||||
errors.append(f"第{idx+1}个角色缺少name字段")
|
||||
|
||||
statistics = {
|
||||
"characters": character_count,
|
||||
"organizations": org_count
|
||||
}
|
||||
|
||||
if "data" not in data or errors:
|
||||
statistics = {"characters": 0, "organizations": 0}
|
||||
|
||||
return {
|
||||
"valid": len(errors) == 0,
|
||||
"version": version,
|
||||
"statistics": statistics,
|
||||
"errors": errors,
|
||||
"warnings": warnings
|
||||
}
|
||||
@@ -0,0 +1,159 @@
|
||||
"""JSON 处理工具类"""
|
||||
import json
|
||||
import re
|
||||
from typing import Any, Dict, List, Union
|
||||
from app.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def clean_json_response(text: str) -> str:
|
||||
"""清洗 AI 返回的 JSON(改进版 - 流式安全)"""
|
||||
try:
|
||||
if not text:
|
||||
logger.warning("⚠️ clean_json_response: 输入为空")
|
||||
return text
|
||||
|
||||
original_length = len(text)
|
||||
logger.debug(f"🔍 开始清洗JSON,原始长度: {original_length}")
|
||||
|
||||
# 去除 markdown 代码块
|
||||
text = re.sub(r'^```json\s*\n?', '', text, flags=re.MULTILINE | re.IGNORECASE)
|
||||
text = re.sub(r'^```\s*\n?', '', text, flags=re.MULTILINE)
|
||||
text = re.sub(r'\n?```\s*$', '', text, flags=re.MULTILINE)
|
||||
text = text.strip()
|
||||
|
||||
if len(text) != original_length:
|
||||
logger.debug(f" 移除markdown后长度: {len(text)}")
|
||||
|
||||
# 尝试直接解析(快速路径)
|
||||
try:
|
||||
json.loads(text)
|
||||
logger.debug(f"✅ 直接解析成功,无需清洗")
|
||||
return text
|
||||
except:
|
||||
pass
|
||||
|
||||
# 找到第一个 { 或 [
|
||||
start = -1
|
||||
for i, c in enumerate(text):
|
||||
if c in ('{', '['):
|
||||
start = i
|
||||
break
|
||||
|
||||
if start == -1:
|
||||
logger.warning(f"⚠️ 未找到JSON起始符号 {{ 或 [")
|
||||
logger.debug(f" 文本预览: {text[:200]}")
|
||||
return text
|
||||
|
||||
if start > 0:
|
||||
logger.debug(f" 跳过前{start}个字符")
|
||||
text = text[start:]
|
||||
|
||||
# 改进的括号匹配算法(更严格的字符串处理)
|
||||
stack = []
|
||||
i = 0
|
||||
end = -1
|
||||
in_string = False
|
||||
|
||||
while i < len(text):
|
||||
c = text[i]
|
||||
|
||||
# 处理字符串状态
|
||||
if c == '"':
|
||||
if not in_string:
|
||||
# 进入字符串
|
||||
in_string = True
|
||||
else:
|
||||
# 检查是否是转义的引号
|
||||
num_backslashes = 0
|
||||
j = i - 1
|
||||
while j >= 0 and text[j] == '\\':
|
||||
num_backslashes += 1
|
||||
j -= 1
|
||||
|
||||
# 偶数个反斜杠表示引号未被转义,字符串结束
|
||||
if num_backslashes % 2 == 0:
|
||||
in_string = False
|
||||
|
||||
i += 1
|
||||
continue
|
||||
|
||||
# 在字符串内部,跳过所有字符
|
||||
if in_string:
|
||||
i += 1
|
||||
continue
|
||||
|
||||
# 处理括号(只有在字符串外部才有效)
|
||||
if c == '{' or c == '[':
|
||||
stack.append(c)
|
||||
elif c == '}':
|
||||
if len(stack) > 0 and stack[-1] == '{':
|
||||
stack.pop()
|
||||
if len(stack) == 0:
|
||||
end = i + 1
|
||||
logger.debug(f"✅ 找到JSON结束位置: {end}")
|
||||
break
|
||||
elif len(stack) > 0:
|
||||
# 括号不匹配,可能是损坏的JSON,尝试继续
|
||||
logger.warning(f"⚠️ 括号不匹配:遇到 }} 但栈顶是 {stack[-1]}")
|
||||
else:
|
||||
# 栈为空遇到 },忽略多余的闭合括号
|
||||
logger.warning(f"⚠️ 遇到多余的 }},忽略")
|
||||
elif c == ']':
|
||||
if len(stack) > 0 and stack[-1] == '[':
|
||||
stack.pop()
|
||||
if len(stack) == 0:
|
||||
end = i + 1
|
||||
logger.debug(f"✅ 找到JSON结束位置: {end}")
|
||||
break
|
||||
elif len(stack) > 0:
|
||||
# 括号不匹配,可能是损坏的JSON,尝试继续
|
||||
logger.warning(f"⚠️ 括号不匹配:遇到 ] 但栈顶是 {stack[-1]}")
|
||||
else:
|
||||
# 栈为空遇到 ],忽略多余的闭合括号
|
||||
logger.warning(f"⚠️ 遇到多余的 ],忽略")
|
||||
|
||||
i += 1
|
||||
|
||||
# 检查未闭合的字符串
|
||||
if in_string:
|
||||
logger.warning(f"⚠️ 字符串未闭合,JSON可能不完整")
|
||||
|
||||
# 提取结果
|
||||
if end > 0:
|
||||
result = text[:end]
|
||||
logger.debug(f"✅ JSON清洗完成,结果长度: {len(result)}")
|
||||
else:
|
||||
result = text
|
||||
logger.warning(f"⚠️ 未找到JSON结束位置,返回全部内容(长度: {len(result)})")
|
||||
logger.debug(f" 栈状态: {stack}")
|
||||
|
||||
# 验证清洗后的结果
|
||||
try:
|
||||
json.loads(result)
|
||||
logger.debug(f"✅ 清洗后JSON验证成功")
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"❌ 清洗后JSON仍然无效: {e}")
|
||||
logger.debug(f" 结果预览: {result[:500]}")
|
||||
logger.debug(f" 结果结尾: ...{result[-200:]}")
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ clean_json_response 出错: {e}")
|
||||
logger.error(f" 文本长度: {len(text) if text else 0}")
|
||||
logger.error(f" 文本预览: {text[:200] if text else 'None'}")
|
||||
raise
|
||||
|
||||
|
||||
def parse_json(text: str) -> Union[Dict, List]:
|
||||
"""解析 JSON"""
|
||||
try:
|
||||
cleaned = clean_json_response(text)
|
||||
return json.loads(cleaned)
|
||||
except Exception as e:
|
||||
logger.error(f"❌ parse_json 出错: {e}")
|
||||
logger.error(f" 原始文本长度: {len(text) if text else 0}")
|
||||
logger.error(f" 清洗后文本长度: {len(cleaned) if cleaned else 0}")
|
||||
raise
|
||||
@@ -0,0 +1,347 @@
|
||||
"""MCP插件测试服务 - 专门处理插件测试逻辑
|
||||
|
||||
重构后使用统一的MCPClientFacade门面来管理所有MCP操作。
|
||||
"""
|
||||
|
||||
import time
|
||||
import json
|
||||
from typing import Dict, Any, Optional
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.models.mcp_plugin import MCPPlugin
|
||||
from app.models.settings import Settings as UserSettings
|
||||
from app.mcp import mcp_client, MCPPluginConfig # 使用新的统一门面
|
||||
from app.services.ai_service import create_user_ai_service
|
||||
from app.schemas.mcp_plugin import MCPTestResult
|
||||
from app.services.prompt_service import prompt_service
|
||||
from app.logger import get_logger
|
||||
from app.user_manager import User
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class MCPTestService:
|
||||
"""MCP插件测试服务(使用统一门面重构)"""
|
||||
|
||||
async def _ensure_plugin_registered(
|
||||
self,
|
||||
plugin: MCPPlugin,
|
||||
user_id: str
|
||||
) -> bool:
|
||||
"""
|
||||
确保插件已注册到统一门面
|
||||
|
||||
Args:
|
||||
plugin: 插件配置
|
||||
user_id: 用户ID
|
||||
|
||||
Returns:
|
||||
是否成功
|
||||
"""
|
||||
if plugin.plugin_type in ("http", "streamable_http", "sse") and plugin.server_url:
|
||||
return await mcp_client.ensure_registered(
|
||||
user_id=user_id,
|
||||
plugin_name=plugin.plugin_name,
|
||||
url=plugin.server_url,
|
||||
plugin_type=plugin.plugin_type,
|
||||
headers=plugin.headers
|
||||
)
|
||||
return False
|
||||
|
||||
async def test_plugin_connection(
|
||||
self,
|
||||
plugin: MCPPlugin,
|
||||
user_id: str
|
||||
) -> MCPTestResult:
|
||||
"""
|
||||
简单连接测试
|
||||
|
||||
Args:
|
||||
plugin: 插件配置
|
||||
user_id: 用户ID
|
||||
|
||||
Returns:
|
||||
测试结果
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# 确保插件已注册
|
||||
registered = await self._ensure_plugin_registered(plugin, user_id)
|
||||
if not registered:
|
||||
return MCPTestResult(
|
||||
success=False,
|
||||
message="插件注册失败",
|
||||
error="无法创建MCP客户端",
|
||||
suggestions=["请检查插件配置", "请确认服务器URL正确"]
|
||||
)
|
||||
|
||||
# 使用统一门面测试连接
|
||||
test_result = await mcp_client.test_connection(user_id, plugin.plugin_name)
|
||||
|
||||
end_time = time.time()
|
||||
response_time = round((end_time - start_time) * 1000, 2)
|
||||
|
||||
if test_result["success"]:
|
||||
return MCPTestResult(
|
||||
success=True,
|
||||
message=f"✅ 连接测试成功",
|
||||
response_time_ms=response_time,
|
||||
tools_count=test_result.get("tools_count", 0),
|
||||
suggestions=[
|
||||
f"响应时间: {response_time}ms",
|
||||
f"可用工具数: {test_result.get('tools_count', 0)}"
|
||||
]
|
||||
)
|
||||
else:
|
||||
return MCPTestResult(
|
||||
success=False,
|
||||
message="❌ 连接测试失败",
|
||||
response_time_ms=response_time,
|
||||
error=test_result.get("message", "未知错误"),
|
||||
error_type=test_result.get("error_type"),
|
||||
suggestions=[
|
||||
"请检查服务器是否在线",
|
||||
"请确认配置正确",
|
||||
"请检查API Key是否有效"
|
||||
]
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
end_time = time.time()
|
||||
response_time = round((end_time - start_time) * 1000, 2)
|
||||
|
||||
logger.error(f"测试插件失败: {plugin.plugin_name}, 错误: {e}")
|
||||
|
||||
return MCPTestResult(
|
||||
success=False,
|
||||
message="❌ 测试失败",
|
||||
response_time_ms=response_time,
|
||||
error=str(e),
|
||||
error_type=type(e).__name__,
|
||||
suggestions=[
|
||||
"请检查服务器是否在线",
|
||||
"请确认配置正确",
|
||||
"请检查API Key是否有效"
|
||||
]
|
||||
)
|
||||
|
||||
async def test_plugin_with_ai(
|
||||
self,
|
||||
plugin: MCPPlugin,
|
||||
user: User,
|
||||
db_session: AsyncSession
|
||||
) -> MCPTestResult:
|
||||
"""
|
||||
使用AI进行智能工具调用测试
|
||||
|
||||
Args:
|
||||
plugin: 插件配置
|
||||
user: 用户对象
|
||||
db_session: 数据库会话
|
||||
|
||||
Returns:
|
||||
测试结果
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# 1. 先进行连接测试
|
||||
connection_result = await self.test_plugin_connection(plugin, user.user_id)
|
||||
|
||||
if not connection_result.success:
|
||||
return connection_result
|
||||
|
||||
# 2. 使用统一门面获取工具列表
|
||||
tools = await mcp_client.get_tools(user.user_id, plugin.plugin_name)
|
||||
|
||||
if not tools:
|
||||
return MCPTestResult(
|
||||
success=False,
|
||||
message="插件没有提供任何工具",
|
||||
error="工具列表为空",
|
||||
response_time_ms=connection_result.response_time_ms,
|
||||
suggestions=["请检查插件配置", "请确认MCP服务器正常运行"]
|
||||
)
|
||||
|
||||
# 3. 获取用户的AI设置
|
||||
settings_result = await db_session.execute(
|
||||
select(UserSettings).where(UserSettings.user_id == user.user_id)
|
||||
)
|
||||
user_settings = settings_result.scalar_one_or_none()
|
||||
|
||||
if not user_settings or not user_settings.api_key:
|
||||
# 没有AI配置,返回简单测试结果
|
||||
logger.warning("用户未配置AI服务,跳过智能测试")
|
||||
return MCPTestResult(
|
||||
success=True,
|
||||
message=f"✅ 连接测试成功(未配置AI,跳过工具调用测试)",
|
||||
response_time_ms=connection_result.response_time_ms,
|
||||
tools_count=len(tools),
|
||||
suggestions=[
|
||||
f"连接测试: 成功",
|
||||
f"可用工具数: {len(tools)}",
|
||||
"提示: 配置AI服务后可进行智能工具调用测试"
|
||||
]
|
||||
)
|
||||
|
||||
# 4. 使用AI选择工具并生成测试参数
|
||||
logger.info(f"使用AI分析工具并生成测试计划...")
|
||||
|
||||
ai_service = create_user_ai_service(
|
||||
api_provider=user_settings.api_provider,
|
||||
api_key=user_settings.api_key,
|
||||
api_base_url=user_settings.api_base_url,
|
||||
model_name=user_settings.llm_model,
|
||||
temperature=0.3,
|
||||
max_tokens=1000
|
||||
)
|
||||
|
||||
# 使用统一门面转换为OpenAI Function Calling格式
|
||||
openai_tools = mcp_client.format_tools_for_openai(tools, plugin.plugin_name)
|
||||
|
||||
logger.info(f"📋 转换后的OpenAI工具数量: {len(openai_tools)}")
|
||||
logger.debug(f"📋 OpenAI工具列表: {[t['function']['name'] for t in openai_tools]}")
|
||||
|
||||
# 调用AI选择工具(使用自定义模板系统)
|
||||
prompts = await prompt_service.get_mcp_tool_test_prompts(
|
||||
plugin_name=plugin.plugin_name,
|
||||
user_id=user.user_id,
|
||||
db=db_session
|
||||
)
|
||||
|
||||
# 使用 generate_text 进行 Function Calling(非流式)
|
||||
ai_response = await ai_service.generate_text(
|
||||
prompt=prompts["user"],
|
||||
system_prompt=prompts["system"],
|
||||
tools=openai_tools,
|
||||
tool_choice="auto"
|
||||
)
|
||||
|
||||
accumulated_text = ai_response.get("content", "")
|
||||
tool_calls = ai_response.get("tool_calls")
|
||||
|
||||
# 5. 检查AI是否返回工具调用
|
||||
if not tool_calls:
|
||||
logger.error(f"❌ AI未返回工具调用")
|
||||
return MCPTestResult(
|
||||
success=False,
|
||||
message="❌ AI Function Calling失败",
|
||||
error=f"AI未返回工具调用请求。响应: {accumulated_text[:200] if accumulated_text else 'N/A'}",
|
||||
tools_count=len(tools),
|
||||
suggestions=[
|
||||
"请确认使用的AI模型支持Function Calling",
|
||||
f"当前Provider: {user_settings.api_provider}",
|
||||
f"当前模型: {user_settings.llm_model}"
|
||||
]
|
||||
)
|
||||
|
||||
# 6. 解析工具调用
|
||||
tool_call = tool_calls[0]
|
||||
function = tool_call["function"]
|
||||
tool_name_with_prefix = function["name"]
|
||||
test_arguments = function["arguments"]
|
||||
|
||||
if isinstance(test_arguments, str):
|
||||
try:
|
||||
# 使用统一的JSON清洗方法
|
||||
cleaned_args = ai_service._clean_json_response(test_arguments)
|
||||
test_arguments = json.loads(cleaned_args)
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"❌ 解析AI参数失败: {e}")
|
||||
return MCPTestResult(
|
||||
success=False,
|
||||
message="❌ AI返回的参数格式错误",
|
||||
error=f"无法解析参数JSON: {str(e)}",
|
||||
tools_count=len(tools)
|
||||
)
|
||||
|
||||
# 解析插件名和工具名
|
||||
try:
|
||||
_, tool_name = mcp_client.parse_function_name(tool_name_with_prefix)
|
||||
except ValueError:
|
||||
tool_name = tool_name_with_prefix
|
||||
|
||||
logger.info(f"🤖 AI选择的工具: {tool_name}")
|
||||
logger.info(f"📝 AI生成的参数: {test_arguments}")
|
||||
|
||||
# 7. 使用统一门面调用MCP工具
|
||||
call_start = time.time()
|
||||
try:
|
||||
tool_result = await mcp_client.call_tool(
|
||||
user_id=user.user_id,
|
||||
plugin_name=plugin.plugin_name,
|
||||
tool_name=tool_name,
|
||||
arguments=test_arguments
|
||||
)
|
||||
|
||||
call_end = time.time()
|
||||
call_time = round((call_end - call_start) * 1000, 2)
|
||||
total_time = round((call_end - start_time) * 1000, 2)
|
||||
|
||||
# 格式化结果
|
||||
result_str = str(tool_result)
|
||||
if len(result_str) > 800:
|
||||
result_preview = result_str[:800] + "\n...(结果已截断)"
|
||||
else:
|
||||
result_preview = result_str
|
||||
|
||||
return MCPTestResult(
|
||||
success=True,
|
||||
message=f"✅ Function Calling测试成功!工具 '{tool_name}' 调用正常",
|
||||
response_time_ms=total_time,
|
||||
tools_count=len(tools),
|
||||
suggestions=[
|
||||
f"🤖 AI选择: {tool_name}",
|
||||
f"📝 参数: {json.dumps(test_arguments, ensure_ascii=False)}",
|
||||
f"⏱️ 耗时: {call_time}ms",
|
||||
f"📊 结果:\n{result_preview}"
|
||||
]
|
||||
)
|
||||
|
||||
except Exception as call_error:
|
||||
call_end = time.time()
|
||||
total_time = round((call_end - start_time) * 1000, 2)
|
||||
|
||||
logger.warning(f"工具调用失败: {tool_name}, 错误: {call_error}")
|
||||
|
||||
return MCPTestResult(
|
||||
success=True, # 连接成功就算测试通过
|
||||
message=f"⚠️ 连接成功,但工具调用失败",
|
||||
response_time_ms=total_time,
|
||||
tools_count=len(tools),
|
||||
error=f"工具 '{tool_name}' 调用失败: {str(call_error)}",
|
||||
suggestions=[
|
||||
f"✅ 连接测试: 成功",
|
||||
f"❌ 工具调用测试: 失败",
|
||||
f"🤖 AI选择: {tool_name}",
|
||||
f"❌ 错误: {str(call_error)}",
|
||||
"💡 可能原因: API Key无效、参数错误或服务限制"
|
||||
]
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
end_time = time.time()
|
||||
total_time = round((end_time - start_time) * 1000, 2)
|
||||
|
||||
logger.error(f"测试插件失败: {plugin.plugin_name}, 错误: {e}")
|
||||
|
||||
return MCPTestResult(
|
||||
success=False,
|
||||
message="❌ 测试失败",
|
||||
response_time_ms=total_time,
|
||||
error=str(e),
|
||||
error_type=type(e).__name__,
|
||||
suggestions=[
|
||||
"请检查服务器是否在线",
|
||||
"请确认配置正确",
|
||||
"请检查API Key是否有效"
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
# 全局单例
|
||||
mcp_test_service = MCPTestService()
|
||||
@@ -0,0 +1,235 @@
|
||||
"""MCP工具加载器 - 统一的工具获取入口
|
||||
|
||||
在AI请求之前,自动检查用户MCP配置并加载可用工具。
|
||||
"""
|
||||
from typing import Optional, List, Dict, Any
|
||||
from datetime import datetime, timedelta
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.logger import get_logger
|
||||
from app.models.mcp_plugin import MCPPlugin
|
||||
from app.mcp import mcp_client
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class UserToolsCache:
|
||||
"""用户工具缓存条目"""
|
||||
tools: Optional[List[Dict[str, Any]]]
|
||||
expire_time: datetime
|
||||
hit_count: int = 0
|
||||
|
||||
|
||||
class MCPToolsLoader:
|
||||
"""
|
||||
MCP工具加载器
|
||||
|
||||
负责:
|
||||
1. 检查用户是否配置并启用了MCP插件
|
||||
2. 从各个启用的插件加载工具列表
|
||||
3. 将工具转换为OpenAI Function Calling格式
|
||||
4. 缓存结果以提升性能
|
||||
"""
|
||||
|
||||
_instance: Optional['MCPToolsLoader'] = None
|
||||
|
||||
def __new__(cls):
|
||||
"""单例模式"""
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
cls._instance._initialized = False
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
# 用户工具缓存: user_id -> UserToolsCache
|
||||
self._cache: Dict[str, UserToolsCache] = {}
|
||||
|
||||
# 缓存TTL(5分钟)
|
||||
self._cache_ttl = timedelta(minutes=5)
|
||||
|
||||
self._initialized = True
|
||||
logger.info("✅ MCPToolsLoader 初始化完成")
|
||||
|
||||
async def has_enabled_plugins(
|
||||
self,
|
||||
user_id: str,
|
||||
db_session: AsyncSession
|
||||
) -> bool:
|
||||
"""
|
||||
检查用户是否有启用的MCP插件
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
db_session: 数据库会话
|
||||
|
||||
Returns:
|
||||
是否有启用的插件
|
||||
"""
|
||||
try:
|
||||
query = select(MCPPlugin.id).where(
|
||||
MCPPlugin.user_id == user_id,
|
||||
MCPPlugin.enabled == True,
|
||||
MCPPlugin.plugin_type.in_(["http", "streamable_http", "sse"])
|
||||
).limit(1)
|
||||
|
||||
result = await db_session.execute(query)
|
||||
return result.scalar() is not None
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"检查用户MCP插件失败: {e}")
|
||||
return False
|
||||
|
||||
async def get_user_tools(
|
||||
self,
|
||||
user_id: str,
|
||||
db_session: AsyncSession,
|
||||
use_cache: bool = True,
|
||||
force_refresh: bool = False
|
||||
) -> Optional[List[Dict[str, Any]]]:
|
||||
"""
|
||||
获取用户的MCP工具列表(OpenAI格式)
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
db_session: 数据库会话
|
||||
use_cache: 是否使用缓存
|
||||
force_refresh: 是否强制刷新
|
||||
|
||||
Returns:
|
||||
- None: 用户未配置或未启用任何MCP插件
|
||||
- []: 有配置但没有可用工具
|
||||
- List[Dict]: OpenAI Function Calling格式的工具列表
|
||||
"""
|
||||
now = datetime.now()
|
||||
|
||||
# 检查缓存
|
||||
if use_cache and not force_refresh and user_id in self._cache:
|
||||
cache_entry = self._cache[user_id]
|
||||
if now < cache_entry.expire_time:
|
||||
cache_entry.hit_count += 1
|
||||
logger.debug(f"🎯 用户工具缓存命中: {user_id} (命中次数: {cache_entry.hit_count})")
|
||||
return cache_entry.tools
|
||||
else:
|
||||
del self._cache[user_id]
|
||||
logger.debug(f"⏰ 用户工具缓存过期: {user_id}")
|
||||
|
||||
# 从数据库加载
|
||||
try:
|
||||
tools = await self._load_user_tools(user_id, db_session)
|
||||
|
||||
# 更新缓存
|
||||
self._cache[user_id] = UserToolsCache(
|
||||
tools=tools,
|
||||
expire_time=now + self._cache_ttl
|
||||
)
|
||||
|
||||
if tools:
|
||||
logger.info(f"🔧 用户 {user_id} 加载了 {len(tools)} 个MCP工具")
|
||||
else:
|
||||
logger.debug(f"📭 用户 {user_id} 没有可用的MCP工具")
|
||||
|
||||
return tools
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 加载用户MCP工具失败: {e}")
|
||||
return None
|
||||
|
||||
async def _load_user_tools(
|
||||
self,
|
||||
user_id: str,
|
||||
db_session: AsyncSession
|
||||
) -> Optional[List[Dict[str, Any]]]:
|
||||
"""
|
||||
从数据库加载用户启用的MCP插件并获取工具
|
||||
"""
|
||||
# 查询启用的插件
|
||||
query = select(MCPPlugin).where(
|
||||
MCPPlugin.user_id == user_id,
|
||||
MCPPlugin.enabled == True,
|
||||
MCPPlugin.plugin_type.in_(["http", "streamable_http", "sse"])
|
||||
).order_by(MCPPlugin.sort_order)
|
||||
|
||||
result = await db_session.execute(query)
|
||||
plugins = result.scalars().all()
|
||||
|
||||
if not plugins:
|
||||
return None
|
||||
|
||||
all_tools = []
|
||||
|
||||
for plugin in plugins:
|
||||
try:
|
||||
# 确定插件类型
|
||||
plugin_type = plugin.plugin_type
|
||||
if plugin_type == "http":
|
||||
plugin_type = "streamable_http" # 默认使用streamable_http
|
||||
|
||||
# 确保插件已注册到MCP客户端
|
||||
await mcp_client.ensure_registered(
|
||||
user_id=user_id,
|
||||
plugin_name=plugin.plugin_name,
|
||||
url=plugin.server_url,
|
||||
plugin_type=plugin_type,
|
||||
headers=plugin.headers
|
||||
)
|
||||
|
||||
# 获取工具列表
|
||||
plugin_tools = await mcp_client.get_tools(user_id, plugin.plugin_name)
|
||||
|
||||
# 转换为OpenAI格式
|
||||
formatted = mcp_client.format_tools_for_openai(plugin_tools, plugin.plugin_name)
|
||||
all_tools.extend(formatted)
|
||||
|
||||
logger.debug(f"✅ 从插件 {plugin.plugin_name} 加载了 {len(formatted)} 个工具")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"⚠️ 加载插件 {plugin.plugin_name} 工具失败: {e}")
|
||||
continue
|
||||
|
||||
return all_tools if all_tools else None
|
||||
|
||||
def invalidate_cache(self, user_id: Optional[str] = None):
|
||||
"""
|
||||
使缓存失效
|
||||
|
||||
Args:
|
||||
user_id: 用户ID,为None时清空所有缓存
|
||||
"""
|
||||
if user_id:
|
||||
if user_id in self._cache:
|
||||
del self._cache[user_id]
|
||||
logger.debug(f"🧹 清理用户工具缓存: {user_id}")
|
||||
else:
|
||||
count = len(self._cache)
|
||||
self._cache.clear()
|
||||
logger.info(f"🧹 清理所有用户工具缓存 ({count}个)")
|
||||
|
||||
def get_cache_stats(self) -> Dict[str, Any]:
|
||||
"""获取缓存统计"""
|
||||
now = datetime.now()
|
||||
return {
|
||||
"total_entries": len(self._cache),
|
||||
"total_hits": sum(e.hit_count for e in self._cache.values()),
|
||||
"cache_ttl_minutes": self._cache_ttl.total_seconds() / 60,
|
||||
"entries": [
|
||||
{
|
||||
"user_id": uid,
|
||||
"tools_count": len(e.tools) if e.tools else 0,
|
||||
"hit_count": e.hit_count,
|
||||
"expired": now >= e.expire_time,
|
||||
"expire_time": e.expire_time.isoformat()
|
||||
}
|
||||
for uid, e in self._cache.items()
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
# 全局单例
|
||||
mcp_tools_loader = MCPToolsLoader()
|
||||
@@ -10,10 +10,60 @@ import hashlib
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# 配置模型缓存目录(不设置离线模式,让它自动选择)
|
||||
# 如果本地有模型就用本地的,没有才联网下载
|
||||
# 配置模型缓存目录
|
||||
# 优先使用 backend/embedding 目录(打包后的实际位置)
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
if 'SENTENCE_TRANSFORMERS_HOME' not in os.environ:
|
||||
os.environ['SENTENCE_TRANSFORMERS_HOME'] = 'embedding'
|
||||
# 根据运行环境确定模型目录
|
||||
if getattr(sys, 'frozen', False):
|
||||
# PyInstaller 打包后 - 需要检查多个可能的位置
|
||||
exe_dir = Path(sys.executable).parent
|
||||
|
||||
# 检查顺序:
|
||||
# 1. _MEIPASS/backend/embedding (临时解压目录)
|
||||
# 2. exe同级/_internal/backend/embedding
|
||||
# 3. exe同级/backend/embedding
|
||||
possible_paths = []
|
||||
|
||||
if hasattr(sys, '_MEIPASS'):
|
||||
possible_paths.append(Path(sys._MEIPASS) / 'backend' / 'embedding')
|
||||
|
||||
possible_paths.extend([
|
||||
exe_dir / '_internal' / 'backend' / 'embedding',
|
||||
exe_dir / 'backend' / 'embedding',
|
||||
exe_dir / '_internal' / 'embedding',
|
||||
exe_dir / 'embedding'
|
||||
])
|
||||
|
||||
model_dir = None
|
||||
for path in possible_paths:
|
||||
if path.exists():
|
||||
model_dir = path
|
||||
logger.info(f"🔧 找到打包环境模型目录: {model_dir}")
|
||||
break
|
||||
|
||||
if model_dir:
|
||||
os.environ['SENTENCE_TRANSFORMERS_HOME'] = str(model_dir)
|
||||
else:
|
||||
# 最后降级方案
|
||||
fallback_dir = exe_dir / 'embedding'
|
||||
os.environ['SENTENCE_TRANSFORMERS_HOME'] = str(fallback_dir)
|
||||
logger.warning(f"⚠️ 未找到预打包模型,使用降级目录: {fallback_dir}")
|
||||
logger.warning(f" 检查过的路径: {[str(p) for p in possible_paths]}")
|
||||
else:
|
||||
# 开发模式,从当前文件位置向上找到项目根目录
|
||||
base_dir = Path(__file__).parent.parent.parent
|
||||
model_dir = base_dir / 'backend' / 'embedding'
|
||||
if model_dir.exists():
|
||||
os.environ['SENTENCE_TRANSFORMERS_HOME'] = str(model_dir)
|
||||
logger.info(f"🔧 设置开发环境模型目录: {model_dir}")
|
||||
else:
|
||||
# 降级到项目根目录的 embedding
|
||||
fallback_dir = base_dir / 'embedding'
|
||||
os.environ['SENTENCE_TRANSFORMERS_HOME'] = str(fallback_dir)
|
||||
logger.info(f"🔧 使用降级模型目录: {fallback_dir}")
|
||||
|
||||
|
||||
class MemoryService:
|
||||
@@ -44,9 +94,10 @@ class MemoryService:
|
||||
# 初始化多语言embedding模型(支持中文)
|
||||
logger.info("🔄 正在加载Embedding模型...")
|
||||
|
||||
# 确保模型缓存目录存在
|
||||
model_cache_dir = 'embedding'
|
||||
# 使用环境变量中配置的模型目录
|
||||
model_cache_dir = os.environ.get('SENTENCE_TRANSFORMERS_HOME', 'embedding')
|
||||
os.makedirs(model_cache_dir, exist_ok=True)
|
||||
logger.info(f"📂 使用模型缓存目录: {os.path.abspath(model_cache_dir)}")
|
||||
|
||||
# 调试信息:打印环境变量和路径
|
||||
logger.info(f"📂 当前工作目录: {os.getcwd()}")
|
||||
@@ -56,40 +107,91 @@ class MemoryService:
|
||||
logger.info(f"🔧 HF_HUB_OFFLINE: {os.environ.get('HF_HUB_OFFLINE', '未设置')}")
|
||||
|
||||
# 检查模型目录内容
|
||||
if os.path.exists(model_cache_dir):
|
||||
abs_cache_dir = os.path.abspath(model_cache_dir)
|
||||
logger.info(f"📂 检查模型缓存目录: {abs_cache_dir}")
|
||||
|
||||
if os.path.exists(abs_cache_dir):
|
||||
logger.info(f"📁 模型目录存在,检查内容...")
|
||||
try:
|
||||
items = os.listdir(model_cache_dir)
|
||||
logger.info(f"📁 模型目录内容: {items}")
|
||||
items = os.listdir(abs_cache_dir)
|
||||
logger.info(f"📁 模型目录内容 ({len(items)} 项): {items}")
|
||||
|
||||
# 检查是否有预期的模型文件夹
|
||||
expected_model_dir = os.path.join(model_cache_dir, 'models--sentence-transformers--paraphrase-multilingual-MiniLM-L12-v2')
|
||||
expected_model_dir = os.path.join(abs_cache_dir, 'models--sentence-transformers--paraphrase-multilingual-MiniLM-L12-v2')
|
||||
logger.info(f"🔍 检查预期路径: {expected_model_dir}")
|
||||
|
||||
if os.path.exists(expected_model_dir):
|
||||
logger.info(f"✅ 找到本地模型目录: {expected_model_dir}")
|
||||
logger.info(f"✅ 找到本地模型目录!")
|
||||
# 检查快照目录
|
||||
snapshots_dir = os.path.join(expected_model_dir, 'snapshots')
|
||||
if os.path.exists(snapshots_dir):
|
||||
snapshots = os.listdir(snapshots_dir)
|
||||
logger.info(f"📁 模型快照: {snapshots}")
|
||||
logger.info(f"📁 模型快照 ({len(snapshots)} 个): {snapshots}")
|
||||
# 检查是否有有效的快照
|
||||
if snapshots:
|
||||
logger.info(f"✅ 发现有效快照,可以使用离线模式")
|
||||
else:
|
||||
logger.warning(f"⚠️ 未找到本地模型目录: {expected_model_dir}")
|
||||
logger.warning(f"⚠️ 未找到本地模型目录")
|
||||
logger.warning(f" 预期位置: {expected_model_dir}")
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 检查模型目录失败: {str(e)}")
|
||||
import traceback
|
||||
logger.error(f" 堆栈: {traceback.format_exc()}")
|
||||
else:
|
||||
logger.warning(f"⚠️ 模型目录不存在: {os.path.abspath(model_cache_dir)}")
|
||||
logger.warning(f"⚠️ 模型目录不存在: {abs_cache_dir}")
|
||||
|
||||
try:
|
||||
logger.info("🔄 尝试加载主模型: sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2")
|
||||
# 优先使用本地缓存的模型
|
||||
# cache_folder会让模型优先从本地加载,只有不存在时才联网下载
|
||||
# 注意:不要设置local_files_only=True,这会阻止fallback到联网下载
|
||||
self.embedding_model = SentenceTransformer(
|
||||
'sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2',
|
||||
cache_folder=model_cache_dir,
|
||||
device='cpu', # 明确指定使用CPU
|
||||
trust_remote_code=False # 安全起见
|
||||
|
||||
# 使用绝对路径检查本地模型
|
||||
abs_cache_dir = os.path.abspath(model_cache_dir)
|
||||
local_model_path = os.path.join(
|
||||
abs_cache_dir,
|
||||
'models--sentence-transformers--paraphrase-multilingual-MiniLM-L12-v2'
|
||||
)
|
||||
logger.info("✅ Embedding模型加载成功 (paraphrase-multilingual-MiniLM-L12-v2)")
|
||||
|
||||
logger.info(f"🔍 检查本地模型路径: {local_model_path}")
|
||||
logger.info(f"🔍 路径存在检查: {os.path.exists(local_model_path)}")
|
||||
|
||||
# 检查快照目录是否存在且有内容
|
||||
snapshots_dir = os.path.join(local_model_path, 'snapshots')
|
||||
has_valid_model = False
|
||||
if os.path.exists(snapshots_dir):
|
||||
try:
|
||||
snapshots = os.listdir(snapshots_dir)
|
||||
if snapshots:
|
||||
logger.info(f"✅ 发现本地模型快照: {snapshots}")
|
||||
has_valid_model = True
|
||||
except Exception as e:
|
||||
logger.warning(f"⚠️ 检查快照失败: {e}")
|
||||
|
||||
# 优先尝试从本地路径加载
|
||||
if has_valid_model:
|
||||
logger.info(f"✅ 检测到完整本地模型,使用离线模式加载")
|
||||
try:
|
||||
self.embedding_model = SentenceTransformer(
|
||||
'sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2',
|
||||
cache_folder=abs_cache_dir,
|
||||
device='cpu',
|
||||
trust_remote_code=True,
|
||||
local_files_only=True # 强制使用本地文件
|
||||
)
|
||||
logger.info("✅ Embedding模型加载成功 (离线模式)")
|
||||
except Exception as local_err:
|
||||
logger.warning(f"⚠️ 离线模式加载失败: {str(local_err)}")
|
||||
logger.info("🔄 尝试在线模式...")
|
||||
raise local_err
|
||||
else:
|
||||
logger.info("📥 本地模型不完整或不存在,将联网下载...")
|
||||
logger.info(f" 下载后将保存到: {abs_cache_dir}")
|
||||
self.embedding_model = SentenceTransformer(
|
||||
'sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2',
|
||||
cache_folder=abs_cache_dir,
|
||||
device='cpu',
|
||||
trust_remote_code=True,
|
||||
local_files_only=False # 允许联网下载
|
||||
)
|
||||
logger.info("✅ Embedding模型加载成功 (在线下载)")
|
||||
except Exception as e:
|
||||
logger.warning(f"⚠️ 无法加载多语言模型: {str(e)}")
|
||||
logger.error(f"❌ 详细错误: {repr(e)}")
|
||||
@@ -659,6 +761,44 @@ class MemoryService:
|
||||
logger.error(f"❌ 删除章节记忆失败: {str(e)}")
|
||||
return False
|
||||
|
||||
async def delete_project_memories(
|
||||
self,
|
||||
user_id: str,
|
||||
project_id: str
|
||||
) -> bool:
|
||||
"""
|
||||
删除指定项目的所有记忆(包括向量数据库)
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
project_id: 项目ID
|
||||
|
||||
Returns:
|
||||
是否删除成功
|
||||
"""
|
||||
try:
|
||||
# 生成collection名称
|
||||
user_hash = hashlib.sha256(user_id.encode()).hexdigest()[:8]
|
||||
project_hash = hashlib.sha256(project_id.encode()).hexdigest()[:8]
|
||||
collection_name = f"u_{user_hash}_p_{project_hash}"
|
||||
|
||||
# 删除整个collection(这会清理所有向量数据)
|
||||
try:
|
||||
self.client.delete_collection(name=collection_name)
|
||||
logger.info(f"🗑️ 已删除项目{project_id[:8]}的向量数据库collection: {collection_name}")
|
||||
return True
|
||||
except Exception as e:
|
||||
# 如果collection不存在,也算成功
|
||||
if "does not exist" in str(e).lower():
|
||||
logger.info(f"ℹ️ 项目{project_id[:8]}的collection不存在,无需删除")
|
||||
return True
|
||||
else:
|
||||
raise
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 删除项目记忆失败: {str(e)}")
|
||||
return False
|
||||
|
||||
async def update_memory(
|
||||
self,
|
||||
user_id: str,
|
||||
|
||||
@@ -20,11 +20,14 @@ class LinuxDOOAuthService:
|
||||
self.client_secret = settings.LINUXDO_CLIENT_SECRET
|
||||
self.redirect_uri = settings.LINUXDO_REDIRECT_URI
|
||||
|
||||
# 验证redirect_uri配置
|
||||
# 如果未配置,使用默认值(本地开发)
|
||||
if not self.redirect_uri:
|
||||
raise ValueError(
|
||||
"LINUXDO_REDIRECT_URI 未配置!\n"
|
||||
"请在 .env 文件中设置正确的回调地址:\n"
|
||||
self.redirect_uri = "http://localhost:8000/api/auth/callback"
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.warning(
|
||||
"⚠️ LINUXDO_REDIRECT_URI 未配置,使用默认值: http://localhost:8000/api/auth/callback\n"
|
||||
"如需使用 OAuth 登录,请在 .env 文件中配置:\n"
|
||||
"本地开发: LINUXDO_REDIRECT_URI=http://localhost:8000/api/auth/callback\n"
|
||||
"Docker部署: LINUXDO_REDIRECT_URI=https://your-domain.com/api/auth/callback"
|
||||
)
|
||||
|
||||
@@ -1,9 +1,12 @@
|
||||
"""剧情分析服务 - 自动分析章节的钩子、伏笔、冲突等元素"""
|
||||
from typing import Dict, Any, List, Optional
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from app.services.ai_service import AIService
|
||||
from app.services.prompt_service import prompt_service, PromptService
|
||||
from app.logger import get_logger
|
||||
import json
|
||||
import re
|
||||
import asyncio
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -11,169 +14,6 @@ logger = get_logger(__name__)
|
||||
class PlotAnalyzer:
|
||||
"""剧情分析器 - 使用AI分析章节内容"""
|
||||
|
||||
# AI分析提示词模板
|
||||
ANALYSIS_PROMPT = """你是一位专业的小说编辑和剧情分析师。请深度分析以下章节内容:
|
||||
|
||||
**章节信息:**
|
||||
- 章节: 第{chapter_number}章
|
||||
- 标题: {title}
|
||||
- 字数: {word_count}字
|
||||
|
||||
**章节内容:**
|
||||
{content}
|
||||
|
||||
---
|
||||
|
||||
**分析任务:**
|
||||
请从专业编辑的角度,全面分析这一章节:
|
||||
|
||||
### 1. 剧情钩子 (Hooks) - 吸引读者的元素
|
||||
识别能够吸引读者继续阅读的关键元素:
|
||||
- **悬念钩子**: 未解之谜、疑问、谜团
|
||||
- **情感钩子**: 引发共鸣的情感点、触动心弦的时刻
|
||||
- **冲突钩子**: 矛盾对抗、紧张局势
|
||||
- **认知钩子**: 颠覆认知的信息、惊人真相
|
||||
|
||||
每个钩子需要:
|
||||
- 类型分类
|
||||
- 具体内容描述
|
||||
- 强度评分(1-10)
|
||||
- 出现位置(开头/中段/结尾)
|
||||
- **关键词**: 【必填】从章节原文中逐字复制一段关键文本(8-25字),必须是原文中真实存在的连续文字,用于在文本中精确定位。不要概括或改写,必须原样复制!
|
||||
|
||||
### 2. 伏笔分析 (Foreshadowing)
|
||||
- **埋下的新伏笔**: 描述内容、预期作用、隐藏程度(1-10)
|
||||
- **回收的旧伏笔**: 呼应哪一章、回收效果评分
|
||||
- **伏笔质量**: 巧妙性和合理性评估
|
||||
- **关键词**: 【必填】从章节原文中逐字复制一段关键文本(8-25字),必须是原文中真实存在的连续文字,用于在文本中精确定位。不要概括或改写,必须原样复制!
|
||||
|
||||
### 3. 冲突分析 (Conflict)
|
||||
- 冲突类型: 人与人/人与己/人与环境/人与社会
|
||||
- 冲突各方及其立场
|
||||
- 冲突强度评分(1-10)
|
||||
- 冲突解决进度(0-100%)
|
||||
|
||||
### 4. 情感曲线 (Emotional Arc)
|
||||
- 主导情绪: 紧张/温馨/悲伤/激昂/平静等
|
||||
- 情感强度(1-10)
|
||||
- 情绪变化轨迹描述
|
||||
|
||||
### 5. 角色状态追踪 (Character Development)
|
||||
对每个出场角色分析:
|
||||
- 心理状态变化(前→后)
|
||||
- 关系变化
|
||||
- 关键行动和决策
|
||||
- 成长或退步
|
||||
|
||||
### 6. 关键情节点 (Plot Points)
|
||||
列出3-5个核心情节点:
|
||||
- 情节内容
|
||||
- 类型(revelation/conflict/resolution/transition)
|
||||
- 重要性(0.0-1.0)
|
||||
- 对故事的影响
|
||||
- **关键词**: 【必填】从章节原文中逐字复制一段关键文本(8-25字),必须是原文中真实存在的连续文字,用于在文本中精确定位。不要概括或改写,必须原样复制!
|
||||
|
||||
### 7. 场景与节奏
|
||||
- 主要场景
|
||||
- 叙事节奏(快/中/慢)
|
||||
- 对话与描写的比例
|
||||
|
||||
### 8. 质量评分
|
||||
- 节奏把控: 1-10分
|
||||
- 吸引力: 1-10分
|
||||
- 连贯性: 1-10分
|
||||
- 整体质量: 1-10分
|
||||
|
||||
### 9. 改进建议
|
||||
提供3-5条具体的改进建议
|
||||
|
||||
---
|
||||
|
||||
**输出格式(纯JSON,不要markdown标记):**
|
||||
|
||||
{{
|
||||
"hooks": [
|
||||
{{
|
||||
"type": "悬念",
|
||||
"content": "具体描述",
|
||||
"strength": 8,
|
||||
"position": "中段",
|
||||
"keyword": "必须从原文逐字复制的文本片段"
|
||||
}}
|
||||
],
|
||||
"foreshadows": [
|
||||
{{
|
||||
"content": "伏笔内容",
|
||||
"type": "planted",
|
||||
"strength": 7,
|
||||
"subtlety": 8,
|
||||
"reference_chapter": null,
|
||||
"keyword": "必须从原文逐字复制的文本片段"
|
||||
}}
|
||||
],
|
||||
"conflict": {{
|
||||
"types": ["人与人", "人与己"],
|
||||
"parties": ["主角-复仇", "反派-维护现状"],
|
||||
"level": 8,
|
||||
"description": "冲突描述",
|
||||
"resolution_progress": 0.3
|
||||
}},
|
||||
"emotional_arc": {{
|
||||
"primary_emotion": "紧张",
|
||||
"intensity": 8,
|
||||
"curve": "平静→紧张→高潮→释放",
|
||||
"secondary_emotions": ["期待", "焦虑"]
|
||||
}},
|
||||
"character_states": [
|
||||
{{
|
||||
"character_name": "张三",
|
||||
"state_before": "犹豫",
|
||||
"state_after": "坚定",
|
||||
"psychological_change": "心理变化描述",
|
||||
"key_event": "触发事件",
|
||||
"relationship_changes": {{"李四": "关系改善"}}
|
||||
}}
|
||||
],
|
||||
"plot_points": [
|
||||
{{
|
||||
"content": "情节点描述",
|
||||
"type": "revelation",
|
||||
"importance": 0.9,
|
||||
"impact": "推动故事发展",
|
||||
"keyword": "必须从原文逐字复制的文本片段"
|
||||
}}
|
||||
],
|
||||
"scenes": [
|
||||
{{
|
||||
"location": "地点",
|
||||
"atmosphere": "氛围",
|
||||
"duration": "时长估计"
|
||||
}}
|
||||
],
|
||||
"pacing": "varied",
|
||||
"dialogue_ratio": 0.4,
|
||||
"description_ratio": 0.3,
|
||||
"scores": {{
|
||||
"pacing": 8,
|
||||
"engagement": 9,
|
||||
"coherence": 8,
|
||||
"overall": 8.5
|
||||
}},
|
||||
"plot_stage": "发展",
|
||||
"suggestions": [
|
||||
"具体建议1",
|
||||
"具体建议2"
|
||||
]
|
||||
}}
|
||||
|
||||
**重要提示:**
|
||||
1. 每个钩子、伏笔、情节点的keyword字段是必填的,不能为空
|
||||
2. keyword必须是从章节原文中逐字复制的文本,长度8-25字
|
||||
3. keyword用于在前端标注文本位置,所以必须能在原文中精确找到
|
||||
4. 不要使用概括性语句或改写后的文字作为keyword
|
||||
|
||||
只返回JSON,不要其他说明。"""
|
||||
|
||||
def __init__(self, ai_service: AIService):
|
||||
"""
|
||||
初始化剧情分析器
|
||||
@@ -189,63 +29,135 @@ class PlotAnalyzer:
|
||||
chapter_number: int,
|
||||
title: str,
|
||||
content: str,
|
||||
word_count: int
|
||||
word_count: int,
|
||||
user_id: str = None,
|
||||
db: AsyncSession = None,
|
||||
max_retries: int = 3
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
分析单章内容
|
||||
分析单章内容(带重试机制)
|
||||
|
||||
Args:
|
||||
chapter_number: 章节号
|
||||
title: 章节标题
|
||||
content: 章节内容
|
||||
word_count: 字数
|
||||
user_id: 用户ID(用于获取自定义提示词)
|
||||
db: 数据库会话(用于查询自定义提示词)
|
||||
max_retries: 最大重试次数,默认3次
|
||||
|
||||
Returns:
|
||||
分析结果字典,失败返回None
|
||||
"""
|
||||
logger.info(f"🔍 开始分析第{chapter_number}章: {title}")
|
||||
|
||||
# 如果内容过长,截取前8000字(避免超token)
|
||||
analysis_content = content[:8000] if len(content) > 8000 else content
|
||||
|
||||
# 获取自定义提示词模板
|
||||
try:
|
||||
logger.info(f"🔍 开始分析第{chapter_number}章: {title}")
|
||||
|
||||
# 如果内容过长,截取前8000字(避免超token)
|
||||
analysis_content = content[:8000] if len(content) > 8000 else content
|
||||
|
||||
# 构建提示词
|
||||
prompt = self.ANALYSIS_PROMPT.format(
|
||||
chapter_number=chapter_number,
|
||||
title=title,
|
||||
word_count=word_count,
|
||||
content=analysis_content
|
||||
)
|
||||
|
||||
# 调用AI进行分析
|
||||
# 注意:不指定max_tokens,使用用户在设置中配置的值
|
||||
logger.info(f" 调用AI分析(内容长度: {len(analysis_content)}字)...")
|
||||
response = await self.ai_service.generate_text(
|
||||
prompt=prompt,
|
||||
temperature=0.3 # 降低温度以获得更稳定的JSON输出
|
||||
)
|
||||
|
||||
# 解析JSON结果
|
||||
analysis_result = self._parse_analysis_response(response)
|
||||
|
||||
if analysis_result:
|
||||
logger.info(f"✅ 第{chapter_number}章分析完成")
|
||||
logger.info(f" - 钩子: {len(analysis_result.get('hooks', []))}个")
|
||||
logger.info(f" - 伏笔: {len(analysis_result.get('foreshadows', []))}个")
|
||||
logger.info(f" - 情节点: {len(analysis_result.get('plot_points', []))}个")
|
||||
logger.info(f" - 整体评分: {analysis_result.get('scores', {}).get('overall', 'N/A')}")
|
||||
return analysis_result
|
||||
if user_id and db:
|
||||
template = await PromptService.get_template("PLOT_ANALYSIS", user_id, db)
|
||||
else:
|
||||
logger.error(f"❌ 第{chapter_number}章分析失败: JSON解析错误")
|
||||
return None
|
||||
|
||||
# 降级到系统默认模板
|
||||
template = PromptService.PLOT_ANALYSIS
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 章节分析异常: {str(e)}")
|
||||
return None
|
||||
logger.warning(f"⚠️ 获取提示词模板失败,使用默认模板: {str(e)}")
|
||||
template = PromptService.PLOT_ANALYSIS
|
||||
|
||||
# 格式化提示词
|
||||
prompt = PromptService.format_prompt(
|
||||
template,
|
||||
chapter_number=chapter_number,
|
||||
title=title,
|
||||
word_count=word_count,
|
||||
content=analysis_content
|
||||
)
|
||||
|
||||
last_error = None
|
||||
|
||||
for attempt in range(1, max_retries + 1):
|
||||
try:
|
||||
# 调用AI进行分析
|
||||
logger.info(f" 📡 调用AI分析(内容长度: {len(analysis_content)}字, 尝试 {attempt}/{max_retries})...")
|
||||
accumulated_text = ""
|
||||
|
||||
try:
|
||||
async for chunk in self.ai_service.generate_text_stream(
|
||||
prompt=prompt,
|
||||
temperature=0.3 # 降低温度以获得更稳定的JSON输出
|
||||
):
|
||||
accumulated_text += chunk
|
||||
except GeneratorExit:
|
||||
# 流式响应被中断
|
||||
logger.warning(f"⚠️ 流式响应被中断(GeneratorExit),已累积 {len(accumulated_text)} 字符")
|
||||
# 如果已经累积了足够内容,继续尝试解析
|
||||
if len(accumulated_text) < 100:
|
||||
raise Exception("流式响应中断,内容不足")
|
||||
except Exception as stream_error:
|
||||
logger.error(f"❌ 流式生成出错: {str(stream_error)}")
|
||||
raise
|
||||
|
||||
# 检查响应是否为空
|
||||
if not accumulated_text or len(accumulated_text.strip()) < 10:
|
||||
logger.warning(f"⚠️ AI响应为空或过短(长度: {len(accumulated_text)}), 尝试 {attempt}/{max_retries}")
|
||||
last_error = "AI响应为空或过短"
|
||||
if attempt < max_retries:
|
||||
wait_time = min(2 ** attempt, 10)
|
||||
logger.info(f" ⏳ 等待 {wait_time} 秒后重试...")
|
||||
await asyncio.sleep(wait_time)
|
||||
continue
|
||||
else:
|
||||
logger.error(f"❌ 第{chapter_number}章分析失败: AI响应为空,已达最大重试次数")
|
||||
return None
|
||||
|
||||
# 提取内容
|
||||
response_text = accumulated_text
|
||||
logger.debug(f" 收到AI响应,长度: {len(response_text)} 字符")
|
||||
|
||||
# 解析JSON结果
|
||||
analysis_result = self._parse_analysis_response(response_text)
|
||||
|
||||
if analysis_result:
|
||||
logger.info(f"✅ 第{chapter_number}章分析完成 (尝试 {attempt}/{max_retries})")
|
||||
logger.info(f" - 钩子: {len(analysis_result.get('hooks', []))}个")
|
||||
logger.info(f" - 伏笔: {len(analysis_result.get('foreshadows', []))}个")
|
||||
logger.info(f" - 情节点: {len(analysis_result.get('plot_points', []))}个")
|
||||
logger.info(f" - 整体评分: {analysis_result.get('scores', {}).get('overall', 'N/A')}")
|
||||
return analysis_result
|
||||
else:
|
||||
# JSON解析失败,重试
|
||||
logger.warning(f"⚠️ JSON解析失败, 尝试 {attempt}/{max_retries}")
|
||||
last_error = "JSON解析失败"
|
||||
if attempt < max_retries:
|
||||
wait_time = min(2 ** attempt, 10)
|
||||
logger.info(f" ⏳ 等待 {wait_time} 秒后重试...")
|
||||
await asyncio.sleep(wait_time)
|
||||
continue
|
||||
else:
|
||||
logger.error(f"❌ 第{chapter_number}章分析失败: JSON解析错误,已达最大重试次数")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
last_error = str(e)
|
||||
logger.error(f"❌ 章节分析异常(尝试 {attempt}/{max_retries}): {last_error}")
|
||||
|
||||
if attempt < max_retries:
|
||||
wait_time = min(2 ** attempt, 10)
|
||||
logger.info(f" ⏳ 等待 {wait_time} 秒后重试...")
|
||||
await asyncio.sleep(wait_time)
|
||||
continue
|
||||
else:
|
||||
logger.error(f"❌ 第{chapter_number}章分析失败: {last_error},已达最大重试次数")
|
||||
return None
|
||||
|
||||
# 不应该到达这里,但作为安全措施
|
||||
logger.error(f"❌ 第{chapter_number}章分析失败: {last_error}")
|
||||
return None
|
||||
|
||||
def _parse_analysis_response(self, response: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
解析AI返回的分析结果
|
||||
解析AI返回的分析结果(使用统一的JSON清洗方法)
|
||||
|
||||
Args:
|
||||
response: AI返回的文本
|
||||
@@ -254,13 +166,8 @@ class PlotAnalyzer:
|
||||
解析后的字典,失败返回None
|
||||
"""
|
||||
try:
|
||||
# 清理响应文本
|
||||
cleaned = response.strip()
|
||||
|
||||
# 移除可能的markdown标记
|
||||
cleaned = re.sub(r'^```json\s*', '', cleaned)
|
||||
cleaned = re.sub(r'^```\s*', '', cleaned)
|
||||
cleaned = re.sub(r'\s*```$', '', cleaned)
|
||||
# 使用统一的JSON清洗方法
|
||||
cleaned = self.ai_service._clean_json_response(response)
|
||||
|
||||
# 尝试解析JSON
|
||||
result = json.loads(cleaned)
|
||||
@@ -272,22 +179,12 @@ class PlotAnalyzer:
|
||||
logger.warning(f"⚠️ 分析结果缺少字段: {field}")
|
||||
result[field] = [] if field != 'scores' else {}
|
||||
|
||||
logger.info("✅ 成功解析分析结果")
|
||||
return result
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"❌ JSON解析失败: {str(e)}")
|
||||
logger.error(f" 原始响应(前500字): {response[:500]}")
|
||||
|
||||
# 尝试提取JSON部分
|
||||
json_match = re.search(r'\{[\s\S]*\}', response)
|
||||
if json_match:
|
||||
try:
|
||||
result = json.loads(json_match.group())
|
||||
logger.info("✅ 通过正则提取成功解析JSON")
|
||||
return result
|
||||
except:
|
||||
pass
|
||||
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 解析异常: {str(e)}")
|
||||
@@ -298,7 +195,8 @@ class PlotAnalyzer:
|
||||
analysis: Dict[str, Any],
|
||||
chapter_id: str,
|
||||
chapter_number: int,
|
||||
chapter_content: str = ""
|
||||
chapter_content: str = "",
|
||||
chapter_title: str = ""
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
从分析结果中提取记忆片段
|
||||
@@ -308,6 +206,7 @@ class PlotAnalyzer:
|
||||
chapter_id: 章节ID
|
||||
chapter_number: 章节号
|
||||
chapter_content: 章节完整内容(用于计算位置)
|
||||
chapter_title: 章节标题
|
||||
|
||||
Returns:
|
||||
记忆片段列表
|
||||
@@ -315,6 +214,38 @@ class PlotAnalyzer:
|
||||
memories = []
|
||||
|
||||
try:
|
||||
# 【新增】0. 提取章节摘要作为记忆(用于语义检索相关章节)
|
||||
chapter_summary = ""
|
||||
|
||||
# 尝试从分析结果获取摘要
|
||||
if analysis.get('summary'):
|
||||
chapter_summary = analysis.get('summary')
|
||||
# 或者从情节点组合生成摘要
|
||||
elif analysis.get('plot_points'):
|
||||
plot_summaries = [p.get('content', '') for p in analysis.get('plot_points', [])[:3]]
|
||||
chapter_summary = ";".join(plot_summaries)
|
||||
# 或者使用内容前300字
|
||||
elif chapter_content:
|
||||
chapter_summary = chapter_content[:300] + ("..." if len(chapter_content) > 300 else "")
|
||||
|
||||
# 如果有摘要,添加到记忆中
|
||||
if chapter_summary:
|
||||
memories.append({
|
||||
'type': 'chapter_summary',
|
||||
'content': chapter_summary,
|
||||
'title': f"第{chapter_number}章《{chapter_title}》摘要",
|
||||
'metadata': {
|
||||
'chapter_id': chapter_id,
|
||||
'chapter_number': chapter_number,
|
||||
'importance_score': 0.6, # 中等重要性
|
||||
'tags': ['摘要', '章节概览', chapter_title],
|
||||
'is_foreshadow': 0,
|
||||
'text_position': 0,
|
||||
'text_length': len(chapter_summary)
|
||||
}
|
||||
})
|
||||
logger.info(f" ✅ 添加章节摘要记忆: {len(chapter_summary)}字")
|
||||
|
||||
# 1. 提取钩子作为记忆
|
||||
for i, hook in enumerate(analysis.get('hooks', [])):
|
||||
if hook.get('strength', 0) >= 6: # 只保存强度>=6的钩子
|
||||
|
||||
@@ -0,0 +1,645 @@
|
||||
"""大纲剧情展开服务 - 将大纲节点展开为多个章节"""
|
||||
from typing import List, Dict, Any, Optional
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, func
|
||||
import json
|
||||
|
||||
from app.models.outline import Outline
|
||||
from app.models.project import Project
|
||||
from app.models.character import Character
|
||||
from app.models.chapter import Chapter
|
||||
from app.services.ai_service import AIService
|
||||
from app.services.prompt_service import prompt_service, PromptService
|
||||
from app.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class PlotExpansionService:
|
||||
"""大纲剧情展开服务"""
|
||||
|
||||
def __init__(self, ai_service: AIService):
|
||||
self.ai_service = ai_service
|
||||
|
||||
async def analyze_outline_for_chapters(
|
||||
self,
|
||||
outline: Outline,
|
||||
project: Project,
|
||||
db: AsyncSession,
|
||||
target_chapter_count: int = 3,
|
||||
expansion_strategy: str = "balanced",
|
||||
enable_scene_analysis: bool = True,
|
||||
provider: Optional[str] = None,
|
||||
model: Optional[str] = None,
|
||||
batch_size: int = 5,
|
||||
progress_callback: Optional[callable] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
分析单个大纲,生成多章节规划(支持分批生成)
|
||||
|
||||
Args:
|
||||
outline: 大纲对象
|
||||
project: 项目对象
|
||||
db: 数据库会话
|
||||
target_chapter_count: 目标生成章节数
|
||||
expansion_strategy: 展开策略(balanced/climax/detail)
|
||||
enable_scene_analysis: 是否启用场景级分析
|
||||
provider: AI提供商
|
||||
model: AI模型
|
||||
batch_size: 每批生成的章节数(默认5章)
|
||||
progress_callback: 进度回调函数(可选)
|
||||
|
||||
Returns:
|
||||
章节规划列表
|
||||
"""
|
||||
logger.info(f"开始分析大纲 {outline.id},目标生成 {target_chapter_count} 章")
|
||||
|
||||
# 如果章节数较少,直接生成
|
||||
if target_chapter_count <= batch_size:
|
||||
return await self._generate_chapters_single_batch(
|
||||
outline=outline,
|
||||
project=project,
|
||||
db=db,
|
||||
target_chapter_count=target_chapter_count,
|
||||
expansion_strategy=expansion_strategy,
|
||||
enable_scene_analysis=enable_scene_analysis,
|
||||
provider=provider,
|
||||
model=model
|
||||
)
|
||||
|
||||
# 章节数较多,分批生成
|
||||
logger.info(f"章节数({target_chapter_count})超过批次大小({batch_size}),启用分批生成")
|
||||
return await self._generate_chapters_in_batches(
|
||||
outline=outline,
|
||||
project=project,
|
||||
db=db,
|
||||
target_chapter_count=target_chapter_count,
|
||||
expansion_strategy=expansion_strategy,
|
||||
enable_scene_analysis=enable_scene_analysis,
|
||||
provider=provider,
|
||||
model=model,
|
||||
batch_size=batch_size,
|
||||
progress_callback=progress_callback
|
||||
)
|
||||
|
||||
async def _generate_chapters_single_batch(
|
||||
self,
|
||||
outline: Outline,
|
||||
project: Project,
|
||||
db: AsyncSession,
|
||||
target_chapter_count: int,
|
||||
expansion_strategy: str,
|
||||
enable_scene_analysis: bool,
|
||||
provider: Optional[str],
|
||||
model: Optional[str]
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""单批次生成章节规划"""
|
||||
# 获取角色信息
|
||||
characters_result = await db.execute(
|
||||
select(Character).where(Character.project_id == project.id)
|
||||
)
|
||||
characters = characters_result.scalars().all()
|
||||
characters_info = "\n".join([
|
||||
f"- {char.name} ({'组织' if char.is_organization else '角色'}, {char.role_type}): "
|
||||
f"{char.personality[:100] if char.personality else '暂无描述'}"
|
||||
for char in characters
|
||||
])
|
||||
|
||||
# 获取大纲上下文(前后大纲)
|
||||
context_info = await self._get_outline_context(outline, project.id, db)
|
||||
|
||||
# 获取自定义提示词模板
|
||||
template = await PromptService.get_template("OUTLINE_EXPAND_SINGLE", project.user_id, db)
|
||||
# 格式化提示词
|
||||
prompt = PromptService.format_prompt(
|
||||
template,
|
||||
project_title=project.title,
|
||||
project_genre=project.genre or '通用',
|
||||
project_theme=project.theme or '未设定',
|
||||
project_narrative_perspective=project.narrative_perspective or '第三人称',
|
||||
project_world_time_period=project.world_time_period or '未设定',
|
||||
project_world_location=project.world_location or '未设定',
|
||||
project_world_atmosphere=project.world_atmosphere or '未设定',
|
||||
characters_info=characters_info or '暂无角色',
|
||||
outline_order_index=outline.order_index,
|
||||
outline_title=outline.title,
|
||||
outline_content=outline.content,
|
||||
context_info=context_info,
|
||||
strategy_instruction=expansion_strategy,
|
||||
target_chapter_count=target_chapter_count,
|
||||
scene_instruction="", # 暂时为空
|
||||
scene_field="" # 暂时为空
|
||||
)
|
||||
|
||||
# 调用AI生成章节规划
|
||||
logger.info(f"调用AI生成章节规划...")
|
||||
accumulated_text = ""
|
||||
async for chunk in self.ai_service.generate_text_stream(
|
||||
prompt=prompt,
|
||||
provider=provider,
|
||||
model=model
|
||||
):
|
||||
accumulated_text += chunk
|
||||
|
||||
# 提取内容
|
||||
ai_content = accumulated_text
|
||||
|
||||
# 解析AI响应
|
||||
chapter_plans = self._parse_expansion_response(ai_content, outline.id)
|
||||
|
||||
logger.info(f"成功生成 {len(chapter_plans)} 个章节规划")
|
||||
return chapter_plans
|
||||
|
||||
async def _generate_chapters_in_batches(
|
||||
self,
|
||||
outline: Outline,
|
||||
project: Project,
|
||||
db: AsyncSession,
|
||||
target_chapter_count: int,
|
||||
expansion_strategy: str,
|
||||
enable_scene_analysis: bool,
|
||||
provider: Optional[str],
|
||||
model: Optional[str],
|
||||
batch_size: int,
|
||||
progress_callback: Optional[callable]
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""分批生成章节规划"""
|
||||
# 计算批次数
|
||||
total_batches = (target_chapter_count + batch_size - 1) // batch_size
|
||||
logger.info(f"分批生成计划: 总共{target_chapter_count}章,分{total_batches}批,每批{batch_size}章")
|
||||
|
||||
# 获取角色信息(所有批次共用)
|
||||
characters_result = await db.execute(
|
||||
select(Character).where(Character.project_id == project.id)
|
||||
)
|
||||
characters = characters_result.scalars().all()
|
||||
characters_info = "\n".join([
|
||||
f"- {char.name} ({'组织' if char.is_organization else '角色'}, {char.role_type}): "
|
||||
f"{char.personality[:100] if char.personality else '暂无描述'}"
|
||||
for char in characters
|
||||
])
|
||||
|
||||
# 获取大纲上下文
|
||||
context_info = await self._get_outline_context(outline, project.id, db)
|
||||
|
||||
all_chapter_plans = []
|
||||
|
||||
for batch_num in range(total_batches):
|
||||
# 计算当前批次的章节数
|
||||
remaining_chapters = target_chapter_count - len(all_chapter_plans)
|
||||
current_batch_size = min(batch_size, remaining_chapters)
|
||||
current_start_index = len(all_chapter_plans) + 1
|
||||
|
||||
logger.info(f"开始生成第{batch_num + 1}/{total_batches}批,章节范围: {current_start_index}-{current_start_index + current_batch_size - 1}")
|
||||
|
||||
# 回调通知进度
|
||||
if progress_callback:
|
||||
await progress_callback(batch_num + 1, total_batches, current_start_index, current_batch_size)
|
||||
|
||||
# 构建当前批次的提示词(包含已生成章节的上下文)
|
||||
previous_context = ""
|
||||
if all_chapter_plans:
|
||||
previous_summaries = []
|
||||
for ch in all_chapter_plans[-3:]: # 只显示最近3章
|
||||
previous_summaries.append(
|
||||
f"第{ch['sub_index']}节《{ch['title']}》: {ch['plot_summary'][:100]}..."
|
||||
)
|
||||
previous_context = f"""
|
||||
【已生成章节概要】(接续生成,注意衔接)
|
||||
{chr(10).join(previous_summaries)}
|
||||
|
||||
⚠️ 当前是第{current_start_index}-{current_start_index + current_batch_size - 1}节(共{target_chapter_count}节中的一部分)
|
||||
"""
|
||||
# 获取自定义提示词模板
|
||||
template = await PromptService.get_template("OUTLINE_EXPAND_MULTI", project.user_id, db)
|
||||
# 格式化提示词
|
||||
prompt = PromptService.format_prompt(
|
||||
template,
|
||||
project_title=project.title,
|
||||
project_genre=project.genre or '通用',
|
||||
project_theme=project.theme or '未设定',
|
||||
project_narrative_perspective=project.narrative_perspective or '第三人称',
|
||||
project_world_time_period=project.world_time_period or '未设定',
|
||||
project_world_location=project.world_location or '未设定',
|
||||
project_world_atmosphere=project.world_atmosphere or '未设定',
|
||||
characters_info=characters_info or '暂无角色',
|
||||
outline_order_index=outline.order_index,
|
||||
outline_title=outline.title,
|
||||
outline_content=outline.content,
|
||||
context_info=context_info,
|
||||
previous_context=previous_context,
|
||||
strategy_instruction=expansion_strategy,
|
||||
start_index=current_start_index,
|
||||
end_index=current_start_index + current_batch_size - 1,
|
||||
target_chapter_count=current_batch_size,
|
||||
scene_instruction="", # 暂时为空
|
||||
scene_field="" # 暂时为空
|
||||
)
|
||||
|
||||
# 调用AI生成当前批次
|
||||
logger.info(f"调用AI生成第{batch_num + 1}批...")
|
||||
accumulated_text = ""
|
||||
async for chunk in self.ai_service.generate_text_stream(
|
||||
prompt=prompt,
|
||||
provider=provider,
|
||||
model=model
|
||||
):
|
||||
accumulated_text += chunk
|
||||
|
||||
# 提取内容
|
||||
ai_content = accumulated_text
|
||||
|
||||
# 解析AI响应
|
||||
batch_plans = self._parse_expansion_response(ai_content, outline.id)
|
||||
|
||||
# 调整sub_index以保持连续性
|
||||
for i, plan in enumerate(batch_plans):
|
||||
plan["sub_index"] = current_start_index + i
|
||||
|
||||
all_chapter_plans.extend(batch_plans)
|
||||
|
||||
logger.info(f"第{batch_num + 1}批生成完成,本批生成{len(batch_plans)}章,累计{len(all_chapter_plans)}章")
|
||||
|
||||
logger.info(f"分批生成完成,共生成 {len(all_chapter_plans)} 个章节规划")
|
||||
return all_chapter_plans
|
||||
|
||||
async def batch_expand_outlines(
|
||||
self,
|
||||
project_id: str,
|
||||
db: AsyncSession,
|
||||
ai_service: AIService,
|
||||
target_chapters_per_outline: int = 3,
|
||||
expansion_strategy: str = "balanced",
|
||||
provider: Optional[str] = None,
|
||||
model: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
批量展开所有大纲为章节
|
||||
|
||||
Returns:
|
||||
{
|
||||
"total_outlines": 总大纲数,
|
||||
"total_chapters_planned": 规划的总章节数,
|
||||
"expansions": [每个大纲的展开结果]
|
||||
}
|
||||
"""
|
||||
logger.info(f"开始批量展开项目 {project_id} 的所有大纲")
|
||||
|
||||
# 获取项目
|
||||
project_result = await db.execute(
|
||||
select(Project).where(Project.id == project_id)
|
||||
)
|
||||
project = project_result.scalar_one_or_none()
|
||||
if not project:
|
||||
raise ValueError(f"项目 {project_id} 不存在")
|
||||
|
||||
# 获取所有大纲
|
||||
outlines_result = await db.execute(
|
||||
select(Outline)
|
||||
.where(Outline.project_id == project_id)
|
||||
.order_by(Outline.order_index)
|
||||
)
|
||||
outlines = outlines_result.scalars().all()
|
||||
|
||||
if not outlines:
|
||||
logger.warning(f"项目 {project_id} 没有大纲")
|
||||
return {
|
||||
"total_outlines": 0,
|
||||
"total_chapters_planned": 0,
|
||||
"expansions": []
|
||||
}
|
||||
|
||||
# 逐个展开大纲
|
||||
expansions = []
|
||||
total_chapters = 0
|
||||
|
||||
for outline in outlines:
|
||||
try:
|
||||
chapter_plans = await self.analyze_outline_for_chapters(
|
||||
outline=outline,
|
||||
project=project,
|
||||
db=db,
|
||||
target_chapter_count=target_chapters_per_outline,
|
||||
expansion_strategy=expansion_strategy,
|
||||
provider=provider,
|
||||
model=model
|
||||
)
|
||||
|
||||
expansions.append({
|
||||
"outline_id": outline.id,
|
||||
"outline_title": outline.title,
|
||||
"chapter_plans": chapter_plans,
|
||||
"chapter_count": len(chapter_plans)
|
||||
})
|
||||
|
||||
total_chapters += len(chapter_plans)
|
||||
logger.info(f"大纲 {outline.title} 展开为 {len(chapter_plans)} 章")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"展开大纲 {outline.id} 失败: {str(e)}")
|
||||
expansions.append({
|
||||
"outline_id": outline.id,
|
||||
"outline_title": outline.title,
|
||||
"error": str(e),
|
||||
"chapter_count": 0
|
||||
})
|
||||
|
||||
result = {
|
||||
"total_outlines": len(outlines),
|
||||
"total_chapters_planned": total_chapters,
|
||||
"expansions": expansions
|
||||
}
|
||||
|
||||
logger.info(f"批量展开完成: {len(outlines)} 个大纲 → {total_chapters} 个章节规划")
|
||||
return result
|
||||
|
||||
async def create_chapters_from_plans(
|
||||
self,
|
||||
outline_id: str,
|
||||
chapter_plans: List[Dict[str, Any]],
|
||||
project_id: str,
|
||||
db: AsyncSession,
|
||||
start_chapter_number: int = None
|
||||
) -> List[Chapter]:
|
||||
"""
|
||||
根据章节规划创建实际的章节记录
|
||||
|
||||
Args:
|
||||
outline_id: 大纲ID
|
||||
chapter_plans: 章节规划列表
|
||||
project_id: 项目ID
|
||||
db: 数据库会话
|
||||
start_chapter_number: 起始章节号(如果为None,则自动计算)
|
||||
|
||||
Returns:
|
||||
创建的章节列表
|
||||
"""
|
||||
logger.info(f"根据规划创建 {len(chapter_plans)} 个章节记录")
|
||||
|
||||
# 如果没有指定起始章节号,根据大纲顺序自动计算
|
||||
if start_chapter_number is None:
|
||||
# 1. 获取当前大纲信息
|
||||
outline_result = await db.execute(
|
||||
select(Outline).where(Outline.id == outline_id)
|
||||
)
|
||||
current_outline = outline_result.scalar_one_or_none()
|
||||
|
||||
if not current_outline:
|
||||
raise ValueError(f"大纲 {outline_id} 不存在")
|
||||
|
||||
# 2. 查询所有在当前大纲之前的大纲(按order_index排序)
|
||||
prev_outlines_result = await db.execute(
|
||||
select(Outline)
|
||||
.where(
|
||||
Outline.project_id == project_id,
|
||||
Outline.order_index < current_outline.order_index
|
||||
)
|
||||
.order_by(Outline.order_index)
|
||||
)
|
||||
prev_outlines = prev_outlines_result.scalars().all()
|
||||
|
||||
# 3. 计算前面所有大纲已展开的章节总数
|
||||
total_prev_chapters = 0
|
||||
for prev_outline in prev_outlines:
|
||||
count_result = await db.execute(
|
||||
select(func.count(Chapter.id))
|
||||
.where(
|
||||
Chapter.project_id == project_id,
|
||||
Chapter.outline_id == prev_outline.id
|
||||
)
|
||||
)
|
||||
total_prev_chapters += count_result.scalar() or 0
|
||||
|
||||
# 4. 起始章节号 = 前面所有大纲的章节数 + 1
|
||||
start_chapter_number = total_prev_chapters + 1
|
||||
logger.info(f"自动计算起始章节号: {start_chapter_number} (基于大纲order_index={current_outline.order_index}, 前置章节数={total_prev_chapters})")
|
||||
|
||||
chapters = []
|
||||
for idx, plan in enumerate(chapter_plans):
|
||||
# 保存完整的展开规划数据(JSON格式)
|
||||
expansion_plan_json = json.dumps({
|
||||
"key_events": plan.get("key_events", []),
|
||||
"character_focus": plan.get("character_focus", []),
|
||||
"emotional_tone": plan.get("emotional_tone", ""),
|
||||
"narrative_goal": plan.get("narrative_goal", ""),
|
||||
"conflict_type": plan.get("conflict_type", ""),
|
||||
"estimated_words": plan.get("estimated_words", 3000),
|
||||
"scenes": plan.get("scenes", []) if plan.get("scenes") else None
|
||||
}, ensure_ascii=False)
|
||||
|
||||
chapter = Chapter(
|
||||
project_id=project_id,
|
||||
outline_id=outline_id,
|
||||
chapter_number=start_chapter_number + idx,
|
||||
sub_index=plan.get("sub_index", idx + 1),
|
||||
title=plan.get("title", f"第{start_chapter_number + idx}章"),
|
||||
summary=plan.get("plot_summary", ""),
|
||||
expansion_plan=expansion_plan_json,
|
||||
status="draft"
|
||||
)
|
||||
db.add(chapter)
|
||||
chapters.append(chapter)
|
||||
|
||||
await db.commit()
|
||||
|
||||
for chapter in chapters:
|
||||
await db.refresh(chapter)
|
||||
|
||||
logger.info(f"成功创建 {len(chapters)} 个章节记录(已保存展开规划数据)")
|
||||
|
||||
# 重新排序当前大纲之后的所有章节
|
||||
await self._renumber_subsequent_chapters(
|
||||
project_id=project_id,
|
||||
current_outline_id=outline_id,
|
||||
db=db
|
||||
)
|
||||
|
||||
return chapters
|
||||
|
||||
async def _get_outline_context(
|
||||
self,
|
||||
outline: Outline,
|
||||
project_id: str,
|
||||
db: AsyncSession
|
||||
) -> str:
|
||||
"""获取大纲的上下文(前后大纲)"""
|
||||
# 获取前一个大纲
|
||||
prev_result = await db.execute(
|
||||
select(Outline)
|
||||
.where(
|
||||
Outline.project_id == project_id,
|
||||
Outline.order_index < outline.order_index
|
||||
)
|
||||
.order_by(Outline.order_index.desc())
|
||||
.limit(1)
|
||||
)
|
||||
prev_outline = prev_result.scalar_one_or_none()
|
||||
|
||||
# 获取后一个大纲
|
||||
next_result = await db.execute(
|
||||
select(Outline)
|
||||
.where(
|
||||
Outline.project_id == project_id,
|
||||
Outline.order_index > outline.order_index
|
||||
)
|
||||
.order_by(Outline.order_index)
|
||||
.limit(1)
|
||||
)
|
||||
next_outline = next_result.scalar_one_or_none()
|
||||
|
||||
context = ""
|
||||
if prev_outline:
|
||||
context += f"【前一节】{prev_outline.title}: {prev_outline.content[:200]}...\n\n"
|
||||
if next_outline:
|
||||
context += f"【后一节】{next_outline.title}: {next_outline.content[:200]}...\n"
|
||||
|
||||
return context if context else "(无前后文)"
|
||||
|
||||
|
||||
def _parse_expansion_response(
|
||||
self,
|
||||
ai_response: str,
|
||||
outline_id: str
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""解析AI的展开响应(使用统一的JSON清洗方法)"""
|
||||
try:
|
||||
# 使用统一的JSON清洗方法
|
||||
cleaned_text = self.ai_service._clean_json_response(ai_response)
|
||||
|
||||
# 解析JSON
|
||||
chapter_plans = json.loads(cleaned_text)
|
||||
|
||||
# 确保是列表
|
||||
if not isinstance(chapter_plans, list):
|
||||
chapter_plans = [chapter_plans]
|
||||
|
||||
# 为每个章节规划添加outline_id
|
||||
for plan in chapter_plans:
|
||||
plan["outline_id"] = outline_id
|
||||
|
||||
logger.info(f"✅ 成功解析 {len(chapter_plans)} 个章节规划")
|
||||
return chapter_plans
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"❌ 解析AI响应失败: {e}, 响应内容: {ai_response[:500]}")
|
||||
# 返回一个基础规划
|
||||
return [{
|
||||
"outline_id": outline_id,
|
||||
"sub_index": 1,
|
||||
"title": "AI解析失败的默认章节",
|
||||
"plot_summary": ai_response[:500],
|
||||
"key_events": ["解析失败"],
|
||||
"character_focus": [],
|
||||
"emotional_tone": "未知",
|
||||
"narrative_goal": "需要重新生成",
|
||||
"conflict_type": "未知",
|
||||
"estimated_words": 3000
|
||||
}]
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 解析异常: {str(e)}")
|
||||
return [{
|
||||
"outline_id": outline_id,
|
||||
"sub_index": 1,
|
||||
"title": "解析异常的默认章节",
|
||||
"plot_summary": "系统错误",
|
||||
"key_events": [],
|
||||
"character_focus": [],
|
||||
"emotional_tone": "未知",
|
||||
"narrative_goal": "需要重新生成",
|
||||
"conflict_type": "未知",
|
||||
"estimated_words": 3000
|
||||
}]
|
||||
|
||||
|
||||
async def _renumber_subsequent_chapters(
|
||||
self,
|
||||
project_id: str,
|
||||
current_outline_id: str,
|
||||
db: AsyncSession
|
||||
):
|
||||
"""
|
||||
重新计算当前大纲之后所有大纲的章节序号
|
||||
|
||||
Args:
|
||||
project_id: 项目ID
|
||||
current_outline_id: 当前大纲ID
|
||||
db: 数据库会话
|
||||
"""
|
||||
logger.info(f"开始重新排序大纲 {current_outline_id} 之后的所有章节")
|
||||
|
||||
# 1. 获取当前大纲信息
|
||||
current_outline_result = await db.execute(
|
||||
select(Outline).where(Outline.id == current_outline_id)
|
||||
)
|
||||
current_outline = current_outline_result.scalar_one_or_none()
|
||||
|
||||
if not current_outline:
|
||||
logger.warning(f"大纲 {current_outline_id} 不存在,跳过重新排序")
|
||||
return
|
||||
|
||||
# 2. 获取当前大纲及之后的所有大纲(按order_index排序)
|
||||
subsequent_outlines_result = await db.execute(
|
||||
select(Outline)
|
||||
.where(
|
||||
Outline.project_id == project_id,
|
||||
Outline.order_index >= current_outline.order_index
|
||||
)
|
||||
.order_by(Outline.order_index)
|
||||
)
|
||||
subsequent_outlines = subsequent_outlines_result.scalars().all()
|
||||
|
||||
# 3. 计算每个大纲的起始章节号
|
||||
current_chapter_number = 1
|
||||
|
||||
# 先计算前面大纲的章节总数
|
||||
prev_outlines_result = await db.execute(
|
||||
select(Outline)
|
||||
.where(
|
||||
Outline.project_id == project_id,
|
||||
Outline.order_index < current_outline.order_index
|
||||
)
|
||||
.order_by(Outline.order_index)
|
||||
)
|
||||
prev_outlines = prev_outlines_result.scalars().all()
|
||||
|
||||
for prev_outline in prev_outlines:
|
||||
count_result = await db.execute(
|
||||
select(func.count(Chapter.id))
|
||||
.where(
|
||||
Chapter.project_id == project_id,
|
||||
Chapter.outline_id == prev_outline.id
|
||||
)
|
||||
)
|
||||
current_chapter_number += count_result.scalar() or 0
|
||||
|
||||
# 4. 逐个大纲更新章节序号
|
||||
updated_count = 0
|
||||
for outline in subsequent_outlines:
|
||||
# 获取该大纲的所有章节(按sub_index排序)
|
||||
chapters_result = await db.execute(
|
||||
select(Chapter)
|
||||
.where(
|
||||
Chapter.project_id == project_id,
|
||||
Chapter.outline_id == outline.id
|
||||
)
|
||||
.order_by(Chapter.sub_index)
|
||||
)
|
||||
chapters = chapters_result.scalars().all()
|
||||
|
||||
# 更新每个章节的chapter_number
|
||||
for chapter in chapters:
|
||||
if chapter.chapter_number != current_chapter_number:
|
||||
logger.debug(f"更新章节 {chapter.id}: {chapter.chapter_number} -> {current_chapter_number}")
|
||||
chapter.chapter_number = current_chapter_number
|
||||
updated_count += 1
|
||||
current_chapter_number += 1
|
||||
|
||||
# 5. 提交更新
|
||||
await db.commit()
|
||||
logger.info(f"重新排序完成,共更新 {updated_count} 个章节的序号")
|
||||
|
||||
|
||||
# 工厂函数
|
||||
def create_plot_expansion_service(ai_service: AIService) -> PlotExpansionService:
|
||||
"""创建剧情展开服务实例"""
|
||||
return PlotExpansionService(ai_service)
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user