This commit is contained in:
xiamuceer
2025-10-30 11:14:43 +08:00
parent b97410d973
commit 0f6c2d344a
91 changed files with 22309 additions and 0 deletions
+48
View File
@@ -0,0 +1,48 @@
# AI服务配置
# OpenAI配置
OPENAI_API_KEY=your_openai_key_here
OPENAI_BASE_URL=https://api.openai.com/v1
# Anthropic配置
ANTHROPIC_API_KEY=your_anthropic_key_here
ANTHROPIC_BASE_URL=https://api.anthropic.com
# 默认AI提供商:openai, gemini, anthropic
DEFAULT_AI_PROVIDER=openai
DEFAULT_MODEL=gpt-4.1
DEFAULT_TEMPERATURE=0.8
DEFAULT_MAX_TOKENS=32000
# 应用配置
APP_NAME=MuMuAINovel
APP_VERSION=1.0.0
APP_HOST=0.0.0.0
APP_PORT=8000
DEBUG=true
# LinuxDO OAuth2 配置(可选)
# 注意:Docker部署时,LINUXDO_REDIRECT_URI 应该使用实际的域名或服务器IP
# 本地开发: http://localhost:8000/api/auth/callback
# 生产环境: https://your-domain.com/api/auth/callback 或 http://your-server-ip:8000/api/auth/callback
LINUXDO_CLIENT_ID=your_client_id_here
LINUXDO_CLIENT_SECRET=your_client_secret_here
LINUXDO_REDIRECT_URI=http://localhost:8000/api/auth/callback
# 前端URL配置(用于OAuth回调后重定向到前端)
# 本地开发: http://localhost:8000
# 生产环境: https://your-domain.com 或 http://your-server-ip:8000
FRONTEND_URL=http://localhost:8000
# 本地账户登录配置
# 启用本地账户登录(true/false)
LOCAL_AUTH_ENABLED=true
# 本地登录用户名
LOCAL_AUTH_USERNAME=admin
# 本地登录密码
LOCAL_AUTH_PASSWORD=your_secure_password_here
# 本地用户显示名称
LOCAL_AUTH_DISPLAY_NAME=管理员
# CORS配置(生产环境)
# 允许的跨域来源,多个用逗号分隔
# CORS_ORIGINS=https://your-domain.com,https://www.your-domain.com
+2
View File
@@ -0,0 +1,2 @@
"""AI Story Creator - 后端应用包"""
__version__ = "1.0.0"
+1
View File
@@ -0,0 +1 @@
"""API路由模块"""
+230
View File
@@ -0,0 +1,230 @@
"""
认证 API - LinuxDO OAuth2 登录 + 本地账户登录
"""
from fastapi import APIRouter, HTTPException, Response, Request
from fastapi.responses import RedirectResponse
from pydantic import BaseModel
from typing import Optional
import hashlib
from app.services.oauth_service import LinuxDOOAuthService
from app.user_manager import user_manager
from app.database import init_db
from app.logger import get_logger
from app.config import settings
logger = get_logger(__name__)
router = APIRouter(prefix="/auth", tags=["认证"])
# OAuth2 服务实例
oauth_service = LinuxDOOAuthService()
# State 临时存储(生产环境应使用 Redis)
_state_storage = {}
class AuthUrlResponse(BaseModel):
auth_url: str
state: str
class LocalLoginRequest(BaseModel):
"""本地登录请求"""
username: str
password: str
class LocalLoginResponse(BaseModel):
"""本地登录响应"""
success: bool
message: str
user: Optional[dict] = None
@router.get("/config")
async def get_auth_config():
"""获取认证配置信息"""
return {
"local_auth_enabled": settings.LOCAL_AUTH_ENABLED,
"linuxdo_enabled": bool(settings.LINUXDO_CLIENT_ID and settings.LINUXDO_CLIENT_SECRET)
}
@router.post("/local/login", response_model=LocalLoginResponse)
async def local_login(request: LocalLoginRequest, response: Response):
"""本地账户登录"""
# 检查是否启用本地登录
if not settings.LOCAL_AUTH_ENABLED:
raise HTTPException(status_code=403, detail="本地账户登录未启用")
# 检查是否配置了本地账户
if not settings.LOCAL_AUTH_USERNAME or not settings.LOCAL_AUTH_PASSWORD:
raise HTTPException(status_code=500, detail="本地账户未配置")
# 验证用户名和密码
if request.username != settings.LOCAL_AUTH_USERNAME or request.password != settings.LOCAL_AUTH_PASSWORD:
raise HTTPException(status_code=401, detail="用户名或密码错误")
# 生成本地用户ID(使用用户名的hash)
user_id = f"local_{hashlib.md5(request.username.encode()).hexdigest()[:16]}"
# 创建或更新本地用户
user = await user_manager.create_or_update_from_linuxdo(
linuxdo_id=user_id,
username=request.username,
display_name=settings.LOCAL_AUTH_DISPLAY_NAME,
avatar_url=None,
trust_level=9 # 本地用户给予高信任级别
)
# 初始化用户数据库
try:
await init_db(user.user_id)
logger.info(f"本地用户 {user.user_id} 数据库初始化成功")
except Exception as e:
logger.error(f"本地用户 {user.user_id} 数据库初始化失败: {e}")
# 设置 Cookie7天有效)
response.set_cookie(
key="user_id",
value=user.user_id,
max_age=7 * 24 * 60 * 60, # 7天
httponly=True,
samesite="lax"
)
return LocalLoginResponse(
success=True,
message="登录成功",
user=user.dict()
)
@router.get("/linuxdo/url", response_model=AuthUrlResponse)
async def get_linuxdo_auth_url():
"""获取 LinuxDO 授权 URL"""
state = oauth_service.generate_state()
auth_url = oauth_service.get_authorization_url(state)
# 临时存储 state5分钟有效)
_state_storage[state] = True
return AuthUrlResponse(auth_url=auth_url, state=state)
async def _handle_callback(
code: Optional[str] = None,
state: Optional[str] = None,
error: Optional[str] = None,
response: Response = None
):
"""
LinuxDO OAuth2 回调处理
成功后重定向到前端首页,并设置 user_id Cookie
"""
# 检查是否有错误
if error:
raise HTTPException(status_code=400, detail=f"授权失败: {error}")
# 检查必需参数
if not code or not state:
raise HTTPException(status_code=400, detail="缺少 code 或 state 参数")
# 验证 state(防止 CSRF
if state not in _state_storage:
raise HTTPException(status_code=400, detail="无效的 state 参数")
# 删除已使用的 state
del _state_storage[state]
# 1. 使用 code 获取 access_token
token_data = await oauth_service.get_access_token(code)
if not token_data or "access_token" not in token_data:
raise HTTPException(status_code=400, detail="获取访问令牌失败")
access_token = token_data["access_token"]
# 2. 使用 access_token 获取用户信息
user_info = await oauth_service.get_user_info(access_token)
if not user_info:
raise HTTPException(status_code=400, detail="获取用户信息失败")
# 3. 创建或更新用户
linuxdo_id = str(user_info.get("id"))
username = user_info.get("username", "")
display_name = user_info.get("name", username)
avatar_url = user_info.get("avatar_url")
trust_level = user_info.get("trust_level", 0)
user = await user_manager.create_or_update_from_linuxdo(
linuxdo_id=linuxdo_id,
username=username,
display_name=display_name,
avatar_url=avatar_url,
trust_level=trust_level
)
# 3.5. 初始化用户数据库(如果是新用户)
try:
await init_db(user.user_id)
logger.info(f"用户 {user.user_id} 数据库初始化成功")
except Exception as e:
logger.error(f"用户 {user.user_id} 数据库初始化失败: {e}")
# 继续执行,不影响登录流程(可能是已存在的用户)
# 4. 设置 Cookie 并重定向到前端回调页面
# 使用配置的前端URL,支持不同的部署环境
frontend_url = settings.FRONTEND_URL.rstrip('/')
redirect_url = f"{frontend_url}/auth/callback"
logger.info(f"OAuth回调成功,重定向到前端: {redirect_url}")
redirect_response = RedirectResponse(url=redirect_url)
# 设置 httponly Cookie7天有效)
redirect_response.set_cookie(
key="user_id",
value=user.user_id,
max_age=7 * 24 * 60 * 60, # 7天
httponly=True,
samesite="lax"
)
return redirect_response
@router.get("/linuxdo/callback")
async def linuxdo_callback(
code: Optional[str] = None,
state: Optional[str] = None,
error: Optional[str] = None,
response: Response = None
):
"""LinuxDO OAuth2 回调处理(标准路径)"""
return await _handle_callback(code, state, error, response)
@router.get("/callback")
async def callback_alias(
code: Optional[str] = None,
state: Optional[str] = None,
error: Optional[str] = None,
response: Response = None
):
"""LinuxDO OAuth2 回调处理(兼容路径)"""
return await _handle_callback(code, state, error, response)
@router.post("/logout")
async def logout(response: Response):
"""退出登录"""
response.delete_cookie("user_id")
return {"message": "退出登录成功"}
@router.get("/user")
async def get_current_user(request: Request):
"""获取当前登录用户信息"""
if not hasattr(request.state, "user") or not request.state.user:
raise HTTPException(status_code=401, detail="未登录")
return request.state.user.dict()
+655
View File
@@ -0,0 +1,655 @@
"""章节管理API"""
from fastapi import APIRouter, Depends, HTTPException, Request
from fastapi.responses import StreamingResponse
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, func
import json
import asyncio
from app.database import get_db
from app.models.chapter import Chapter
from app.models.project import Project
from app.models.outline import Outline
from app.models.character import Character
from app.models.generation_history import GenerationHistory
from app.schemas.chapter import (
ChapterCreate,
ChapterUpdate,
ChapterResponse,
ChapterListResponse
)
from app.services.ai_service import ai_service
from app.services.prompt_service import prompt_service
from app.logger import get_logger
router = APIRouter(prefix="/chapters", tags=["章节管理"])
logger = get_logger(__name__)
@router.post("", response_model=ChapterResponse, summary="创建章节")
async def create_chapter(
chapter: ChapterCreate,
db: AsyncSession = Depends(get_db)
):
"""创建新的章节"""
# 验证项目是否存在
result = await db.execute(
select(Project).where(Project.id == chapter.project_id)
)
project = result.scalar_one_or_none()
if not project:
raise HTTPException(status_code=404, detail="项目不存在")
# 计算字数
word_count = len(chapter.content)
db_chapter = Chapter(
**chapter.model_dump(),
word_count=word_count
)
db.add(db_chapter)
# 更新项目的当前字数
project.current_words = project.current_words + word_count
await db.commit()
await db.refresh(db_chapter)
return db_chapter
@router.get("/project/{project_id}", response_model=ChapterListResponse, summary="获取项目的所有章节")
async def get_project_chapters(
project_id: str,
db: AsyncSession = Depends(get_db)
):
"""获取指定项目的所有章节(路径参数版本)"""
# 获取总数
count_result = await db.execute(
select(func.count(Chapter.id)).where(Chapter.project_id == project_id)
)
total = count_result.scalar_one()
# 获取章节列表
result = await db.execute(
select(Chapter)
.where(Chapter.project_id == project_id)
.order_by(Chapter.chapter_number)
)
chapters = result.scalars().all()
return ChapterListResponse(total=total, items=chapters)
@router.get("/{chapter_id}", response_model=ChapterResponse, summary="获取章节详情")
async def get_chapter(
chapter_id: str,
db: AsyncSession = Depends(get_db)
):
"""根据ID获取章节详情"""
result = await db.execute(
select(Chapter).where(Chapter.id == chapter_id)
)
chapter = result.scalar_one_or_none()
if not chapter:
raise HTTPException(status_code=404, detail="章节不存在")
return chapter
@router.put("/{chapter_id}", response_model=ChapterResponse, summary="更新章节")
async def update_chapter(
chapter_id: str,
chapter_update: ChapterUpdate,
db: AsyncSession = Depends(get_db)
):
"""更新章节信息"""
result = await db.execute(
select(Chapter).where(Chapter.id == chapter_id)
)
chapter = result.scalar_one_or_none()
if not chapter:
raise HTTPException(status_code=404, detail="章节不存在")
# 记录旧字数
old_word_count = chapter.word_count or 0
# 更新字段
update_data = chapter_update.model_dump(exclude_unset=True)
for field, value in update_data.items():
setattr(chapter, field, value)
# 如果内容更新了,重新计算字数
if "content" in update_data and chapter.content:
new_word_count = len(chapter.content)
chapter.word_count = new_word_count
# 更新项目字数
result = await db.execute(
select(Project).where(Project.id == chapter.project_id)
)
project = result.scalar_one_or_none()
if project:
project.current_words = project.current_words - old_word_count + new_word_count
await db.commit()
await db.refresh(chapter)
return chapter
@router.delete("/{chapter_id}", summary="删除章节")
async def delete_chapter(
chapter_id: str,
db: AsyncSession = Depends(get_db)
):
"""删除章节"""
result = await db.execute(
select(Chapter).where(Chapter.id == chapter_id)
)
chapter = result.scalar_one_or_none()
if not chapter:
raise HTTPException(status_code=404, detail="章节不存在")
# 更新项目字数
result = await db.execute(
select(Project).where(Project.id == chapter.project_id)
)
project = result.scalar_one_or_none()
if project:
project.current_words = max(0, project.current_words - chapter.word_count)
await db.delete(chapter)
await db.commit()
return {"message": "章节删除成功"}
async def check_prerequisites(db: AsyncSession, chapter: Chapter) -> tuple[bool, str, list[Chapter]]:
"""
检查章节前置条件
Args:
db: 数据库会话
chapter: 当前章节
Returns:
(可否生成, 错误信息, 前置章节列表)
"""
# 如果是第一章,无需检查前置
if chapter.chapter_number == 1:
return True, "", []
# 查询所有前置章节(序号小于当前章节的)
result = await db.execute(
select(Chapter)
.where(Chapter.project_id == chapter.project_id)
.where(Chapter.chapter_number < chapter.chapter_number)
.order_by(Chapter.chapter_number)
)
previous_chapters = result.scalars().all()
# 检查是否所有前置章节都有内容
incomplete_chapters = [
ch for ch in previous_chapters
if not ch.content or ch.content.strip() == ""
]
if incomplete_chapters:
missing_numbers = [str(ch.chapter_number) for ch in incomplete_chapters]
error_msg = f"需要先完成前置章节:第 {', '.join(missing_numbers)}"
return False, error_msg, previous_chapters
return True, "", previous_chapters
@router.get("/{chapter_id}/can-generate", summary="检查章节是否可以生成")
async def check_can_generate(
chapter_id: str,
db: AsyncSession = Depends(get_db)
):
"""
检查章节是否满足生成条件
返回可生成状态和前置章节信息
"""
# 获取章节
result = await db.execute(
select(Chapter).where(Chapter.id == chapter_id)
)
chapter = result.scalar_one_or_none()
if not chapter:
raise HTTPException(status_code=404, detail="章节不存在")
# 检查前置条件
can_generate, error_msg, previous_chapters = await check_prerequisites(db, chapter)
# 构建前置章节信息
previous_info = [
{
"id": ch.id,
"chapter_number": ch.chapter_number,
"title": ch.title,
"has_content": bool(ch.content and ch.content.strip()),
"word_count": ch.word_count or 0
}
for ch in previous_chapters
]
return {
"can_generate": can_generate,
"reason": error_msg if not can_generate else "",
"previous_chapters": previous_info,
"chapter_number": chapter.chapter_number
}
@router.post("/{chapter_id}/generate", summary="AI创作章节内容")
async def generate_chapter_content(
chapter_id: str,
db: AsyncSession = Depends(get_db)
):
"""
根据大纲、前置章节内容和项目信息AI创作章节完整内容
要求:必须按顺序生成,确保前置章节都已完成
"""
# 获取章节
result = await db.execute(
select(Chapter).where(Chapter.id == chapter_id)
)
chapter = result.scalar_one_or_none()
if not chapter:
raise HTTPException(status_code=404, detail="章节不存在")
# 检查前置条件
can_generate, error_msg, previous_chapters = await check_prerequisites(db, chapter)
if not can_generate:
raise HTTPException(status_code=400, detail=error_msg)
try:
# 获取项目信息
project_result = await db.execute(
select(Project).where(Project.id == chapter.project_id)
)
project = project_result.scalar_one_or_none()
if not project:
raise HTTPException(status_code=404, detail="项目不存在")
# 获取对应的大纲(使用新的查询确保获取最新数据)
outline_result = await db.execute(
select(Outline)
.where(Outline.project_id == chapter.project_id)
.where(Outline.order_index == chapter.chapter_number)
.execution_options(populate_existing=True)
)
outline = outline_result.scalar_one_or_none()
# 获取所有大纲用于上下文(使用新的查询确保获取最新数据)
all_outlines_result = await db.execute(
select(Outline)
.where(Outline.project_id == chapter.project_id)
.order_by(Outline.order_index)
.execution_options(populate_existing=True)
)
all_outlines = all_outlines_result.scalars().all()
outlines_context = "\n".join([
f"{o.order_index}{o.title}: {o.content[:100]}..."
for o in all_outlines
])
# 获取角色信息
characters_result = await db.execute(
select(Character).where(Character.project_id == chapter.project_id)
)
characters = characters_result.scalars().all()
characters_info = "\n".join([
f"- {c.name}({'组织' if c.is_organization else '角色'}, {c.role_type}): {c.personality[:100] if c.personality else ''}"
for c in characters
])
# 构建前置章节内容上下文(如果有前置章节)
previous_content = ""
if previous_chapters:
# Token控制:保留最近3章的完整内容,早期章节使用摘要
recent_chapters = previous_chapters[-3:] if len(previous_chapters) > 3 else previous_chapters
early_chapters = previous_chapters[:-3] if len(previous_chapters) > 3 else []
# 早期章节摘要
if early_chapters:
early_summary = "【前期剧情概要】\n" + "\n".join([
f"{ch.chapter_number}章《{ch.title}》:{ch.content[:200] if ch.content else ''}..."
for ch in early_chapters
])
previous_content += early_summary + "\n\n"
# 最近章节完整内容
if recent_chapters:
recent_content = "【最近章节完整内容】\n" + "\n\n".join([
f"=== 第{ch.chapter_number}章:{ch.title} ===\n{ch.content}"
for ch in recent_chapters
])
previous_content += recent_content
logger.info(f"构建前置上下文:{len(early_chapters)}章摘要 + {len(recent_chapters)}章完整内容")
# 根据是否有前置内容选择不同的提示词
if previous_content:
# 使用带上下文的提示词
prompt = prompt_service.get_chapter_generation_with_context_prompt(
title=project.title,
theme=project.theme or '',
genre=project.genre or '',
narrative_perspective=project.narrative_perspective or '第三人称',
time_period=project.world_time_period or '未设定',
location=project.world_location or '未设定',
atmosphere=project.world_atmosphere or '未设定',
rules=project.world_rules or '未设定',
characters_info=characters_info or '暂无角色信息',
outlines_context=outlines_context,
previous_content=previous_content,
chapter_number=chapter.chapter_number,
chapter_title=chapter.title,
chapter_outline=outline.content if outline else chapter.summary or '暂无大纲'
)
else:
# 第一章,使用原有提示词
prompt = prompt_service.get_chapter_generation_prompt(
title=project.title,
theme=project.theme or '',
genre=project.genre or '',
narrative_perspective=project.narrative_perspective or '第三人称',
time_period=project.world_time_period or '未设定',
location=project.world_location or '未设定',
atmosphere=project.world_atmosphere or '未设定',
rules=project.world_rules or '未设定',
characters_info=characters_info or '暂无角色信息',
outlines_context=outlines_context,
chapter_number=chapter.chapter_number,
chapter_title=chapter.title,
chapter_outline=outline.content if outline else chapter.summary or '暂无大纲'
)
logger.info(f"开始AI创作章节 {chapter_id}")
# 调用AI生成
ai_content = await ai_service.generate_text(
prompt=prompt
)
# 更新章节内容
old_word_count = chapter.word_count or 0
chapter.content = ai_content
new_word_count = len(ai_content)
chapter.word_count = new_word_count
chapter.status = "completed"
# 更新项目字数
project.current_words = project.current_words - old_word_count + new_word_count
# 记录生成历史
history = GenerationHistory(
project_id=chapter.project_id,
chapter_id=chapter.id,
prompt=f"创作章节: 第{chapter.chapter_number}{chapter.title}",
generated_content=ai_content[:500] if len(ai_content) > 500 else ai_content,
model="default"
)
db.add(history)
await db.commit()
await db.refresh(chapter)
logger.info(f"成功创作章节 {chapter_id},共 {new_word_count}")
return {"content": ai_content}
except Exception as e:
logger.error(f"创作章节失败: {str(e)}")
raise HTTPException(status_code=500, detail=f"创作章节失败: {str(e)}")
@router.post("/{chapter_id}/generate-stream", summary="AI创作章节内容(流式)")
async def generate_chapter_content_stream(
chapter_id: str,
request: Request
):
"""
根据大纲、前置章节内容和项目信息AI创作章节完整内容(流式返回)
要求:必须按顺序生成,确保前置章节都已完成
注意:此函数不使用依赖注入的db,而是在生成器内部创建独立的数据库会话
以避免流式响应期间的连接泄漏问题
"""
# 预先验证章节存在性(使用临时会话)
async for temp_db in get_db(request):
try:
result = await temp_db.execute(
select(Chapter).where(Chapter.id == chapter_id)
)
chapter = result.scalar_one_or_none()
if not chapter:
raise HTTPException(status_code=404, detail="章节不存在")
# 检查前置条件
can_generate, error_msg, previous_chapters = await check_prerequisites(temp_db, chapter)
if not can_generate:
raise HTTPException(status_code=400, detail=error_msg)
# 保存前置章节数据供生成器使用
previous_chapters_data = [
{
'id': ch.id,
'chapter_number': ch.chapter_number,
'title': ch.title,
'content': ch.content
}
for ch in previous_chapters
]
finally:
await temp_db.close()
break
async def event_generator():
# 在生成器内部创建独立的数据库会话
db_session = None
db_committed = False
try:
# 创建新的数据库会话
async for db_session in get_db(request):
# 重新获取章节信息
chapter_result = await db_session.execute(
select(Chapter).where(Chapter.id == chapter_id)
)
current_chapter = chapter_result.scalar_one_or_none()
if not current_chapter:
yield f"data: {json.dumps({'type': 'error', 'error': '章节不存在'}, ensure_ascii=False)}\n\n"
return
# 获取项目信息
project_result = await db_session.execute(
select(Project).where(Project.id == current_chapter.project_id)
)
project = project_result.scalar_one_or_none()
if not project:
yield f"data: {json.dumps({'type': 'error', 'error': '项目不存在'}, ensure_ascii=False)}\n\n"
return
# 获取对应的大纲
outline_result = await db_session.execute(
select(Outline)
.where(Outline.project_id == current_chapter.project_id)
.where(Outline.order_index == current_chapter.chapter_number)
.execution_options(populate_existing=True)
)
outline = outline_result.scalar_one_or_none()
# 获取所有大纲用于上下文
all_outlines_result = await db_session.execute(
select(Outline)
.where(Outline.project_id == current_chapter.project_id)
.order_by(Outline.order_index)
.execution_options(populate_existing=True)
)
all_outlines = all_outlines_result.scalars().all()
outlines_context = "\n".join([
f"{o.order_index}{o.title}: {o.content[:100]}..."
for o in all_outlines
])
# 获取角色信息
characters_result = await db_session.execute(
select(Character).where(Character.project_id == current_chapter.project_id)
)
characters = characters_result.scalars().all()
characters_info = "\n".join([
f"- {c.name}({'组织' if c.is_organization else '角色'}, {c.role_type}): {c.personality[:100] if c.personality else ''}"
for c in characters
])
# 构建前置章节内容上下文(使用之前保存的数据)
previous_content = ""
if previous_chapters_data:
recent_chapters = previous_chapters_data[-3:] if len(previous_chapters_data) > 3 else previous_chapters_data
early_chapters = previous_chapters_data[:-3] if len(previous_chapters_data) > 3 else []
if early_chapters:
early_summary = "【前期剧情概要】\n" + "\n".join([
f"{ch['chapter_number']}章《{ch['title']}》:{ch['content'][:200] if ch['content'] else ''}..."
for ch in early_chapters
])
previous_content += early_summary + "\n\n"
if recent_chapters:
recent_content = "【最近章节完整内容】\n" + "\n\n".join([
f"=== 第{ch['chapter_number']}章:{ch['title']} ===\n{ch['content']}"
for ch in recent_chapters
])
previous_content += recent_content
logger.info(f"构建前置上下文:{len(early_chapters)}章摘要 + {len(recent_chapters)}章完整内容")
# 发送开始事件
yield f"data: {json.dumps({'type': 'start', 'message': '开始AI创作...'}, ensure_ascii=False)}\n\n"
# 根据是否有前置内容选择不同的提示词
if previous_content:
prompt = prompt_service.get_chapter_generation_with_context_prompt(
title=project.title,
theme=project.theme or '',
genre=project.genre or '',
narrative_perspective=project.narrative_perspective or '第三人称',
time_period=project.world_time_period or '未设定',
location=project.world_location or '未设定',
atmosphere=project.world_atmosphere or '未设定',
rules=project.world_rules or '未设定',
characters_info=characters_info or '暂无角色信息',
outlines_context=outlines_context,
previous_content=previous_content,
chapter_number=current_chapter.chapter_number,
chapter_title=current_chapter.title,
chapter_outline=outline.content if outline else current_chapter.summary or '暂无大纲'
)
else:
prompt = prompt_service.get_chapter_generation_prompt(
title=project.title,
theme=project.theme or '',
genre=project.genre or '',
narrative_perspective=project.narrative_perspective or '第三人称',
time_period=project.world_time_period or '未设定',
location=project.world_location or '未设定',
atmosphere=project.world_atmosphere or '未设定',
rules=project.world_rules or '未设定',
characters_info=characters_info or '暂无角色信息',
outlines_context=outlines_context,
chapter_number=current_chapter.chapter_number,
chapter_title=current_chapter.title,
chapter_outline=outline.content if outline else current_chapter.summary or '暂无大纲'
)
logger.info(f"开始AI流式创作章节 {chapter_id}")
# 流式生成内容
full_content = ""
async for chunk in ai_service.generate_text_stream(prompt=prompt):
full_content += chunk
yield f"data: {json.dumps({'type': 'content', 'content': chunk}, ensure_ascii=False)}\n\n"
await asyncio.sleep(0) # 让出控制权
# 更新章节内容到数据库
old_word_count = current_chapter.word_count or 0
current_chapter.content = full_content
new_word_count = len(full_content)
current_chapter.word_count = new_word_count
current_chapter.status = "completed"
# 更新项目字数
project.current_words = project.current_words - old_word_count + new_word_count
# 记录生成历史
history = GenerationHistory(
project_id=current_chapter.project_id,
chapter_id=current_chapter.id,
prompt=f"创作章节: 第{current_chapter.chapter_number}{current_chapter.title}",
generated_content=full_content[:500] if len(full_content) > 500 else full_content,
model="default"
)
db_session.add(history)
await db_session.commit()
db_committed = True
await db_session.refresh(current_chapter)
logger.info(f"成功创作章节 {chapter_id},共 {new_word_count}")
# 发送完成事件
yield f"data: {json.dumps({'type': 'done', 'message': '创作完成', 'word_count': new_word_count}, ensure_ascii=False)}\n\n"
break # 退出async for db_session循环
except GeneratorExit:
# SSE连接断开
logger.warning("章节生成器被提前关闭(SSE断开)")
if db_session and not db_committed:
try:
if db_session.in_transaction():
await db_session.rollback()
logger.info("章节生成事务已回滚(GeneratorExit")
except Exception as e:
logger.error(f"GeneratorExit回滚失败: {str(e)}")
except Exception as e:
logger.error(f"流式创作章节失败: {str(e)}")
if db_session and not db_committed:
try:
if db_session.in_transaction():
await db_session.rollback()
logger.info("章节生成事务已回滚(异常)")
except Exception as rollback_error:
logger.error(f"回滚失败: {str(rollback_error)}")
yield f"data: {json.dumps({'type': 'error', 'error': str(e)}, ensure_ascii=False)}\n\n"
finally:
# 确保数据库会话被正确关闭
if db_session:
try:
# 最后检查:确保没有未提交的事务
if not db_committed and db_session.in_transaction():
await db_session.rollback()
logger.warning("在finally中发现未提交事务,已回滚")
await db_session.close()
logger.info("数据库会话已关闭")
except Exception as close_error:
logger.error(f"关闭数据库会话失败: {str(close_error)}")
# 强制关闭
try:
await db_session.close()
except:
pass
return StreamingResponse(
event_generator(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no"
}
)
+491
View File
@@ -0,0 +1,491 @@
"""角色管理API"""
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, func
import json
from app.database import get_db
from app.models.character import Character
from app.models.project import Project
from app.models.generation_history import GenerationHistory
from app.models.relationship import CharacterRelationship, Organization, OrganizationMember, RelationshipType
from app.schemas.character import (
CharacterUpdate,
CharacterResponse,
CharacterListResponse,
CharacterGenerateRequest
)
from app.services.ai_service import ai_service
from app.services.prompt_service import prompt_service
from app.logger import get_logger
router = APIRouter(prefix="/characters", tags=["角色管理"])
logger = get_logger(__name__)
@router.get("", response_model=CharacterListResponse, summary="获取角色列表")
async def get_characters(
project_id: str,
db: AsyncSession = Depends(get_db)
):
"""获取指定项目的所有角色(query参数版本)"""
# 获取总数
count_result = await db.execute(
select(func.count(Character.id)).where(Character.project_id == project_id)
)
total = count_result.scalar_one()
# 获取角色列表
result = await db.execute(
select(Character)
.where(Character.project_id == project_id)
.order_by(Character.created_at.desc())
)
characters = result.scalars().all()
return CharacterListResponse(total=total, items=characters)
@router.get("/project/{project_id}", response_model=CharacterListResponse, summary="获取项目的所有角色")
async def get_project_characters(
project_id: str,
db: AsyncSession = Depends(get_db)
):
"""获取指定项目的所有角色(路径参数版本)"""
# 获取总数
count_result = await db.execute(
select(func.count(Character.id)).where(Character.project_id == project_id)
)
total = count_result.scalar_one()
# 获取角色列表
result = await db.execute(
select(Character)
.where(Character.project_id == project_id)
.order_by(Character.created_at.desc())
)
characters = result.scalars().all()
return CharacterListResponse(total=total, items=characters)
@router.get("/{character_id}", response_model=CharacterResponse, summary="获取角色详情")
async def get_character(
character_id: str,
db: AsyncSession = Depends(get_db)
):
"""根据ID获取角色详情"""
result = await db.execute(
select(Character).where(Character.id == character_id)
)
character = result.scalar_one_or_none()
if not character:
raise HTTPException(status_code=404, detail="角色不存在")
return character
@router.put("/{character_id}", response_model=CharacterResponse, summary="更新角色")
async def update_character(
character_id: str,
character_update: CharacterUpdate,
db: AsyncSession = Depends(get_db)
):
"""更新角色信息"""
result = await db.execute(
select(Character).where(Character.id == character_id)
)
character = result.scalar_one_or_none()
if not character:
raise HTTPException(status_code=404, detail="角色不存在")
# 更新字段
update_data = character_update.model_dump(exclude_unset=True)
for field, value in update_data.items():
setattr(character, field, value)
await db.commit()
await db.refresh(character)
return character
@router.delete("/{character_id}", summary="删除角色")
async def delete_character(
character_id: str,
db: AsyncSession = Depends(get_db)
):
"""删除角色"""
result = await db.execute(
select(Character).where(Character.id == character_id)
)
character = result.scalar_one_or_none()
if not character:
raise HTTPException(status_code=404, detail="角色不存在")
await db.delete(character)
await db.commit()
return {"message": "角色删除成功"}
@router.post("/generate", response_model=CharacterResponse, summary="AI生成角色")
async def generate_character(
request: CharacterGenerateRequest,
db: AsyncSession = Depends(get_db)
):
"""
使用AI生成角色卡
根据用户输入的信息,结合项目的世界观、主题等背景,
AI会生成一个完整、详细的角色设定卡片。
生成内容包括:姓名、年龄、性别、性格、外貌、背景故事、人际关系等
"""
# 验证项目是否存在并获取项目信息
result = await db.execute(
select(Project).where(Project.id == request.project_id)
)
project = result.scalar_one_or_none()
if not project:
raise HTTPException(status_code=404, detail="项目不存在")
try:
# 获取已存在的角色列表,用于关系网络
existing_chars_result = await db.execute(
select(Character)
.where(Character.project_id == request.project_id)
.order_by(Character.created_at.desc())
)
existing_characters = existing_chars_result.scalars().all()
# 构建现有角色信息摘要(包含组织)
existing_chars_info = ""
character_list = []
organization_list = []
if existing_characters:
for c in existing_characters[:10]: # 最多显示10个
if c.is_organization:
organization_list.append(f"- {c.name} [{c.organization_type or '组织'}]")
else:
character_list.append(f"- {c.name}{c.role_type or '未知'}")
if character_list:
existing_chars_info += "\n已有角色:\n" + "\n".join(character_list)
if organization_list:
existing_chars_info += "\n\n已有组织:\n" + "\n".join(organization_list)
# 构建项目上下文信息
project_context = f"""
项目信息:
- 书名:{project.title}
- 主题:{project.theme or '未设定'}
- 类型:{project.genre or '未设定'}
- 时间背景:{project.world_time_period or '未设定'}
- 地理位置:{project.world_location or '未设定'}
- 氛围基调:{project.world_atmosphere or '未设定'}
- 世界规则:{project.world_rules or '未设定'}
{existing_chars_info}
"""
# 构建用户输入信息
user_input = f"""
用户要求:
- 角色名称:{request.name or '请AI生成'}
- 角色定位:{request.role_type or 'supporting'}protagonist=主角, supporting=配角, antagonist=反派)
- 背景设定:{request.background or '无特殊要求'}
- 其他要求:{request.requirements or ''}
"""
# 使用统一的提示词服务
prompt = prompt_service.get_single_character_prompt(
project_context=project_context,
user_input=user_input
)
# 调用AI生成角色
logger.info(f"🎯 开始为项目 {request.project_id} 生成角色")
logger.info(f" - 角色名:{request.name or 'AI生成'}")
logger.info(f" - 角色定位:{request.role_type}")
logger.info(f" - 背景设定:{request.background or ''}")
logger.info(f" - AI提供商:{request.provider or 'default'}")
logger.info(f" - AI模型:{request.model or 'default'}")
logger.info(f" - Prompt长度:{len(prompt)} 字符")
try:
ai_response = await ai_service.generate_text(
prompt=prompt,
provider=request.provider,
model=request.model
)
logger.info(f"✅ AI响应接收完成,长度:{len(ai_response) if ai_response else 0} 字符")
except Exception as ai_error:
logger.error(f"❌ AI服务调用异常:{str(ai_error)}")
raise HTTPException(
status_code=500,
detail=f"AI服务调用失败:{str(ai_error)}"
)
# 检查AI响应
if not ai_response or not ai_response.strip():
logger.error("❌ AI返回了空响应")
raise HTTPException(
status_code=500,
detail="AI服务返回空响应。可能原因:1) API配置错误 2) 模型不支持 3) 网络问题。请检查后端日志。"
)
logger.info(f"📝 开始清理AI响应")
# 清理AI响应,移除可能的markdown标记
cleaned_response = ai_response.strip()
original_length = len(cleaned_response)
if cleaned_response.startswith("```json"):
cleaned_response = cleaned_response[7:]
logger.info(" - 移除了 ```json 标记")
if cleaned_response.startswith("```"):
cleaned_response = cleaned_response[3:]
logger.info(" - 移除了 ``` 标记")
if cleaned_response.endswith("```"):
cleaned_response = cleaned_response[:-3]
logger.info(" - 移除了末尾 ``` 标记")
cleaned_response = cleaned_response.strip()
logger.info(f" - 清理前长度:{original_length},清理后长度:{len(cleaned_response)}")
logger.info(f" - 清理后内容预览(前300字符):{cleaned_response[:300]}")
# 解析AI响应
logger.info(f"🔍 开始解析JSON")
try:
character_data = json.loads(cleaned_response)
logger.info(f"✅ JSON解析成功")
logger.info(f" - 解析后的字段:{list(character_data.keys())}")
except json.JSONDecodeError as e:
logger.error(f"❌ JSON解析失败")
logger.error(f" - 错误位置:line {e.lineno}, column {e.colno}")
logger.error(f" - 错误信息:{str(e)}")
logger.error(f" - 完整响应内容(前1000字符):{cleaned_response[:1000]}")
raise HTTPException(
status_code=500,
detail=f"AI返回的内容无法解析为JSON。错误:{str(e)}。响应内容已记录到日志,请查看后端日志排查。"
)
# 转换traits为JSON字符串
traits_json = json.dumps(character_data.get("traits", []), ensure_ascii=False) if character_data.get("traits") else None
# 判断是否为组织
is_organization = character_data.get("is_organization", False)
# 创建角色
character = Character(
project_id=request.project_id,
name=character_data.get("name", request.name or "未命名角色"),
age=str(character_data.get("age", "")),
gender=character_data.get("gender"),
is_organization=is_organization,
role_type=request.role_type or "supporting",
personality=character_data.get("personality", ""),
background=character_data.get("background", ""),
appearance=character_data.get("appearance", ""),
relationships=character_data.get("relationships_text", character_data.get("relationships", "")), # 优先使用文本描述
organization_type=character_data.get("organization_type") if is_organization else None,
organization_purpose=character_data.get("organization_purpose") if is_organization else None,
organization_members=json.dumps(character_data.get("organization_members", []), ensure_ascii=False) if is_organization else None,
traits=traits_json
)
db.add(character)
await db.flush() # 获取character.id
logger.info(f"✅ 角色创建成功:{character.name} (ID: {character.id}, 是否组织: {is_organization})")
# 如果是组织,自动创建Organization详情记录
if is_organization:
org_check = await db.execute(
select(Organization).where(Organization.character_id == character.id)
)
existing_org = org_check.scalar_one_or_none()
if not existing_org:
organization = Organization(
character_id=character.id,
project_id=request.project_id,
member_count=0,
power_level=character_data.get("power_level", 50),
location=character_data.get("location"),
motto=character_data.get("motto")
)
db.add(organization)
await db.flush()
logger.info(f"✅ 自动创建组织详情:{character.name} (Org ID: {organization.id})")
else:
logger.info(f"️ 组织详情已存在:{character.name}")
# 处理结构化关系数据(仅针对非组织角色)
if not is_organization:
relationships_data = character_data.get("relationships", [])
if relationships_data and isinstance(relationships_data, list):
logger.info(f"📊 开始处理 {len(relationships_data)} 条关系数据")
created_rels = 0
for rel in relationships_data:
try:
target_name = rel.get("target_character_name")
if not target_name:
logger.debug(f" ⚠️ 关系缺少target_character_name,跳过")
continue
target_result = await db.execute(
select(Character).where(
Character.project_id == request.project_id,
Character.name == target_name
)
)
target_char = target_result.scalar_one_or_none()
if target_char:
# 检查是否已存在相同关系
existing_rel = await db.execute(
select(CharacterRelationship).where(
CharacterRelationship.project_id == request.project_id,
CharacterRelationship.character_from_id == character.id,
CharacterRelationship.character_to_id == target_char.id
)
)
if existing_rel.scalar_one_or_none():
logger.debug(f" ️ 关系已存在:{character.name} -> {target_name}")
continue
relationship = CharacterRelationship(
project_id=request.project_id,
character_from_id=character.id,
character_to_id=target_char.id,
relationship_name=rel.get("relationship_type", "未知关系"),
intimacy_level=rel.get("intimacy_level", 50),
description=rel.get("description", ""),
started_at=rel.get("started_at"),
source="ai"
)
# 匹配预定义关系类型
rel_type_result = await db.execute(
select(RelationshipType).where(
RelationshipType.name == rel.get("relationship_type")
)
)
rel_type = rel_type_result.scalar_one_or_none()
if rel_type:
relationship.relationship_type_id = rel_type.id
db.add(relationship)
created_rels += 1
logger.info(f" ✅ 创建关系:{character.name} -> {target_name} ({rel.get('relationship_type')})")
else:
logger.warning(f" ⚠️ 目标角色不存在:{target_name}")
except Exception as rel_error:
logger.warning(f" ❌ 创建关系失败:{str(rel_error)}")
continue
logger.info(f"✅ 成功创建 {created_rels} 条关系记录")
# 处理组织成员关系(仅针对非组织角色)
if not is_organization:
org_memberships = character_data.get("organization_memberships", [])
if org_memberships and isinstance(org_memberships, list):
logger.info(f"🏢 开始处理 {len(org_memberships)} 条组织成员关系")
created_members = 0
for membership in org_memberships:
try:
org_name = membership.get("organization_name")
if not org_name:
logger.debug(f" ⚠️ 组织成员关系缺少organization_name,跳过")
continue
org_char_result = await db.execute(
select(Character).where(
Character.project_id == request.project_id,
Character.name == org_name,
Character.is_organization == True
)
)
org_char = org_char_result.scalar_one_or_none()
if org_char:
# 获取或创建Organization记录
org_result = await db.execute(
select(Organization).where(Organization.character_id == org_char.id)
)
org = org_result.scalar_one_or_none()
if not org:
# 如果组织Character存在但Organization不存在,自动创建
org = Organization(
character_id=org_char.id,
project_id=request.project_id,
member_count=0
)
db.add(org)
await db.flush()
logger.info(f" ️ 自动创建缺失的组织详情:{org_name}")
# 检查是否已存在成员关系
existing_member = await db.execute(
select(OrganizationMember).where(
OrganizationMember.organization_id == org.id,
OrganizationMember.character_id == character.id
)
)
if existing_member.scalar_one_or_none():
logger.debug(f" ️ 成员关系已存在:{character.name} -> {org_name}")
continue
# 创建成员关系
member = OrganizationMember(
organization_id=org.id,
character_id=character.id,
position=membership.get("position", "成员"),
rank=membership.get("rank", 0),
loyalty=membership.get("loyalty", 50),
joined_at=membership.get("joined_at"),
status=membership.get("status", "active"),
source="ai"
)
db.add(member)
# 更新组织成员计数
org.member_count += 1
created_members += 1
logger.info(f" ✅ 添加成员:{character.name} -> {org_name} ({membership.get('position')})")
else:
logger.warning(f" ⚠️ 组织不存在:{org_name}")
except Exception as org_error:
logger.warning(f" ❌ 添加组织成员失败:{str(org_error)}")
continue
logger.info(f"✅ 成功创建 {created_members} 条组织成员记录")
# 记录生成历史
history = GenerationHistory(
project_id=request.project_id,
prompt=prompt,
generated_content=ai_response,
model=request.model or "default"
)
db.add(history)
await db.commit()
await db.refresh(character)
logger.info(f"🎉 成功为项目 {request.project_id} 生成角色: {character.name}")
return character
except Exception as e:
logger.error(f"生成角色失败: {str(e)}")
raise HTTPException(status_code=500, detail=f"生成角色失败: {str(e)}")
+341
View File
@@ -0,0 +1,341 @@
"""组织管理API"""
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, and_
from typing import List
from app.database import get_db
from app.models.relationship import Organization, OrganizationMember
from app.models.character import Character
from app.schemas.relationship import (
OrganizationCreate,
OrganizationUpdate,
OrganizationResponse,
OrganizationDetailResponse,
OrganizationMemberCreate,
OrganizationMemberUpdate,
OrganizationMemberResponse,
OrganizationMemberDetailResponse
)
from app.logger import get_logger
router = APIRouter(prefix="/organizations", tags=["组织管理"])
logger = get_logger(__name__)
@router.get("/project/{project_id}", response_model=List[OrganizationDetailResponse], summary="获取项目的所有组织")
async def get_project_organizations(
project_id: str,
db: AsyncSession = Depends(get_db)
):
"""
获取项目中的所有组织及其详情
返回组织的基本信息和统计数据
"""
result = await db.execute(
select(Organization).where(Organization.project_id == project_id)
)
organizations = result.scalars().all()
# 获取每个组织的角色信息
org_list = []
for org in organizations:
char_result = await db.execute(
select(Character).where(Character.id == org.character_id)
)
char = char_result.scalar_one_or_none()
if char:
org_list.append(OrganizationDetailResponse(
id=org.id,
character_id=org.character_id,
name=char.name,
type=char.organization_type,
purpose=char.organization_purpose,
member_count=org.member_count,
power_level=org.power_level,
location=org.location,
motto=org.motto,
color=org.color
))
logger.info(f"获取项目 {project_id} 的组织列表,共 {len(org_list)}")
return org_list
@router.get("/{org_id}", response_model=OrganizationResponse, summary="获取组织详情")
async def get_organization(
org_id: str,
db: AsyncSession = Depends(get_db)
):
"""获取组织的详细信息"""
result = await db.execute(
select(Organization).where(Organization.id == org_id)
)
org = result.scalar_one_or_none()
if not org:
raise HTTPException(status_code=404, detail="组织不存在")
return org
@router.post("/", response_model=OrganizationResponse, summary="创建组织")
async def create_organization(
organization: OrganizationCreate,
db: AsyncSession = Depends(get_db)
):
"""
创建新组织
- 需要关联到一个已存在的角色记录(is_organization=True
- 可以设置父组织、势力等级等属性
"""
# 验证角色是否存在且是组织
char_result = await db.execute(
select(Character).where(Character.id == organization.character_id)
)
char = char_result.scalar_one_or_none()
if not char:
raise HTTPException(status_code=404, detail="关联的角色不存在")
if not char.is_organization:
raise HTTPException(status_code=400, detail="关联的角色不是组织类型")
# 检查是否已存在
existing = await db.execute(
select(Organization).where(Organization.character_id == organization.character_id)
)
if existing.scalar_one_or_none():
raise HTTPException(status_code=400, detail="该角色已有组织详情记录")
# 创建组织
db_org = Organization(**organization.model_dump())
db.add(db_org)
await db.commit()
await db.refresh(db_org)
logger.info(f"创建组织成功:{db_org.id} - {char.name}")
return db_org
@router.put("/{org_id}", response_model=OrganizationResponse, summary="更新组织")
async def update_organization(
org_id: str,
organization: OrganizationUpdate,
db: AsyncSession = Depends(get_db)
):
"""更新组织的属性"""
result = await db.execute(
select(Organization).where(Organization.id == org_id)
)
db_org = result.scalar_one_or_none()
if not db_org:
raise HTTPException(status_code=404, detail="组织不存在")
# 更新字段
update_data = organization.model_dump(exclude_unset=True)
for field, value in update_data.items():
setattr(db_org, field, value)
await db.commit()
await db.refresh(db_org)
logger.info(f"更新组织成功:{org_id}")
return db_org
@router.delete("/{org_id}", summary="删除组织")
async def delete_organization(
org_id: str,
db: AsyncSession = Depends(get_db)
):
"""删除组织(会级联删除所有成员关系)"""
result = await db.execute(
select(Organization).where(Organization.id == org_id)
)
db_org = result.scalar_one_or_none()
if not db_org:
raise HTTPException(status_code=404, detail="组织不存在")
await db.delete(db_org)
await db.commit()
logger.info(f"删除组织成功:{org_id}")
return {"message": "组织删除成功", "id": org_id}
# ============ 组织成员管理 ============
@router.get("/{org_id}/members", response_model=List[OrganizationMemberDetailResponse], summary="获取组织成员")
async def get_organization_members(
org_id: str,
db: AsyncSession = Depends(get_db)
):
"""
获取组织的所有成员
按职位等级(rank)降序排列
"""
# 验证组织存在
org_result = await db.execute(
select(Organization).where(Organization.id == org_id)
)
if not org_result.scalar_one_or_none():
raise HTTPException(status_code=404, detail="组织不存在")
# 获取成员列表
result = await db.execute(
select(OrganizationMember)
.where(OrganizationMember.organization_id == org_id)
.order_by(OrganizationMember.rank.desc(), OrganizationMember.created_at)
)
members = result.scalars().all()
# 获取成员角色信息
member_list = []
for member in members:
char_result = await db.execute(
select(Character).where(Character.id == member.character_id)
)
char = char_result.scalar_one_or_none()
if char:
member_list.append(OrganizationMemberDetailResponse(
id=member.id,
character_id=member.character_id,
character_name=char.name,
position=member.position,
rank=member.rank,
loyalty=member.loyalty,
contribution=member.contribution,
status=member.status,
joined_at=member.joined_at,
left_at=member.left_at,
notes=member.notes
))
logger.info(f"获取组织 {org_id} 的成员列表,共 {len(member_list)}")
return member_list
@router.post("/{org_id}/members", response_model=OrganizationMemberResponse, summary="添加组织成员")
async def add_organization_member(
org_id: str,
member: OrganizationMemberCreate,
db: AsyncSession = Depends(get_db)
):
"""
添加角色到组织
- 一个角色在同一组织中只能有一个职位
- 会自动更新组织的成员计数
"""
# 验证组织存在
org_result = await db.execute(
select(Organization).where(Organization.id == org_id)
)
org = org_result.scalar_one_or_none()
if not org:
raise HTTPException(status_code=404, detail="组织不存在")
# 验证角色存在
char_result = await db.execute(
select(Character).where(Character.id == member.character_id)
)
char = char_result.scalar_one_or_none()
if not char:
raise HTTPException(status_code=404, detail="角色不存在")
if char.is_organization:
raise HTTPException(status_code=400, detail="不能将组织添加为成员")
# 检查是否已存在
existing = await db.execute(
select(OrganizationMember).where(
and_(
OrganizationMember.organization_id == org_id,
OrganizationMember.character_id == member.character_id
)
)
)
if existing.scalar_one_or_none():
raise HTTPException(status_code=400, detail="该角色已在组织中")
# 创建成员关系
db_member = OrganizationMember(
organization_id=org_id,
**member.model_dump(),
source="manual"
)
db.add(db_member)
# 更新组织成员计数
org.member_count += 1
await db.commit()
await db.refresh(db_member)
logger.info(f"添加成员成功:{char.name} 加入组织 {org_id}")
return db_member
@router.put("/members/{member_id}", response_model=OrganizationMemberResponse, summary="更新成员信息")
async def update_organization_member(
member_id: str,
member: OrganizationMemberUpdate,
db: AsyncSession = Depends(get_db)
):
"""更新组织成员的职位、忠诚度等信息"""
result = await db.execute(
select(OrganizationMember).where(OrganizationMember.id == member_id)
)
db_member = result.scalar_one_or_none()
if not db_member:
raise HTTPException(status_code=404, detail="成员记录不存在")
# 更新字段
update_data = member.model_dump(exclude_unset=True)
for field, value in update_data.items():
setattr(db_member, field, value)
await db.commit()
await db.refresh(db_member)
logger.info(f"更新成员信息成功:{member_id}")
return db_member
@router.delete("/members/{member_id}", summary="移除组织成员")
async def remove_organization_member(
member_id: str,
db: AsyncSession = Depends(get_db)
):
"""
从组织中移除成员
会自动更新组织的成员计数
"""
result = await db.execute(
select(OrganizationMember).where(OrganizationMember.id == member_id)
)
db_member = result.scalar_one_or_none()
if not db_member:
raise HTTPException(status_code=404, detail="成员记录不存在")
# 更新组织成员计数
org_result = await db.execute(
select(Organization).where(Organization.id == db_member.organization_id)
)
org = org_result.scalar_one()
org.member_count = max(0, org.member_count - 1)
await db.delete(db_member)
await db.commit()
logger.info(f"移除成员成功:{member_id}")
return {"message": "成员移除成功", "id": member_id}
+657
View File
@@ -0,0 +1,657 @@
"""大纲管理API"""
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, func, delete
from typing import List
import json
from app.database import get_db
from app.models.outline import Outline
from app.models.project import Project
from app.models.chapter import Chapter
from app.models.character import Character
from app.models.generation_history import GenerationHistory
from app.schemas.outline import (
OutlineCreate,
OutlineUpdate,
OutlineResponse,
OutlineListResponse,
OutlineGenerateRequest,
OutlineReorderRequest
)
from app.services.ai_service import ai_service
from app.services.prompt_service import prompt_service
from app.logger import get_logger
router = APIRouter(prefix="/outlines", tags=["大纲管理"])
logger = get_logger(__name__)
@router.post("", response_model=OutlineResponse, summary="创建大纲")
async def create_outline(
outline: OutlineCreate,
db: AsyncSession = Depends(get_db)
):
"""创建新的章节大纲,同时创建对应的章节记录"""
# 验证项目是否存在
result = await db.execute(
select(Project).where(Project.id == outline.project_id)
)
project = result.scalar_one_or_none()
if not project:
raise HTTPException(status_code=404, detail="项目不存在")
# 创建大纲
db_outline = Outline(**outline.model_dump())
db.add(db_outline)
# 同步创建对应的章节记录
chapter = Chapter(
project_id=outline.project_id,
chapter_number=outline.order_index,
title=outline.title,
summary=outline.content[:500] if len(outline.content) > 500 else outline.content,
status="draft"
)
db.add(chapter)
await db.commit()
await db.refresh(db_outline)
return db_outline
@router.get("", response_model=OutlineListResponse, summary="获取大纲列表")
async def get_outlines(
project_id: str,
db: AsyncSession = Depends(get_db)
):
"""获取指定项目的所有大纲"""
# 获取总数
count_result = await db.execute(
select(func.count(Outline.id)).where(Outline.project_id == project_id)
)
total = count_result.scalar_one()
# 获取大纲列表
result = await db.execute(
select(Outline)
.where(Outline.project_id == project_id)
.order_by(Outline.order_index)
)
outlines = result.scalars().all()
return OutlineListResponse(total=total, items=outlines)
@router.get("/project/{project_id}", response_model=OutlineListResponse, summary="获取项目的所有大纲")
async def get_project_outlines(
project_id: str,
db: AsyncSession = Depends(get_db)
):
"""获取指定项目的所有大纲(路径参数版本)"""
# 获取总数
count_result = await db.execute(
select(func.count(Outline.id)).where(Outline.project_id == project_id)
)
total = count_result.scalar_one()
# 获取大纲列表
result = await db.execute(
select(Outline)
.where(Outline.project_id == project_id)
.order_by(Outline.order_index)
)
outlines = result.scalars().all()
return OutlineListResponse(total=total, items=outlines)
@router.get("/{outline_id}", response_model=OutlineResponse, summary="获取大纲详情")
async def get_outline(
outline_id: str,
db: AsyncSession = Depends(get_db)
):
"""根据ID获取大纲详情"""
result = await db.execute(
select(Outline).where(Outline.id == outline_id)
)
outline = result.scalar_one_or_none()
if not outline:
raise HTTPException(status_code=404, detail="大纲不存在")
return outline
@router.put("/{outline_id}", response_model=OutlineResponse, summary="更新大纲")
async def update_outline(
outline_id: str,
outline_update: OutlineUpdate,
db: AsyncSession = Depends(get_db)
):
"""更新大纲信息,同步更新对应章节和structure字段"""
result = await db.execute(
select(Outline).where(Outline.id == outline_id)
)
outline = result.scalar_one_or_none()
if not outline:
raise HTTPException(status_code=404, detail="大纲不存在")
# 更新字段
update_data = outline_update.model_dump(exclude_unset=True)
for field, value in update_data.items():
setattr(outline, field, value)
# 如果修改了content或title,同步更新structure字段
if 'content' in update_data or 'title' in update_data:
try:
# 尝试解析现有的structure
if outline.structure:
structure_data = json.loads(outline.structure)
else:
structure_data = {}
# 更新structure中的对应字段
if 'title' in update_data:
structure_data['title'] = outline.title
if 'content' in update_data:
structure_data['summary'] = outline.content
structure_data['content'] = outline.content
# 保存更新后的structure
outline.structure = json.dumps(structure_data, ensure_ascii=False)
logger.info(f"同步更新大纲 {outline_id} 的structure字段")
except json.JSONDecodeError:
logger.warning(f"大纲 {outline_id} 的structure字段格式错误,跳过更新")
# 同步更新对应的章节标题和摘要
if 'title' in update_data or 'content' in update_data:
chapter_result = await db.execute(
select(Chapter).where(
Chapter.project_id == outline.project_id,
Chapter.chapter_number == outline.order_index
)
)
chapter = chapter_result.scalar_one_or_none()
if chapter:
if 'title' in update_data:
chapter.title = outline.title
if 'content' in update_data:
# 更新章节摘要(取content前500字符)
chapter.summary = outline.content[:500] if len(outline.content) > 500 else outline.content
logger.info(f"同步更新章节 {chapter.id} 的标题和摘要")
else:
logger.warning(f"未找到对应的章节记录 (order_index={outline.order_index})")
await db.commit()
await db.refresh(outline)
return outline
@router.delete("/{outline_id}", summary="删除大纲")
async def delete_outline(
outline_id: str,
db: AsyncSession = Depends(get_db)
):
"""删除大纲,同步删除章节,并重新排序后续项"""
result = await db.execute(
select(Outline).where(Outline.id == outline_id)
)
outline = result.scalar_one_or_none()
if not outline:
raise HTTPException(status_code=404, detail="大纲不存在")
project_id = outline.project_id
deleted_order = outline.order_index
# 删除对应的章节
await db.execute(
delete(Chapter).where(
Chapter.project_id == project_id,
Chapter.chapter_number == deleted_order
)
)
# 删除大纲
await db.delete(outline)
# 重新排序后续的大纲和章节(序号-1)
result = await db.execute(
select(Outline).where(
Outline.project_id == project_id,
Outline.order_index > deleted_order
)
)
subsequent_outlines = result.scalars().all()
for o in subsequent_outlines:
old_order = o.order_index
o.order_index -= 1
# 同步更新对应的章节
chapter_result = await db.execute(
select(Chapter).where(
Chapter.project_id == project_id,
Chapter.chapter_number == old_order
)
)
chapter = chapter_result.scalar_one_or_none()
if chapter:
chapter.chapter_number = old_order - 1
await db.commit()
return {"message": "大纲删除成功"}
@router.post("/reorder", summary="批量重排序大纲")
async def reorder_outlines(
request: OutlineReorderRequest,
db: AsyncSession = Depends(get_db)
):
"""
批量调整大纲顺序,同步更新章节序号
策略:先收集所有变更,最后一次性提交,避免临时冲突
"""
try:
# 第一步:收集所有大纲和对应的章节
outline_chapter_map = {} # {outline_id: (outline, chapter, old_order, new_order)}
for item in request.orders:
outline_id = item.id
new_order = item.order_index
# 获取大纲
result = await db.execute(
select(Outline).where(Outline.id == outline_id)
)
outline = result.scalar_one_or_none()
if not outline:
logger.warning(f"大纲 {outline_id} 不存在,跳过")
continue
old_order = outline.order_index
# 获取对应的章节(通过旧的chapter_number匹配)
chapter_result = await db.execute(
select(Chapter).where(
Chapter.project_id == outline.project_id,
Chapter.chapter_number == old_order
)
)
chapter = chapter_result.first()
chapter_obj = chapter[0] if chapter else None
outline_chapter_map[outline_id] = (outline, chapter_obj, old_order, new_order)
# 第二步:批量更新所有大纲和章节
updated_outlines = 0
updated_chapters = 0
for outline_id, (outline, chapter, old_order, new_order) in outline_chapter_map.items():
# 更新大纲
outline.order_index = new_order
updated_outlines += 1
# 更新章节
if chapter:
chapter.chapter_number = new_order
chapter.title = outline.title # 同步更新标题
updated_chapters += 1
else:
logger.warning(f"章节 {old_order} 不存在,跳过")
# 第三步:一次性提交所有更改
await db.commit()
logger.info(f"重排序成功:更新了 {updated_outlines} 个大纲,{updated_chapters} 个章节")
return {
"message": "重排序成功",
"updated_outlines": updated_outlines,
"updated_chapters": updated_chapters
}
except Exception as e:
await db.rollback()
logger.error(f"重排序失败: {str(e)}")
raise HTTPException(status_code=500, detail=f"重排序失败: {str(e)}")
@router.post("/generate", response_model=OutlineListResponse, summary="AI生成/续写大纲")
async def generate_outline(
request: OutlineGenerateRequest,
db: AsyncSession = Depends(get_db)
):
"""
使用AI生成或续写小说大纲 - 智能模式
支持三种模式:
- auto: 自动判断(无大纲→新建,有大纲→续写)
- new: 强制全新生成
- continue: 强制续写模式
"""
# 验证项目是否存在
result = await db.execute(
select(Project).where(Project.id == request.project_id)
)
project = result.scalar_one_or_none()
if not project:
raise HTTPException(status_code=404, detail="项目不存在")
try:
# 获取现有大纲(强制从数据库获取最新数据,包括用户手动修改的内容)
existing_result = await db.execute(
select(Outline)
.where(Outline.project_id == request.project_id)
.order_by(Outline.order_index)
.execution_options(populate_existing=True)
)
existing_outlines = existing_result.scalars().all()
# 判断实际执行模式
actual_mode = request.mode
if actual_mode == "auto":
actual_mode = "continue" if existing_outlines else "new"
logger.info(f"自动判断模式:{'续写' if existing_outlines else '新建'}")
# 模式:全新生成
if actual_mode == "new":
return await _generate_new_outline(
request, project, db
)
# 模式:续写
elif actual_mode == "continue":
if not existing_outlines:
raise HTTPException(
status_code=400,
detail="续写模式需要已有大纲,当前项目没有大纲"
)
return await _continue_outline(
request, project, existing_outlines, db
)
else:
raise HTTPException(
status_code=400,
detail=f"不支持的模式: {request.mode}"
)
except HTTPException:
raise
except Exception as e:
logger.error(f"生成大纲失败: {str(e)}")
raise HTTPException(status_code=500, detail=f"生成大纲失败: {str(e)}")
async def _generate_new_outline(
request: OutlineGenerateRequest,
project: Project,
db: AsyncSession
) -> OutlineListResponse:
"""全新生成大纲"""
logger.info(f"全新生成大纲 - 项目: {project.id}, keep_existing: {request.keep_existing}")
# 获取角色信息
characters_result = await db.execute(
select(Character).where(Character.project_id == project.id)
)
characters = characters_result.scalars().all()
characters_info = "\n".join([
f"- {char.name} ({'组织' if char.is_organization else '角色'}, {char.role_type}): "
f"{char.personality[:100] if char.personality else '暂无描述'}"
for char in characters
])
# 使用完整提示词
prompt = prompt_service.get_complete_outline_prompt(
title=project.title,
theme=request.theme or project.theme or "未设定",
genre=request.genre or project.genre or "通用",
chapter_count=request.chapter_count,
narrative_perspective=request.narrative_perspective,
target_words=request.target_words,
time_period=project.world_time_period or "未设定",
location=project.world_location or "未设定",
atmosphere=project.world_atmosphere or "未设定",
rules=project.world_rules or "未设定",
characters_info=characters_info or "暂无角色信息",
requirements=request.requirements or ""
)
# 调用AI
ai_response = await ai_service.generate_text(
prompt=prompt,
provider=request.provider,
model=request.model
)
# 解析响应
outline_data = _parse_ai_response(ai_response)
# 全新生成模式:必须删除旧大纲和章节
# 注意:这是"new"模式的核心逻辑,应该始终删除旧数据
logger.info(f"删除项目 {project.id} 的旧大纲和章节")
await db.execute(
delete(Outline).where(Outline.project_id == project.id)
)
await db.execute(
delete(Chapter).where(Chapter.project_id == project.id)
)
# 保存新大纲
outlines = await _save_outlines(
project.id, outline_data, db, start_index=1
)
# 记录历史
history = GenerationHistory(
project_id=project.id,
prompt=prompt,
generated_content=ai_response,
model=request.model or "default"
)
db.add(history)
await db.commit()
for outline in outlines:
await db.refresh(outline)
logger.info(f"全新生成完成 - {len(outlines)}")
return OutlineListResponse(total=len(outlines), items=outlines)
async def _continue_outline(
request: OutlineGenerateRequest,
project: Project,
existing_outlines: List[Outline],
db: AsyncSession
) -> OutlineListResponse:
"""续写大纲"""
logger.info(f"续写大纲 - 项目: {project.id}, 已有: {len(existing_outlines)}")
# 分析已有大纲
current_chapter_count = len(existing_outlines)
last_chapter_number = existing_outlines[-1].order_index
# 获取最近2章的剧情
recent_outlines = existing_outlines[-2:] if len(existing_outlines) >= 2 else existing_outlines
recent_plot = "\n".join([
f"{o.order_index}章《{o.title}》: {o.content}"
for o in recent_outlines
])
# logger.debug(f"最近三章内容:{recent_plot}")
# 全部章节概览
all_chapters_brief = "\n".join([
f"{o.order_index}章: {o.title}"
for o in existing_outlines
])
# 获取角色信息
characters_result = await db.execute(
select(Character).where(Character.project_id == project.id)
)
characters = characters_result.scalars().all()
characters_info = "\n".join([
f"- {char.name} ({'组织' if char.is_organization else '角色'}, {char.role_type}): "
f"{char.personality[:100] if char.personality else '暂无描述'}"
for char in characters
])
# 情节阶段指导
stage_instructions = {
"development": "继续展开情节,深化角色关系,推进主线冲突",
"climax": "进入故事高潮,矛盾激化,关键冲突爆发",
"ending": "解决主要冲突,收束伏笔,给出结局"
}
stage_instruction = stage_instructions.get(request.plot_stage, "")
# 使用标准续写提示词模板
prompt = prompt_service.get_outline_continue_prompt(
title=project.title,
theme=request.theme or project.theme or "未设定",
genre=request.genre or project.genre or "通用",
narrative_perspective=request.narrative_perspective,
chapter_count=request.chapter_count,
time_period=project.world_time_period or "未设定",
location=project.world_location or "未设定",
atmosphere=project.world_atmosphere or "未设定",
rules=project.world_rules or "未设定",
characters_info=characters_info or "暂无角色信息",
current_chapter_count=current_chapter_count,
all_chapters_brief=all_chapters_brief,
recent_plot=recent_plot,
plot_stage_instruction=stage_instruction,
start_chapter=last_chapter_number + 1,
story_direction=request.story_direction or "自然延续",
requirements=request.requirements or ""
)
# 调用AI
ai_response = await ai_service.generate_text(
prompt=prompt,
provider=request.provider,
model=request.model
)
# 解析响应
outline_data = _parse_ai_response(ai_response)
# 保存续写的大纲
new_outlines = await _save_outlines(
project.id, outline_data, db, start_index=last_chapter_number + 1
)
# 记录历史
history = GenerationHistory(
project_id=project.id,
prompt=prompt,
generated_content=ai_response,
model=request.model or "default"
)
db.add(history)
await db.commit()
for outline in new_outlines:
await db.refresh(outline)
# 返回所有大纲(包括旧的和新的)
all_result = await db.execute(
select(Outline)
.where(Outline.project_id == project.id)
.order_by(Outline.order_index)
)
all_outlines = all_result.scalars().all()
logger.info(f"续写完成 - 新增 {len(new_outlines)} 章,总计 {len(all_outlines)}")
return OutlineListResponse(total=len(all_outlines), items=all_outlines)
def _parse_ai_response(ai_response: str) -> list:
"""解析AI响应为章节数据列表"""
try:
# 清理响应文本
cleaned_text = ai_response.strip()
if cleaned_text.startswith('```json'):
cleaned_text = cleaned_text[7:]
if cleaned_text.startswith('```'):
cleaned_text = cleaned_text[3:]
if cleaned_text.endswith('```'):
cleaned_text = cleaned_text[:-3]
cleaned_text = cleaned_text.strip()
outline_data = json.loads(cleaned_text)
# 确保是列表格式
if not isinstance(outline_data, list):
# 如果是对象,尝试提取chapters字段
if isinstance(outline_data, dict):
outline_data = outline_data.get("chapters", [outline_data])
else:
outline_data = [outline_data]
return outline_data
except json.JSONDecodeError as e:
logger.error(f"AI响应解析失败: {e}")
# 返回一个包含原始内容的章节
return [{
"title": "AI生成的大纲",
"content": ai_response[:1000],
"summary": ai_response[:1000]
}]
async def _save_outlines(
project_id: str,
outline_data: list,
db: AsyncSession,
start_index: int = 1
) -> List[Outline]:
"""保存大纲到数据库"""
outlines = []
for idx, chapter_data in enumerate(outline_data):
order_idx = chapter_data.get("chapter_number", start_index + idx)
title = chapter_data.get("title", f"{order_idx}")
# 优先使用summary,其次content
content = chapter_data.get("summary") or chapter_data.get("content", "")
# 如果有额外信息,添加到内容中
if "key_events" in chapter_data:
content += f"\n\n关键事件:" + "".join(chapter_data["key_events"])
if "characters_involved" in chapter_data:
content += f"\n涉及角色:" + "".join(chapter_data["characters_involved"])
# 创建大纲
outline = Outline(
project_id=project_id,
title=title,
content=content,
structure=json.dumps(chapter_data, ensure_ascii=False),
order_index=order_idx
)
db.add(outline)
outlines.append(outline)
# 同步创建章节记录
chapter = Chapter(
project_id=project_id,
chapter_number=order_idx,
title=title,
summary=content[:500] if len(content) > 500 else content,
status="draft"
)
db.add(chapter)
return outlines
+124
View File
@@ -0,0 +1,124 @@
"""AI去味API - 核心特色功能"""
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.ext.asyncio import AsyncSession
from app.database import get_db
from app.models.generation_history import GenerationHistory
from app.schemas.polish import PolishRequest, PolishResponse
from app.services.ai_service import ai_service
from app.services.prompt_service import prompt_service
from app.logger import get_logger
router = APIRouter(prefix="/polish", tags=["AI去味"])
logger = get_logger(__name__)
@router.post("", response_model=PolishResponse, summary="AI去味")
async def polish_text(
request: PolishRequest,
db: AsyncSession = Depends(get_db)
):
"""
AI去味 - 将AI生成的文本改写得更像人类作家的手笔
核心功能:
- 去除AI痕迹(工整排比、重复修辞、机械总结)
- 增加人性化(口语化、不完美细节、真实情感)
- 优化叙事(自然节奏、简单词汇、松弛感)
- 让对话更生活化
这是本项目的核心特色功能!
"""
try:
# 构建AI去味提示词
prompt = prompt_service.get_denoising_prompt(
original_text=request.original_text
)
logger.info(f"开始AI去味处理,原文长度: {len(request.original_text)}")
# 调用AI进行去味处理
polished_text = await ai_service.generate_text(
prompt=prompt,
provider=request.provider,
model=request.model,
temperature=request.temperature,
max_tokens=len(request.original_text) * 2 # 预留足够token
)
# 计算字数
word_count_before = len(request.original_text)
word_count_after = len(polished_text)
logger.info(f"AI去味完成,处理后长度: {word_count_after}")
# 如果提供了项目ID,记录到历史
if request.project_id:
history = GenerationHistory(
project_id=request.project_id,
generation_type="polish",
prompt=f"原文: {request.original_text[:100]}...",
result=polished_text,
provider=request.provider or "default",
model=request.model or "default"
)
db.add(history)
await db.commit()
return PolishResponse(
original_text=request.original_text,
polished_text=polished_text,
word_count_before=word_count_before,
word_count_after=word_count_after
)
except Exception as e:
logger.error(f"AI去味失败: {str(e)}")
raise HTTPException(status_code=500, detail=f"AI去味失败: {str(e)}")
@router.post("/batch", summary="批量AI去味")
async def polish_batch(
texts: list[str],
project_id: int = None,
provider: str = None,
model: str = None,
db: AsyncSession = Depends(get_db)
):
"""
批量处理多个文本的AI去味
适用于一次性处理多个章节或段落
"""
try:
results = []
for idx, text in enumerate(texts):
logger.info(f"处理第 {idx+1}/{len(texts)} 个文本")
prompt = prompt_service.get_denoising_prompt(original_text=text)
polished_text = await ai_service.generate_text(
prompt=prompt,
provider=provider,
model=model
)
results.append({
"index": idx,
"original": text,
"polished": polished_text,
"word_count_before": len(text),
"word_count_after": len(polished_text)
})
logger.info(f"批量AI去味完成,共处理 {len(results)} 个文本")
return {
"total": len(results),
"results": results
}
except Exception as e:
logger.error(f"批量AI去味失败: {str(e)}")
raise HTTPException(status_code=500, detail=f"批量AI去味失败: {str(e)}")
+414
View File
@@ -0,0 +1,414 @@
"""项目管理API"""
from fastapi import APIRouter, Depends, HTTPException
from fastapi.responses import Response
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, func, delete
from typing import List
from app.database import get_db
from app.models.project import Project
from app.models.character import Character
from app.models.outline import Outline
from app.models.chapter import Chapter
from app.models.generation_history import GenerationHistory
from app.models.relationship import CharacterRelationship, Organization, OrganizationMember
from app.schemas.project import (
ProjectCreate,
ProjectUpdate,
ProjectResponse,
ProjectListResponse
)
from app.logger import get_logger
from app.utils.data_consistency import (
run_full_data_consistency_check,
fix_missing_organization_records,
fix_organization_member_counts
)
logger = get_logger(__name__)
router = APIRouter(prefix="/projects", tags=["项目管理"])
@router.post("", response_model=ProjectResponse, summary="创建项目")
async def create_project(
project: ProjectCreate,
db: AsyncSession = Depends(get_db)
):
try:
logger.info(f"创建新项目: {project.title}")
db_project = Project(**project.model_dump())
db.add(db_project)
await db.commit()
await db.refresh(db_project)
logger.info(f"项目创建成功: {db_project.id}")
return db_project
except Exception as e:
logger.error(f"创建项目失败: {str(e)}", exc_info=True)
raise
@router.get("", response_model=ProjectListResponse, summary="获取项目列表")
async def get_projects(
skip: int = 0,
limit: int = 100,
db: AsyncSession = Depends(get_db)
):
"""获取所有项目列表"""
try:
logger.debug(f"获取项目列表: skip={skip}, limit={limit}")
count_result = await db.execute(select(func.count(Project.id)))
total = count_result.scalar_one()
result = await db.execute(
select(Project)
.order_by(Project.updated_at.desc())
.offset(skip)
.limit(limit)
)
projects = result.scalars().all()
logger.info(f"获取项目列表成功: 共{total}个项目")
return ProjectListResponse(total=total, items=projects)
except Exception as e:
logger.error(f"获取项目列表失败: {str(e)}", exc_info=True)
raise
@router.get("/{project_id}", response_model=ProjectResponse, summary="获取项目详情")
async def get_project(
project_id: str,
db: AsyncSession = Depends(get_db)
):
try:
logger.debug(f"获取项目详情: {project_id}")
result = await db.execute(
select(Project).where(Project.id == project_id)
)
project = result.scalar_one_or_none()
if not project:
logger.warning(f"项目不存在: {project_id}")
raise HTTPException(status_code=404, detail="项目不存在")
logger.info(f"获取项目详情成功: {project.title}")
return project
except HTTPException:
raise
except Exception as e:
logger.error(f"获取项目详情失败: {str(e)}", exc_info=True)
raise
@router.put("/{project_id}", response_model=ProjectResponse, summary="更新项目")
async def update_project(
project_id: str,
project_update: ProjectUpdate,
db: AsyncSession = Depends(get_db)
):
try:
logger.info(f"更新项目: {project_id}")
result = await db.execute(
select(Project).where(Project.id == project_id)
)
project = result.scalar_one_or_none()
if not project:
logger.warning(f"项目不存在: {project_id}")
raise HTTPException(status_code=404, detail="项目不存在")
update_data = project_update.model_dump(exclude_unset=True)
logger.debug(f"更新字段: {list(update_data.keys())}")
for field, value in update_data.items():
setattr(project, field, value)
await db.commit()
await db.refresh(project)
logger.info(f"项目更新成功: {project.title}")
return project
except HTTPException:
raise
except Exception as e:
logger.error(f"更新项目失败: {str(e)}", exc_info=True)
raise
@router.delete("/{project_id}", summary="删除项目")
async def delete_project(
project_id: str,
db: AsyncSession = Depends(get_db)
):
try:
logger.info(f"删除项目: {project_id}")
result = await db.execute(
select(Project).where(Project.id == project_id)
)
project = result.scalar_one_or_none()
if not project:
logger.warning(f"项目不存在: {project_id}")
raise HTTPException(status_code=404, detail="项目不存在")
project_title = project.title
relationships_result = await db.execute(
delete(CharacterRelationship).where(CharacterRelationship.project_id == project_id)
)
logger.debug(f"删除角色关系数: {relationships_result.rowcount}")
orgs_result = await db.execute(
select(Organization).where(Organization.project_id == project_id)
)
orgs = orgs_result.scalars().all()
org_member_count = 0
for org in orgs:
members_result = await db.execute(
delete(OrganizationMember).where(OrganizationMember.organization_id == org.id)
)
org_member_count += members_result.rowcount
logger.debug(f"删除组织成员数: {org_member_count}")
organizations_result = await db.execute(
delete(Organization).where(Organization.project_id == project_id)
)
logger.debug(f"删除组织数: {organizations_result.rowcount}")
history_result = await db.execute(
delete(GenerationHistory).where(GenerationHistory.project_id == project_id)
)
logger.debug(f"删除生成历史数: {history_result.rowcount}")
chapters_result = await db.execute(
delete(Chapter).where(Chapter.project_id == project_id)
)
logger.debug(f"删除章节数: {chapters_result.rowcount}")
outlines_result = await db.execute(
delete(Outline).where(Outline.project_id == project_id)
)
logger.debug(f"删除大纲数: {outlines_result.rowcount}")
characters_result = await db.execute(
delete(Character).where(Character.project_id == project_id)
)
logger.debug(f"删除角色数: {characters_result.rowcount}")
await db.delete(project)
await db.commit()
logger.info(f"项目删除成功: {project_title}")
return {"message": "项目及所有关联数据删除成功"}
except HTTPException:
raise
except Exception as e:
logger.error(f"删除项目失败: {str(e)}", exc_info=True)
raise
@router.get("/{project_id}/export", summary="导出项目章节为TXT")
async def export_project_chapters(
project_id: str,
db: AsyncSession = Depends(get_db)
):
"""
导出项目的所有章节内容为TXT文本文件
按章节顺序组织,包含项目基本信息
"""
try:
logger.info(f"开始导出项目: {project_id}")
result = await db.execute(
select(Project).where(Project.id == project_id)
)
project = result.scalar_one_or_none()
if not project:
logger.warning(f"项目不存在: {project_id}")
raise HTTPException(status_code=404, detail="项目不存在")
chapters_result = await db.execute(
select(Chapter)
.where(Chapter.project_id == project_id)
.order_by(Chapter.chapter_number)
)
chapters = chapters_result.scalars().all()
if not chapters:
logger.warning(f"项目没有章节: {project_id}")
raise HTTPException(status_code=404, detail="项目没有任何章节")
txt_content = []
txt_content.append("=" * 80)
txt_content.append(f"项目标题: {project.title}")
txt_content.append("=" * 80)
if project.description:
txt_content.append(f"\n简介: {project.description}\n")
if project.theme:
txt_content.append(f"主题: {project.theme}")
if project.genre:
txt_content.append(f"类型: {project.genre}")
txt_content.append(f"总章节数: {len(chapters)}")
txt_content.append(f"总字数: {project.current_words}")
txt_content.append("\n" + "=" * 80 + "\n\n")
for chapter in chapters:
txt_content.append(f"{chapter.chapter_number}{chapter.title}")
txt_content.append("-" * 80)
txt_content.append("") # 空行
if chapter.content:
txt_content.append(chapter.content)
else:
txt_content.append("(本章暂无内容)")
txt_content.append("\n\n" + "=" * 80 + "\n\n")
txt_content.append(f"--- 全文完 ---")
txt_content.append(f"\n导出时间: {func.now()}")
final_content = "\n".join(txt_content)
safe_title = "".join(c for c in project.title if c.isalnum() or c in (' ', '-', '_', '', '', ''))
filename = f"{safe_title}.txt"
from urllib.parse import quote
encoded_filename = quote(filename)
logger.info(f"导出成功: {filename}, 共{len(chapters)}章, {len(final_content)}字符")
return Response(
content=final_content.encode('utf-8'),
media_type="text/plain; charset=utf-8",
headers={
"Content-Disposition": f"attachment; filename*=UTF-8''{encoded_filename}",
"Content-Type": "text/plain; charset=utf-8"
}
)
except HTTPException:
raise
except Exception as e:
logger.error(f"导出项目失败: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail=f"导出失败: {str(e)}")
@router.post("/{project_id}/check-consistency", summary="检查数据一致性")
async def check_project_consistency(
project_id: str,
auto_fix: bool = True,
db: AsyncSession = Depends(get_db)
):
"""
检查并修复项目的数据一致性问题
Args:
project_id: 项目ID
auto_fix: 是否自动修复问题(默认True)
返回检查报告,包含:
- organization_records: 检查并修复缺失的Organization记录
- member_counts: 检查并修复组织成员计数
- relationships: 验证关系数据完整性
- organization_members: 验证组织成员数据完整性
"""
try:
logger.info(f"开始数据一致性检查: {project_id}, auto_fix={auto_fix}")
result = await db.execute(
select(Project).where(Project.id == project_id)
)
project = result.scalar_one_or_none()
if not project:
logger.warning(f"项目不存在: {project_id}")
raise HTTPException(status_code=404, detail="项目不存在")
report = await run_full_data_consistency_check(project_id, db, auto_fix)
logger.info(f"数据一致性检查完成: {project_id}")
return report
except HTTPException:
raise
except Exception as e:
logger.error(f"数据一致性检查失败: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail=f"检查失败: {str(e)}")
@router.post("/{project_id}/fix-organizations", summary="修复组织记录")
async def fix_project_organizations(
project_id: str,
db: AsyncSession = Depends(get_db)
):
"""
修复项目中缺失的Organization记录
为所有is_organization=True但没有Organization记录的Character创建记录
"""
try:
logger.info(f"开始修复组织记录: {project_id}")
result = await db.execute(
select(Project).where(Project.id == project_id)
)
project = result.scalar_one_or_none()
if not project:
logger.warning(f"项目不存在: {project_id}")
raise HTTPException(status_code=404, detail="项目不存在")
fixed_count, total_count = await fix_missing_organization_records(project_id, db)
logger.info(f"组织记录修复完成: {project_id}, 修复{fixed_count}/{total_count}")
return {
"message": "组织记录修复完成",
"fixed": fixed_count,
"total": total_count
}
except HTTPException:
raise
except Exception as e:
logger.error(f"修复组织记录失败: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail=f"修复失败: {str(e)}")
@router.post("/{project_id}/fix-member-counts", summary="修复成员计数")
async def fix_project_member_counts(
project_id: str,
db: AsyncSession = Depends(get_db)
):
"""
修复项目中所有组织的成员计数
从实际成员记录重新计算每个组织的member_count
"""
try:
logger.info(f"开始修复成员计数: {project_id}")
result = await db.execute(
select(Project).where(Project.id == project_id)
)
project = result.scalar_one_or_none()
if not project:
logger.warning(f"项目不存在: {project_id}")
raise HTTPException(status_code=404, detail="项目不存在")
fixed_count, total_count = await fix_organization_member_counts(project_id, db)
logger.info(f"成员计数修复完成: {project_id}, 修复{fixed_count}/{total_count}")
return {
"message": "成员计数修复完成",
"fixed": fixed_count,
"total": total_count
}
except HTTPException:
raise
except Exception as e:
logger.error(f"修复成员计数失败: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail=f"修复失败: {str(e)}")
+209
View File
@@ -0,0 +1,209 @@
"""关系管理API"""
from fastapi import APIRouter, Depends, HTTPException, Query
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, or_, and_
from typing import List, Optional
from app.database import get_db
from app.models.relationship import (
RelationshipType,
CharacterRelationship,
Organization,
OrganizationMember
)
from app.models.character import Character
from app.schemas.relationship import (
RelationshipTypeResponse,
CharacterRelationshipCreate,
CharacterRelationshipUpdate,
CharacterRelationshipResponse,
RelationshipGraphData,
RelationshipGraphNode,
RelationshipGraphLink
)
from app.logger import get_logger
router = APIRouter(prefix="/relationships", tags=["关系管理"])
logger = get_logger(__name__)
@router.get("/types", response_model=List[RelationshipTypeResponse], summary="获取关系类型列表")
async def get_relationship_types(db: AsyncSession = Depends(get_db)):
"""获取所有预定义的关系类型"""
result = await db.execute(select(RelationshipType).order_by(RelationshipType.category, RelationshipType.id))
types = result.scalars().all()
return types
@router.get("/project/{project_id}", response_model=List[CharacterRelationshipResponse], summary="获取项目的所有关系")
async def get_project_relationships(
project_id: str,
character_id: Optional[str] = Query(None, description="筛选特定角色的关系"),
db: AsyncSession = Depends(get_db)
):
"""
获取项目中的所有角色关系
- 如果提供character_id,则只返回与该角色相关的关系(作为发起方或接收方)
- 否则返回项目中的所有关系
"""
query = select(CharacterRelationship).where(
CharacterRelationship.project_id == project_id
)
if character_id:
query = query.where(
or_(
CharacterRelationship.character_from_id == character_id,
CharacterRelationship.character_to_id == character_id
)
)
query = query.order_by(CharacterRelationship.created_at.desc())
result = await db.execute(query)
relationships = result.scalars().all()
logger.info(f"获取项目 {project_id} 的关系列表,共 {len(relationships)}")
return relationships
@router.get("/graph/{project_id}", response_model=RelationshipGraphData, summary="获取关系图谱数据")
async def get_relationship_graph(
project_id: str,
db: AsyncSession = Depends(get_db)
):
"""
获取用于可视化的关系图谱数据
返回格式:
- nodes: 角色节点列表
- links: 关系连线列表
"""
# 获取所有角色(节点)
chars_result = await db.execute(
select(Character).where(Character.project_id == project_id)
)
characters = chars_result.scalars().all()
nodes = [
RelationshipGraphNode(
id=c.id,
name=c.name,
type="organization" if c.is_organization else "character",
role_type=c.role_type,
avatar=c.avatar_url
)
for c in characters
]
# 获取所有关系(边)
rels_result = await db.execute(
select(CharacterRelationship).where(
CharacterRelationship.project_id == project_id
)
)
relationships = rels_result.scalars().all()
links = [
RelationshipGraphLink(
source=r.character_from_id,
target=r.character_to_id,
relationship=r.relationship_name or "未知关系",
intimacy=r.intimacy_level,
status=r.status
)
for r in relationships
]
logger.info(f"获取项目 {project_id} 的关系图谱:{len(nodes)} 个节点,{len(links)} 条关系")
return RelationshipGraphData(nodes=nodes, links=links)
@router.post("/", response_model=CharacterRelationshipResponse, summary="创建角色关系")
async def create_relationship(
relationship: CharacterRelationshipCreate,
db: AsyncSession = Depends(get_db)
):
"""
手动创建角色关系
- 需要提供角色A和角色B的ID
- 可以指定预定义的关系类型或自定义关系名称
- 可以设置亲密度、状态等属性
"""
# 验证角色是否存在
char_from = await db.execute(
select(Character).where(Character.id == relationship.character_from_id)
)
char_to = await db.execute(
select(Character).where(Character.id == relationship.character_to_id)
)
if not char_from.scalar_one_or_none():
raise HTTPException(status_code=404, detail=f"角色AID: {relationship.character_from_id})不存在")
if not char_to.scalar_one_or_none():
raise HTTPException(status_code=404, detail=f"角色BID: {relationship.character_to_id})不存在")
# 创建关系
db_relationship = CharacterRelationship(
**relationship.model_dump(),
source="manual"
)
db.add(db_relationship)
await db.commit()
await db.refresh(db_relationship)
logger.info(f"创建关系成功:{relationship.character_from_id} -> {relationship.character_to_id}")
return db_relationship
@router.put("/{relationship_id}", response_model=CharacterRelationshipResponse, summary="更新关系")
async def update_relationship(
relationship_id: str,
relationship: CharacterRelationshipUpdate,
db: AsyncSession = Depends(get_db)
):
"""更新角色关系的属性(亲密度、状态等)"""
result = await db.execute(
select(CharacterRelationship).where(
CharacterRelationship.id == relationship_id
)
)
db_rel = result.scalar_one_or_none()
if not db_rel:
raise HTTPException(status_code=404, detail="关系不存在")
# 更新字段
update_data = relationship.model_dump(exclude_unset=True)
for field, value in update_data.items():
setattr(db_rel, field, value)
await db.commit()
await db.refresh(db_rel)
logger.info(f"更新关系成功:{relationship_id}")
return db_rel
@router.delete("/{relationship_id}", summary="删除关系")
async def delete_relationship(
relationship_id: str,
db: AsyncSession = Depends(get_db)
):
"""删除角色关系"""
result = await db.execute(
select(CharacterRelationship).where(
CharacterRelationship.id == relationship_id
)
)
db_rel = result.scalar_one_or_none()
if not db_rel:
raise HTTPException(status_code=404, detail="关系不存在")
await db.delete(db_rel)
await db.commit()
logger.info(f"删除关系成功:{relationship_id}")
return {"message": "关系删除成功", "id": relationship_id}
+125
View File
@@ -0,0 +1,125 @@
"""
用户管理 API
"""
from fastapi import APIRouter, HTTPException, Request, Depends
from pydantic import BaseModel
from typing import List, Optional
from app.user_manager import user_manager, User
router = APIRouter(prefix="/users", tags=["用户管理"])
def require_login(request: Request):
"""依赖:要求用户已登录"""
if not hasattr(request.state, "user") or not request.state.user:
raise HTTPException(status_code=401, detail="需要登录")
return request.state.user
def require_admin(request: Request):
"""依赖:要求用户为管理员"""
user = require_login(request)
if not request.state.is_admin:
raise HTTPException(status_code=403, detail="需要管理员权限")
return user
class SetAdminRequest(BaseModel):
user_id: str
is_admin: bool
@router.get("/current")
async def get_current_user(user: User = Depends(require_login)):
"""获取当前登录用户信息"""
return user.dict()
@router.get("", response_model=List[dict])
async def list_users(admin_user: User = Depends(require_admin)):
"""
获取所有用户列表(仅管理员)
"""
users = await user_manager.get_all_users()
return [user.dict() for user in users]
@router.post("/set-admin")
async def set_admin(
data: SetAdminRequest,
request: Request,
admin_user: User = Depends(require_admin)
):
"""
设置用户的管理员权限(仅管理员)
限制:
- 不能撤销自己的管理员权限
- 至少保留一个管理员
"""
# 检查是否尝试撤销自己的权限
if data.user_id == admin_user.user_id and not data.is_admin:
raise HTTPException(
status_code=400,
detail="不能撤销自己的管理员权限"
)
# 尝试设置管理员权限
success = await user_manager.set_admin(data.user_id, data.is_admin)
if not success:
if not data.is_admin:
raise HTTPException(
status_code=400,
detail="无法撤销管理员权限,至少需要保留一个管理员"
)
else:
raise HTTPException(
status_code=404,
detail="用户不存在"
)
return {
"message": f"{'授予' if data.is_admin else '撤销'}管理员权限",
"user_id": data.user_id,
"is_admin": data.is_admin
}
@router.delete("/{user_id}")
async def delete_user(
user_id: str,
admin_user: User = Depends(require_admin)
):
"""
删除用户(仅管理员)
限制:
- 不能删除管理员用户
"""
success = await user_manager.delete_user(user_id)
if not success:
raise HTTPException(
status_code=400,
detail="无法删除该用户(用户不存在或为管理员)"
)
return {
"message": "用户已删除",
"user_id": user_id
}
@router.get("/{user_id}")
async def get_user(
user_id: str,
admin_user: User = Depends(require_admin)
):
"""获取指定用户信息(仅管理员)"""
user = await user_manager.get_user(user_id)
if not user:
raise HTTPException(status_code=404, detail="用户不存在")
return user.dict()
File diff suppressed because it is too large Load Diff
+90
View File
@@ -0,0 +1,90 @@
"""应用配置管理"""
from pydantic_settings import BaseSettings
from typing import Optional
from pathlib import Path
import logging
# 获取项目根目录(从backend/app/config.py向上两级)
PROJECT_ROOT = Path(__file__).parent.parent
DATA_DIR = PROJECT_ROOT / "data"
DATA_DIR.mkdir(exist_ok=True)
# 配置模块使用标准logging(在logger.py初始化之前)
config_logger = logging.getLogger(__name__)
# 数据库文件路径(绝对路径)
DB_FILE = DATA_DIR / "ai_story.db"
# 生成数据库URL(在类外部生成,确保使用绝对路径)
# 将Windows反斜杠转换为正斜杠,SQLite URL格式要求
DATABASE_URL = f"sqlite+aiosqlite:///{str(DB_FILE.absolute()).replace(chr(92), '/')}"
config_logger.debug(f"数据库文件路径: {DB_FILE}")
config_logger.debug(f"数据库URL: {DATABASE_URL}")
class Settings(BaseSettings):
"""应用配置"""
# 应用配置
app_name: str = "MuMuAINovel"
app_version: str = "1.0.0"
app_host: str = "0.0.0.0"
app_port: int = 8000
debug: bool = True
# 日志配置
log_level: str = "INFO" # DEBUG, INFO, WARNING, ERROR, CRITICAL
log_to_file: bool = True # 是否输出到文件
log_file_path: str = str(PROJECT_ROOT / "logs" / "app.log")
log_max_bytes: int = 10 * 1024 * 1024 # 10MB
log_backup_count: int = 30 # 保留30个备份文件
# CORS配置
cors_origins: list[str] = ["http://localhost:8000", "http://127.0.0.1:8000"]
# 数据库配置 - 使用预先计算好的绝对路径URL
database_url: str = DATABASE_URL
# AI服务配置
openai_api_key: Optional[str] = None
openai_base_url: Optional[str] = None
gemini_api_key: Optional[str] = None
gemini_base_url: Optional[str] = None
anthropic_api_key: Optional[str] = None
anthropic_base_url: Optional[str] = None
default_ai_provider: str = "openai"
default_model: str = "gpt-4"
default_temperature: float = 0.7
default_max_tokens: int = 2000
# LinuxDO OAuth2 配置
LINUXDO_CLIENT_ID: Optional[str] = None
LINUXDO_CLIENT_SECRET: Optional[str] = None
# 回调地址:Docker部署时必须使用实际域名或服务器IP,不能使用localhost
# 本地开发: http://localhost:8000/api/auth/callback
# 生产环境: https://your-domain.com/api/auth/callback 或 http://your-ip:8000/api/auth/callback
LINUXDO_REDIRECT_URI: Optional[str] = None
# 前端URL配置(用于OAuth回调后重定向)
# 本地开发: http://localhost:8000
# 生产环境: https://your-domain.com 或 http://your-ip:8000
FRONTEND_URL: str = "http://localhost:8000"
# 初始管理员配置(LinuxDO user_id
INITIAL_ADMIN_LINUXDO_ID: Optional[str] = None
# 本地账户登录配置
LOCAL_AUTH_ENABLED: bool = True # 是否启用本地账户登录
LOCAL_AUTH_USERNAME: Optional[str] = None # 本地登录用户名
LOCAL_AUTH_PASSWORD: Optional[str] = None # 本地登录密码
LOCAL_AUTH_DISPLAY_NAME: str = "本地用户" # 本地用户显示名称
class Config:
env_file = ".env"
case_sensitive = False
# 创建全局配置实例
settings = Settings()
config_logger.info(f"配置加载完成: {settings.app_name} v{settings.app_version}")
config_logger.debug(f"调试模式: {settings.debug}")
config_logger.debug(f"AI提供商: {settings.default_ai_provider}")
+261
View File
@@ -0,0 +1,261 @@
"""数据库连接和会话管理 - 支持多用户数据隔离"""
import asyncio
from typing import Dict, Any
from datetime import datetime
from sqlalchemy import select, text
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker
from sqlalchemy.orm import declarative_base
from sqlalchemy.pool import StaticPool
from fastapi import Request, HTTPException
from app.config import settings
from app.logger import get_logger
logger = get_logger(__name__)
# 创建基类
Base = declarative_base()
# 引擎缓存:每个用户一个引擎
_engine_cache: Dict[str, Any] = {}
# 锁管理:用于保护引擎创建过程
_engine_locks: Dict[str, asyncio.Lock] = {}
_cache_lock = asyncio.Lock()
# 会话统计(用于监控连接泄漏)
_session_stats = {
"created": 0,
"closed": 0,
"active": 0,
"errors": 0,
"generator_exits": 0,
"last_check": None
}
async def get_engine(user_id: str):
"""获取或创建用户专属的数据库引擎(线程安全)
Args:
user_id: 用户ID
Returns:
用户专属的异步引擎
"""
if user_id in _engine_cache:
return _engine_cache[user_id]
async with _cache_lock:
if user_id not in _engine_locks:
_engine_locks[user_id] = asyncio.Lock()
user_lock = _engine_locks[user_id]
async with user_lock:
if user_id not in _engine_cache:
db_url = f"sqlite+aiosqlite:///data/ai_story_user_{user_id}.db"
engine = create_async_engine(
db_url,
echo=False,
future=True,
poolclass=StaticPool,
pool_pre_ping=True,
pool_recycle=3600,
connect_args={
"timeout": 30,
"check_same_thread": False
}
)
try:
async with engine.begin() as conn:
await conn.execute(text("PRAGMA journal_mode=WAL"))
await conn.execute(text("PRAGMA synchronous=NORMAL"))
await conn.execute(text("PRAGMA cache_size=-64000"))
await conn.execute(text("PRAGMA temp_store=MEMORY"))
await conn.execute(text("PRAGMA busy_timeout=5000"))
logger.info(f"✅ 用户 {user_id} 的数据库已优化(WAL模式 + 64MB缓存)")
except Exception as e:
logger.warning(f"⚠️ 用户 {user_id} 数据库优化失败: {str(e)}")
_engine_cache[user_id] = engine
logger.info(f"为用户 {user_id} 创建数据库引擎")
return _engine_cache[user_id]
async def get_db(request: Request):
"""获取数据库会话的依赖函数
从 request.state.user_id 获取用户ID,然后返回该用户的数据库会话
"""
user_id = getattr(request.state, "user_id", None)
if not user_id:
raise HTTPException(status_code=401, detail="未登录或用户ID缺失")
engine = await get_engine(user_id)
AsyncSessionLocal = async_sessionmaker(
engine,
class_=AsyncSession,
expire_on_commit=False
)
session = AsyncSessionLocal()
session_id = id(session)
global _session_stats
_session_stats["created"] += 1
_session_stats["active"] += 1
logger.debug(f"📊 会话创建 [User:{user_id}][ID:{session_id}] - 活跃:{_session_stats['active']}, 总创建:{_session_stats['created']}, 总关闭:{_session_stats['closed']}")
try:
yield session
if session.in_transaction():
await session.rollback()
except GeneratorExit:
_session_stats["generator_exits"] += 1
logger.warning(f"⚠️ GeneratorExit [User:{user_id}][ID:{session_id}] - SSE连接断开(总计:{_session_stats['generator_exits']}次)")
try:
if session.in_transaction():
await session.rollback()
logger.info(f"✅ 事务已回滚 [User:{user_id}][ID:{session_id}]GeneratorExit")
except Exception as rollback_error:
_session_stats["errors"] += 1
logger.error(f"❌ GeneratorExit回滚失败 [User:{user_id}][ID:{session_id}]: {str(rollback_error)}")
except Exception as e:
_session_stats["errors"] += 1
logger.error(f"❌ 会话异常 [User:{user_id}][ID:{session_id}]: {str(e)}")
try:
if session.in_transaction():
await session.rollback()
logger.info(f"✅ 事务已回滚 [User:{user_id}][ID:{session_id}](异常)")
except Exception as rollback_error:
logger.error(f"❌ 异常回滚失败 [User:{user_id}][ID:{session_id}]: {str(rollback_error)}")
raise
finally:
try:
if session.in_transaction():
await session.rollback()
logger.warning(f"⚠️ finally中发现未提交事务 [User:{user_id}][ID:{session_id}],已回滚")
await session.close()
_session_stats["closed"] += 1
_session_stats["active"] -= 1
_session_stats["last_check"] = datetime.now().isoformat()
logger.debug(f"📊 会话关闭 [User:{user_id}][ID:{session_id}] - 活跃:{_session_stats['active']}, 总创建:{_session_stats['created']}, 总关闭:{_session_stats['closed']}, 错误:{_session_stats['errors']}")
if _session_stats["active"] > 10:
logger.warning(f"🚨 活跃会话数过多: {_session_stats['active']},可能存在连接泄漏!")
elif _session_stats["active"] < 0:
logger.error(f"🚨 活跃会话数异常: {_session_stats['active']},统计可能不准确!")
except Exception as e:
_session_stats["errors"] += 1
logger.error(f"❌ 关闭会话时出错 [User:{user_id}][ID:{session_id}]: {str(e)}", exc_info=True)
try:
await session.close()
except:
pass
async def _init_relationship_types(user_id: str):
"""为指定用户初始化预置的关系类型数据
Args:
user_id: 用户ID
"""
from app.models.relationship import RelationshipType
relationship_types = [
{"name": "父亲", "category": "family", "reverse_name": "子女", "intimacy_range": "high", "icon": "👨"},
{"name": "母亲", "category": "family", "reverse_name": "子女", "intimacy_range": "high", "icon": "👩"},
{"name": "兄弟", "category": "family", "reverse_name": "兄弟", "intimacy_range": "high", "icon": "👬"},
{"name": "姐妹", "category": "family", "reverse_name": "姐妹", "intimacy_range": "high", "icon": "👭"},
{"name": "子女", "category": "family", "reverse_name": "父母", "intimacy_range": "high", "icon": "👶"},
{"name": "配偶", "category": "family", "reverse_name": "配偶", "intimacy_range": "high", "icon": "💑"},
{"name": "恋人", "category": "family", "reverse_name": "恋人", "intimacy_range": "high", "icon": "💕"},
{"name": "师父", "category": "social", "reverse_name": "徒弟", "intimacy_range": "high", "icon": "🎓"},
{"name": "徒弟", "category": "social", "reverse_name": "师父", "intimacy_range": "high", "icon": "📚"},
{"name": "朋友", "category": "social", "reverse_name": "朋友", "intimacy_range": "medium", "icon": "🤝"},
{"name": "同学", "category": "social", "reverse_name": "同学", "intimacy_range": "medium", "icon": "🎒"},
{"name": "邻居", "category": "social", "reverse_name": "邻居", "intimacy_range": "low", "icon": "🏘️"},
{"name": "知己", "category": "social", "reverse_name": "知己", "intimacy_range": "high", "icon": "💙"},
{"name": "上司", "category": "professional", "reverse_name": "下属", "intimacy_range": "low", "icon": "👔"},
{"name": "下属", "category": "professional", "reverse_name": "上司", "intimacy_range": "low", "icon": "💼"},
{"name": "同事", "category": "professional", "reverse_name": "同事", "intimacy_range": "medium", "icon": "🤵"},
{"name": "合作伙伴", "category": "professional", "reverse_name": "合作伙伴", "intimacy_range": "medium", "icon": "🤜🤛"},
{"name": "敌人", "category": "hostile", "reverse_name": "敌人", "intimacy_range": "low", "icon": "⚔️"},
{"name": "仇人", "category": "hostile", "reverse_name": "仇人", "intimacy_range": "low", "icon": "💢"},
{"name": "竞争对手", "category": "hostile", "reverse_name": "竞争对手", "intimacy_range": "low", "icon": "🎯"},
{"name": "宿敌", "category": "hostile", "reverse_name": "宿敌", "intimacy_range": "low", "icon": ""},
]
try:
engine = await get_engine(user_id)
AsyncSessionLocal = async_sessionmaker(
engine,
class_=AsyncSession,
expire_on_commit=False
)
async with AsyncSessionLocal() as session:
result = await session.execute(select(RelationshipType))
existing = result.scalars().first()
if existing:
logger.info(f"用户 {user_id} 的关系类型数据已存在,跳过初始化")
return
logger.info(f"开始为用户 {user_id} 插入关系类型数据...")
for rt_data in relationship_types:
relationship_type = RelationshipType(**rt_data)
session.add(relationship_type)
await session.commit()
logger.info(f"成功为用户 {user_id} 插入 {len(relationship_types)} 条关系类型数据")
except Exception as e:
logger.error(f"用户 {user_id} 初始化关系类型数据失败: {str(e)}", exc_info=True)
raise
async def init_db(user_id: str):
"""初始化指定用户的数据库,创建所有表并插入预置数据
Args:
user_id: 用户ID
"""
try:
logger.info(f"开始初始化用户 {user_id} 的数据库...")
engine = await get_engine(user_id)
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
await _init_relationship_types(user_id)
logger.info(f"用户 {user_id} 的数据库初始化成功")
except Exception as e:
logger.error(f"用户 {user_id} 的数据库初始化失败: {str(e)}", exc_info=True)
raise
async def close_db():
"""关闭所有数据库连接"""
try:
logger.info("正在关闭所有数据库连接...")
for user_id, engine in _engine_cache.items():
await engine.dispose()
logger.info(f"用户 {user_id} 的数据库连接已关闭")
_engine_cache.clear()
logger.info("所有数据库连接已关闭")
except Exception as e:
logger.error(f"关闭数据库连接失败: {str(e)}", exc_info=True)
raise
+73
View File
@@ -0,0 +1,73 @@
"""初始化关系类型数据"""
import asyncio
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
from app.database import AsyncSessionLocal
from app.models.relationship import RelationshipType
from app.logger import get_logger
logger = get_logger(__name__)
async def init_relationship_types():
"""初始化预置的关系类型数据"""
# 预置关系类型数据
relationship_types = [
# 家族关系
{"name": "父亲", "category": "family", "reverse_name": "子女", "intimacy_range": "high", "icon": "👨"},
{"name": "母亲", "category": "family", "reverse_name": "子女", "intimacy_range": "high", "icon": "👩"},
{"name": "兄弟", "category": "family", "reverse_name": "兄弟", "intimacy_range": "high", "icon": "👬"},
{"name": "姐妹", "category": "family", "reverse_name": "姐妹", "intimacy_range": "high", "icon": "👭"},
{"name": "子女", "category": "family", "reverse_name": "父母", "intimacy_range": "high", "icon": "👶"},
{"name": "配偶", "category": "family", "reverse_name": "配偶", "intimacy_range": "high", "icon": "💑"},
{"name": "恋人", "category": "family", "reverse_name": "恋人", "intimacy_range": "high", "icon": "💕"},
# 社交关系
{"name": "师父", "category": "social", "reverse_name": "徒弟", "intimacy_range": "high", "icon": "🎓"},
{"name": "徒弟", "category": "social", "reverse_name": "师父", "intimacy_range": "high", "icon": "📚"},
{"name": "朋友", "category": "social", "reverse_name": "朋友", "intimacy_range": "medium", "icon": "🤝"},
{"name": "同学", "category": "social", "reverse_name": "同学", "intimacy_range": "medium", "icon": "🎒"},
{"name": "邻居", "category": "social", "reverse_name": "邻居", "intimacy_range": "low", "icon": "🏘️"},
{"name": "知己", "category": "social", "reverse_name": "知己", "intimacy_range": "high", "icon": "💙"},
# 职业关系
{"name": "上司", "category": "professional", "reverse_name": "下属", "intimacy_range": "low", "icon": "👔"},
{"name": "下属", "category": "professional", "reverse_name": "上司", "intimacy_range": "low", "icon": "💼"},
{"name": "同事", "category": "professional", "reverse_name": "同事", "intimacy_range": "medium", "icon": "🤵"},
{"name": "合作伙伴", "category": "professional", "reverse_name": "合作伙伴", "intimacy_range": "medium", "icon": "🤜🤛"},
# 敌对关系
{"name": "敌人", "category": "hostile", "reverse_name": "敌人", "intimacy_range": "low", "icon": "⚔️"},
{"name": "仇人", "category": "hostile", "reverse_name": "仇人", "intimacy_range": "low", "icon": "💢"},
{"name": "竞争对手", "category": "hostile", "reverse_name": "竞争对手", "intimacy_range": "low", "icon": "🎯"},
{"name": "宿敌", "category": "hostile", "reverse_name": "宿敌", "intimacy_range": "low", "icon": ""},
]
async with AsyncSessionLocal() as session:
try:
# 检查是否已经有数据
result = await session.execute(select(RelationshipType))
existing = result.scalars().first()
if existing:
logger.info("关系类型数据已存在,跳过初始化")
return
# 插入预置数据
logger.info("开始插入关系类型数据...")
for rt_data in relationship_types:
relationship_type = RelationshipType(**rt_data)
session.add(relationship_type)
await session.commit()
logger.info(f"成功插入 {len(relationship_types)} 条关系类型数据")
except Exception as e:
logger.error(f"初始化关系类型数据失败: {str(e)}", exc_info=True)
await session.rollback()
raise
if __name__ == "__main__":
asyncio.run(init_relationship_types())
+158
View File
@@ -0,0 +1,158 @@
"""统一日志配置模块 - Uvicorn风格"""
import logging
import sys
from pathlib import Path
from logging.handlers import RotatingFileHandler
from typing import Optional
class UvicornFormatter(logging.Formatter):
"""Uvicorn风格的日志格式化器"""
# 日志级别颜色(ANSI转义码)
COLORS = {
'DEBUG': '\033[36m', # 青色
'INFO': '\033[32m', # 绿色
'WARNING': '\033[33m', # 黄色
'ERROR': '\033[31m', # 红色
'CRITICAL': '\033[35m', # 紫色
}
RESET = '\033[0m'
def __init__(self, use_colors: bool = True):
"""
初始化格式化器
Args:
use_colors: 是否使用颜色(控制台输出使用,文件输出不使用)
"""
super().__init__()
self.use_colors = use_colors
def format(self, record):
"""格式化日志记录为 Uvicorn 风格"""
# 获取日志级别名称
levelname = record.levelname
# 添加颜色(如果启用且终端支持)
if self.use_colors and sys.stderr.isatty():
colored_level = f"{self.COLORS.get(levelname, '')}{levelname}{self.RESET}"
else:
colored_level = levelname
# 添加请求追踪ID(如果存在)
request_id = getattr(record, 'request_id', None)
request_id_str = f" [{request_id}]" if request_id else ""
# Uvicorn风格格式: INFO: module_name - message [request_id]
# 注意:INFO后面有5个空格,保持对齐
return f"{colored_level}: {record.name}{request_id_str} - {record.getMessage()}"
# 全局标志,防止重复初始化
_logging_configured = False
def setup_logging(
level: str = "INFO",
log_to_file: bool = False,
log_file_path: Optional[str] = None,
max_bytes: int = 10 * 1024 * 1024,
backup_count: int = 30
):
"""
配置统一的 Uvicorn 风格日志系统
Args:
level: 日志级别 (DEBUG, INFO, WARNING, ERROR, CRITICAL)
log_to_file: 是否输出到文件
log_file_path: 日志文件路径
max_bytes: 单个日志文件最大字节数(默认10MB)
backup_count: 保留的备份文件数量(默认30个)
"""
global _logging_configured
# 如果已经配置过,直接返回
if _logging_configured:
return logging.getLogger()
# 获取根日志器
root_logger = logging.getLogger()
root_logger.setLevel(getattr(logging, level.upper()))
# 清除已有的处理器,避免重复
root_logger.handlers.clear()
# 1. 创建控制台处理器(带颜色)
console_handler = logging.StreamHandler(sys.stderr)
console_handler.setLevel(getattr(logging, level.upper()))
console_formatter = UvicornFormatter(use_colors=True)
console_handler.setFormatter(console_formatter)
root_logger.addHandler(console_handler)
# 2. 创建文件处理器(如果启用)
if log_to_file and log_file_path:
# 确保日志目录存在
log_file = Path(log_file_path)
log_file.parent.mkdir(parents=True, exist_ok=True)
# 使用RotatingFileHandler实现日志轮转
file_handler = RotatingFileHandler(
filename=log_file_path,
maxBytes=max_bytes,
backupCount=backup_count,
encoding='utf-8'
)
file_handler.setLevel(getattr(logging, level.upper()))
# 文件日志不使用颜色
file_formatter = UvicornFormatter(use_colors=False)
file_handler.setFormatter(file_formatter)
root_logger.addHandler(file_handler)
# 记录日志配置信息
root_logger.info(f"日志文件输出已启用: {log_file_path}")
root_logger.info(f"日志轮转配置: 单文件最大{max_bytes / 1024 / 1024:.1f}MB, 保留{backup_count}个备份")
# 配置第三方库的日志级别
_configure_third_party_loggers()
# 标记为已配置
_logging_configured = True
return root_logger
def _configure_third_party_loggers():
"""配置第三方库的日志级别"""
# SQLAlchemy - 禁用SQL日志
logging.getLogger('sqlalchemy.engine').setLevel(logging.WARNING)
logging.getLogger('sqlalchemy.pool').setLevel(logging.WARNING)
logging.getLogger('sqlalchemy.dialects').setLevel(logging.WARNING)
logging.getLogger('sqlalchemy.orm').setLevel(logging.WARNING)
# Watchfiles - 开发时的文件监控,降低级别
logging.getLogger('watchfiles').setLevel(logging.WARNING)
# httpx - HTTP客户端
logging.getLogger('httpx').setLevel(logging.WARNING)
# openai/anthropic - AI客户端库
logging.getLogger('openai').setLevel(logging.WARNING)
logging.getLogger('anthropic').setLevel(logging.WARNING)
# 应用模块 - 可根据需要调整
logging.getLogger('app.services.ai_service').setLevel(logging.WARNING)
logging.getLogger('app.api.wizard').setLevel(logging.WARNING)
def get_logger(name: str) -> logging.Logger:
"""
获取指定名称的日志器
Args:
name: 日志器名称,通常使用 __name__
Returns:
配置好的日志器实例
"""
return logging.getLogger(name)
+176
View File
@@ -0,0 +1,176 @@
"""FastAPI应用主入口"""
from fastapi import FastAPI, Request, status
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from fastapi.responses import JSONResponse, FileResponse
from fastapi.exceptions import RequestValidationError
from contextlib import asynccontextmanager
from pathlib import Path
from app.config import settings
from app.database import close_db, _session_stats
from app.logger import setup_logging, get_logger
from app.middleware import RequestIDMiddleware
from app.middleware.auth_middleware import AuthMiddleware
setup_logging(
level=settings.log_level,
log_to_file=settings.log_to_file,
log_file_path=settings.log_file_path,
max_bytes=settings.log_max_bytes,
backup_count=settings.log_backup_count
)
logger = get_logger(__name__)
@asynccontextmanager
async def lifespan(app: FastAPI):
"""应用生命周期管理"""
logger.info("应用启动,等待用户登录...")
yield
await close_db()
logger.info("应用已关闭")
app = FastAPI(
title=settings.app_name,
version=settings.app_version,
description="AI写小说工具 - 智能小说创作助手",
lifespan=lifespan
)
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(request: Request, exc: RequestValidationError):
"""处理请求验证错误"""
logger.error(f"请求验证失败: {exc.errors()}")
return JSONResponse(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
content={
"detail": "请求参数验证失败",
"errors": exc.errors()
}
)
@app.exception_handler(Exception)
async def global_exception_handler(request: Request, exc: Exception):
"""处理所有未捕获的异常"""
logger.error(f"未处理的异常: {type(exc).__name__}: {str(exc)}", exc_info=True)
return JSONResponse(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
content={
"detail": "服务器内部错误",
"message": str(exc) if settings.debug else "请稍后重试"
}
)
app.add_middleware(RequestIDMiddleware)
app.add_middleware(AuthMiddleware)
if settings.debug:
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
else:
app.add_middleware(
CORSMiddleware,
allow_origins=settings.cors_origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.get("/health")
async def health_check():
"""健康检查"""
return {"status": "ok"}
@app.get("/health/db-sessions")
async def db_session_stats():
"""
数据库会话统计(监控连接泄漏)
返回:
- created: 总创建会话数
- closed: 总关闭会话数
- active: 当前活跃会话数(应该接近0)
- errors: 错误次数
- generator_exits: SSE断开次数
- last_check: 最后检查时间
"""
return {
"status": "ok",
"session_stats": _session_stats,
"warning": "活跃会话数过多" if _session_stats["active"] > 10 else None
}
from app.api import (
projects, outlines, characters, chapters,
wizard_stream, relationships, organizations,
auth, users
)
app.include_router(auth.router, prefix="/api")
app.include_router(users.router, prefix="/api")
app.include_router(projects.router, prefix="/api")
app.include_router(wizard_stream.router, prefix="/api")
app.include_router(outlines.router, prefix="/api")
app.include_router(characters.router, prefix="/api")
app.include_router(chapters.router, prefix="/api")
app.include_router(relationships.router, prefix="/api")
app.include_router(organizations.router, prefix="/api")
static_dir = Path(__file__).parent.parent / "static"
if static_dir.exists():
app.mount("/assets", StaticFiles(directory=str(static_dir / "assets")), name="assets")
@app.get("/{full_path:path}")
async def serve_spa(full_path: str):
"""服务单页应用,所有非API路径返回index.html"""
if full_path.startswith("api/"):
return JSONResponse(
status_code=404,
content={"detail": "API路径不存在"}
)
file_path = static_dir / full_path
if file_path.is_file():
return FileResponse(file_path)
index_file = static_dir / "index.html"
if index_file.exists():
return FileResponse(index_file)
return JSONResponse(
status_code=404,
content={"detail": "页面不存在"}
)
else:
logger.warning("静态文件目录不存在,请先构建前端: cd frontend && npm run build")
@app.get("/")
async def root():
return {
"message": "欢迎使用AI Story Creator",
"version": settings.app_version,
"docs": "/docs",
"notice": "请先构建前端: cd frontend && npm run build"
}
if __name__ == "__main__":
import uvicorn
uvicorn.run(
"app.main:app",
host=settings.app_host,
port=settings.app_port,
reload=settings.debug
)
+4
View File
@@ -0,0 +1,4 @@
"""中间件模块"""
from .request_id import RequestIDMiddleware
__all__ = ['RequestIDMiddleware']
+39
View File
@@ -0,0 +1,39 @@
"""
认证中间件 - 从 Cookie 中提取用户信息并注入到 request.state
"""
from fastapi import Request
from starlette.middleware.base import BaseHTTPMiddleware
from app.user_manager import user_manager
class AuthMiddleware(BaseHTTPMiddleware):
"""认证中间件"""
async def dispatch(self, request: Request, call_next):
"""
处理请求,从 Cookie 中提取用户 ID 并注入到 request.state
"""
# 从 Cookie 中获取用户 ID
user_id = request.cookies.get("user_id")
# 注入到 request.state
if user_id:
user = await user_manager.get_user(user_id)
if user:
request.state.user_id = user_id
request.state.user = user
request.state.is_admin = user.is_admin
else:
# 用户不存在,清除状态
request.state.user_id = None
request.state.user = None
request.state.is_admin = False
else:
# 未登录
request.state.user_id = None
request.state.user = None
request.state.is_admin = False
# 继续处理请求
response = await call_next(request)
return response
+78
View File
@@ -0,0 +1,78 @@
"""请求追踪ID中间件"""
import uuid
import logging
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import Response
from typing import Callable
class RequestIDMiddleware(BaseHTTPMiddleware):
"""
请求追踪ID中间件
为每个请求生成唯一ID,并添加到日志上下文中
"""
async def dispatch(self, request: Request, call_next: Callable) -> Response:
"""
处理请求,添加追踪ID
Args:
request: 请求对象
call_next: 下一个处理器
Returns:
响应对象
"""
# 从请求头获取追踪ID,或生成新的
request_id = request.headers.get('X-Request-ID') or str(uuid.uuid4())
# 将请求ID存储到request.state中,方便后续访问
request.state.request_id = request_id
# 创建日志过滤器,自动添加request_id到日志记录
log_filter = RequestIDFilter(request_id)
# 获取根日志器并添加过滤器
root_logger = logging.getLogger()
root_logger.addFilter(log_filter)
try:
# 处理请求
response = await call_next(request)
# 将请求ID添加到响应头
response.headers['X-Request-ID'] = request_id
return response
finally:
# 移除过滤器,避免影响其他请求
root_logger.removeFilter(log_filter)
class RequestIDFilter(logging.Filter):
"""日志过滤器,为日志记录添加request_id属性"""
def __init__(self, request_id: str):
"""
初始化过滤器
Args:
request_id: 请求追踪ID
"""
super().__init__()
self.request_id = request_id
def filter(self, record: logging.LogRecord) -> bool:
"""
为日志记录添加request_id属性
Args:
record: 日志记录
Returns:
True(不过滤任何日志)
"""
record.request_id = self.request_id
return True
+26
View File
@@ -0,0 +1,26 @@
"""数据库模型"""
from app.models.project import Project
from app.models.outline import Outline
from app.models.character import Character
from app.models.chapter import Chapter
from app.models.generation_history import GenerationHistory
from app.models.settings import Settings
from app.models.relationship import (
RelationshipType,
CharacterRelationship,
Organization,
OrganizationMember
)
__all__ = [
"Project",
"Outline",
"Character",
"Chapter",
"GenerationHistory",
"Settings",
"RelationshipType",
"CharacterRelationship",
"Organization",
"OrganizationMember",
]
+24
View File
@@ -0,0 +1,24 @@
"""章节数据模型"""
from sqlalchemy import Column, String, Text, Integer, DateTime, ForeignKey
from sqlalchemy.sql import func
from app.database import Base
import uuid
class Chapter(Base):
"""章节表"""
__tablename__ = "chapters"
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
project_id = Column(String(36), ForeignKey("projects.id", ondelete="CASCADE"), nullable=False)
chapter_number = Column(Integer, nullable=False, comment="章节序号")
title = Column(String(200), nullable=False, comment="章节标题")
content = Column(Text, comment="章节内容")
summary = Column(Text, comment="章节摘要")
word_count = Column(Integer, default=0, comment="字数统计")
status = Column(String(20), default="draft", comment="章节状态")
created_at = Column(DateTime, server_default=func.now(), comment="创建时间")
updated_at = Column(DateTime, server_default=func.now(), onupdate=func.now(), comment="更新时间")
def __repr__(self):
return f"<Chapter(id={self.id}, chapter_number={self.chapter_number}, title={self.title})>"
+44
View File
@@ -0,0 +1,44 @@
"""角色数据模型"""
from sqlalchemy import Column, String, Text, DateTime, ForeignKey, Boolean
from sqlalchemy.sql import func
from app.database import Base
import uuid
class Character(Base):
"""角色表(包括角色和组织)"""
__tablename__ = "characters"
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
project_id = Column(String(36), ForeignKey("projects.id", ondelete="CASCADE"), nullable=False)
# 基本信息
name = Column(String(100), nullable=False, comment="角色/组织名称")
age = Column(String(20), comment="年龄")
gender = Column(String(20), comment="性别")
is_organization = Column(Boolean, default=False, comment="是否为组织")
# 角色类型:protagonist(主角)/supporting(配角)/antagonist(反派)
role_type = Column(String(50), comment="角色类型")
# 角色详细信息
personality = Column(Text, comment="性格特点/组织特性")
background = Column(Text, comment="背景故事")
appearance = Column(Text, comment="外貌描述")
relationships = Column(Text, comment="人物关系(JSON)")
# 组织特有字段
organization_type = Column(String(100), comment="组织类型")
organization_purpose = Column(String(500), comment="组织目的")
organization_members = Column(Text, comment="组织成员(JSON)")
# 其他
avatar_url = Column(String(500), comment="头像URL")
traits = Column(Text, comment="特征标签(JSON)")
created_at = Column(DateTime, server_default=func.now(), comment="创建时间")
updated_at = Column(DateTime, server_default=func.now(), onupdate=func.now(), comment="更新时间")
def __repr__(self):
entity_type = "组织" if self.is_organization else "角色"
return f"<Character(id={self.id}, name={self.name}, type={entity_type})>"
+23
View File
@@ -0,0 +1,23 @@
"""生成历史数据模型"""
from sqlalchemy import Column, String, Text, Integer, Float, DateTime, ForeignKey
from sqlalchemy.sql import func
from app.database import Base
import uuid
class GenerationHistory(Base):
"""生成历史表"""
__tablename__ = "generation_history"
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
project_id = Column(String(36), ForeignKey("projects.id", ondelete="CASCADE"), nullable=False)
chapter_id = Column(String(36), ForeignKey("chapters.id", ondelete="SET NULL"), nullable=True)
prompt = Column(Text, comment="使用的提示词")
generated_content = Column(Text, comment="生成的内容")
model = Column(String(50), comment="使用的模型")
tokens_used = Column(Integer, comment="消耗的token数")
generation_time = Column(Float, comment="生成耗时(秒)")
created_at = Column(DateTime, server_default=func.now(), comment="创建时间")
def __repr__(self):
return f"<GenerationHistory(id={self.id}, model={self.model})>"
+22
View File
@@ -0,0 +1,22 @@
"""大纲数据模型"""
from sqlalchemy import Column, String, Text, Integer, DateTime, ForeignKey
from sqlalchemy.sql import func
from app.database import Base
import uuid
class Outline(Base):
"""大纲表"""
__tablename__ = "outlines"
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
project_id = Column(String(36), ForeignKey("projects.id", ondelete="CASCADE"), nullable=False)
title = Column(String(200), nullable=False, comment="大纲标题")
content = Column(Text, comment="大纲内容")
structure = Column(Text, comment="结构化大纲数据(JSON)")
order_index = Column(Integer, comment="排序序号")
created_at = Column(DateTime, server_default=func.now(), comment="创建时间")
updated_at = Column(DateTime, server_default=func.now(), onupdate=func.now(), comment="更新时间")
def __repr__(self):
return f"<Outline(id={self.id}, title={self.title})>"
+38
View File
@@ -0,0 +1,38 @@
"""项目数据模型"""
from sqlalchemy import Column, String, Text, DateTime, Integer
from sqlalchemy.sql import func
from app.database import Base
import uuid
class Project(Base):
"""项目表"""
__tablename__ = "projects"
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
title = Column(String(200), nullable=False, comment="项目标题")
description = Column(Text, comment="项目简介")
theme = Column(Text, comment="主题")
genre = Column(String(50), comment="小说类型")
target_words = Column(Integer, default=0, comment="目标字数")
current_words = Column(Integer, default=0, comment="当前字数")
status = Column(String(20), default="planning", comment="创作状态")
wizard_status = Column(String(20), default="incomplete", comment="向导完成状态: incomplete/completed")
wizard_step = Column(Integer, default=0, comment="向导当前步骤: 0-4")
# 世界构建字段
world_time_period = Column(Text, comment="时间背景")
world_location = Column(Text, comment="地理位置")
world_atmosphere = Column(Text, comment="氛围基调")
world_rules = Column(Text, comment="世界规则")
# 项目配置
chapter_count = Column(Integer, comment="章节数量")
narrative_perspective = Column(String(50), comment="叙事视角:first_person/third_person/omniscient")
character_count = Column(Integer, default=5, comment="角色数量")
created_at = Column(DateTime, server_default=func.now(), comment="创建时间")
updated_at = Column(DateTime, server_default=func.now(), onupdate=func.now(), comment="更新时间")
def __repr__(self):
return f"<Project(id={self.id}, title={self.title})>"
+116
View File
@@ -0,0 +1,116 @@
"""角色关系和组织管理数据模型"""
from sqlalchemy import Column, String, Integer, Text, DateTime, ForeignKey, Boolean
from sqlalchemy.sql import func
from app.database import Base
import uuid
class RelationshipType(Base):
"""关系类型定义表"""
__tablename__ = "relationship_types"
id = Column(Integer, primary_key=True, index=True, autoincrement=True)
name = Column(String(50), nullable=False, comment="关系名称")
category = Column(String(20), nullable=False, comment="分类:family/social/hostile/professional")
reverse_name = Column(String(50), comment="反向关系名称")
intimacy_range = Column(String(20), comment="亲密度范围:high/medium/low")
icon = Column(String(50), comment="图标标识")
description = Column(Text, comment="关系描述")
created_at = Column(DateTime, server_default=func.now(), comment="创建时间")
def __repr__(self):
return f"<RelationshipType(id={self.id}, name={self.name}, category={self.category})>"
class CharacterRelationship(Base):
"""角色关系表"""
__tablename__ = "character_relationships"
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()), comment="关系ID")
project_id = Column(String(36), ForeignKey("projects.id", ondelete="CASCADE"), nullable=False, index=True, comment="项目ID")
# 关系双方
character_from_id = Column(String(36), ForeignKey("characters.id", ondelete="CASCADE"), nullable=False, index=True, comment="角色A的ID")
character_to_id = Column(String(36), ForeignKey("characters.id", ondelete="CASCADE"), nullable=False, index=True, comment="角色B的ID")
# 关系类型
relationship_type_id = Column(Integer, ForeignKey("relationship_types.id"), index=True, comment="关系类型ID")
relationship_name = Column(String(100), comment="自定义关系名称")
# 关系属性
intimacy_level = Column(Integer, default=50, comment="亲密度:0-100")
status = Column(String(20), default="active", comment="状态:active/broken/past/complicated")
description = Column(Text, comment="关系详细描述")
# 故事时间线
started_at = Column(String(100), comment="关系开始时间(故事时间)")
ended_at = Column(String(100), comment="关系结束时间(故事时间)")
# 来源标识
source = Column(String(20), default="ai", comment="来源:ai/manual/imported")
created_at = Column(DateTime, server_default=func.now(), comment="创建时间")
updated_at = Column(DateTime, server_default=func.now(), onupdate=func.now(), comment="更新时间")
def __repr__(self):
return f"<CharacterRelationship(id={self.id}, from={self.character_from_id}, to={self.character_to_id})>"
class Organization(Base):
"""组织详情表"""
__tablename__ = "organizations"
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()), comment="组织ID")
character_id = Column(String(36), ForeignKey("characters.id", ondelete="CASCADE"), nullable=False, unique=True, comment="关联的角色ID")
project_id = Column(String(36), ForeignKey("projects.id", ondelete="CASCADE"), nullable=False, index=True, comment="项目ID")
# 组织层级
parent_org_id = Column(String(36), ForeignKey("organizations.id", ondelete="SET NULL"), comment="父组织ID")
level = Column(Integer, default=0, comment="组织层级")
# 组织属性
power_level = Column(Integer, default=50, comment="势力等级:0-100")
member_count = Column(Integer, default=0, comment="成员数量")
location = Column(Text, comment="所在地")
# 组织特色
motto = Column(String(200), comment="宗旨/口号")
color = Column(String(20), comment="代表颜色")
created_at = Column(DateTime, server_default=func.now(), comment="创建时间")
updated_at = Column(DateTime, server_default=func.now(), onupdate=func.now(), comment="更新时间")
def __repr__(self):
return f"<Organization(id={self.id}, character_id={self.character_id})>"
class OrganizationMember(Base):
"""组织成员关系表"""
__tablename__ = "organization_members"
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()), comment="成员关系ID")
organization_id = Column(String(36), ForeignKey("organizations.id", ondelete="CASCADE"), nullable=False, index=True, comment="组织ID")
character_id = Column(String(36), ForeignKey("characters.id", ondelete="CASCADE"), nullable=False, index=True, comment="角色ID")
# 职位信息
position = Column(String(100), nullable=False, comment="职位名称")
rank = Column(Integer, default=0, comment="职位等级")
# 成员状态
status = Column(String(20), default="active", comment="状态:active/retired/expelled/deceased")
joined_at = Column(String(100), comment="加入时间(故事时间)")
left_at = Column(String(100), comment="离开时间(故事时间)")
# 成员属性
loyalty = Column(Integer, default=50, comment="忠诚度:0-100")
contribution = Column(Integer, default=0, comment="贡献度:0-100")
# 来源标识
source = Column(String(20), default="ai", comment="来源:ai/manual")
notes = Column(Text, comment="备注")
created_at = Column(DateTime, server_default=func.now(), comment="创建时间")
updated_at = Column(DateTime, server_default=func.now(), onupdate=func.now(), comment="更新时间")
def __repr__(self):
return f"<OrganizationMember(id={self.id}, org={self.organization_id}, char={self.character_id})>"
+24
View File
@@ -0,0 +1,24 @@
"""设置数据模型"""
from sqlalchemy import Column, String, Text, Float, Integer, DateTime
from sqlalchemy.sql import func
from app.database import Base
import uuid
class Settings(Base):
"""设置表"""
__tablename__ = "settings"
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
api_provider = Column(String(50), default="openai", comment="API提供商")
api_key = Column(String(500), comment="API密钥")
api_base_url = Column(String(500), comment="自定义API地址")
model_name = Column(String(100), default="gpt-4", comment="模型名称")
temperature = Column(Float, default=0.7, comment="温度参数")
max_tokens = Column(Integer, default=2000, comment="最大token数")
preferences = Column(Text, comment="其他偏好设置(JSON)")
created_at = Column(DateTime, server_default=func.now(), comment="创建时间")
updated_at = Column(DateTime, server_default=func.now(), onupdate=func.now(), comment="更新时间")
def __repr__(self):
return f"<Settings(id={self.id}, api_provider={self.api_provider})>"
+1
View File
@@ -0,0 +1 @@
"""Pydantic数据模型"""
+57
View File
@@ -0,0 +1,57 @@
"""章节相关的Pydantic模型"""
from pydantic import BaseModel, Field
from typing import Optional
from datetime import datetime
class ChapterBase(BaseModel):
"""章节基础模型"""
title: str = Field(..., description="章节标题")
chapter_number: int = Field(..., description="章节序号")
content: Optional[str] = Field(None, description="章节内容")
summary: Optional[str] = Field(None, description="章节摘要")
word_count: Optional[int] = Field(0, description="字数")
status: Optional[str] = Field("draft", description="章节状态")
class ChapterCreate(BaseModel):
"""创建章节的请求模型"""
project_id: str = Field(..., description="所属项目ID")
title: str = Field(..., description="章节标题")
chapter_number: int = Field(..., description="章节序号")
content: Optional[str] = Field(None, description="章节内容")
summary: Optional[str] = Field(None, description="章节摘要")
status: Optional[str] = Field("draft", description="章节状态")
class ChapterUpdate(BaseModel):
"""更新章节的请求模型"""
title: Optional[str] = None
content: Optional[str] = None
# chapter_number 不允许修改,只能通过大纲的重排序来调整
summary: Optional[str] = None
# word_count 自动计算,不允许手动修改
status: Optional[str] = None
class ChapterResponse(BaseModel):
"""章节响应模型"""
id: str
project_id: str
title: str
chapter_number: int
content: Optional[str] = None
summary: Optional[str] = None
word_count: int = 0
status: str
created_at: datetime
updated_at: datetime
class Config:
from_attributes = True
class ChapterListResponse(BaseModel):
"""章节列表响应模型"""
total: int
items: list[ChapterResponse]
+67
View File
@@ -0,0 +1,67 @@
"""角色相关的Pydantic模型"""
from pydantic import BaseModel, Field
from typing import Optional, List, Dict, Any
from datetime import datetime
class CharacterBase(BaseModel):
"""角色基础模型"""
name: str = Field(..., description="角色/组织姓名")
age: Optional[str] = Field(None, description="年龄")
gender: Optional[str] = Field(None, description="性别")
is_organization: bool = Field(False, description="是否为组织")
role_type: Optional[str] = Field(None, description="角色类型:protagonist/supporting/antagonist")
personality: Optional[str] = Field(None, description="性格特点/组织特性")
background: Optional[str] = Field(None, description="背景故事")
appearance: Optional[str] = Field(None, description="外貌特征")
relationships: Optional[str] = Field(None, description="人际关系(JSON)")
organization_type: Optional[str] = Field(None, description="组织类型")
organization_purpose: Optional[str] = Field(None, description="组织目的")
organization_members: Optional[str] = Field(None, description="组织成员(JSON)")
traits: Optional[str] = Field(None, description="特征标签(JSON)")
class CharacterUpdate(BaseModel):
"""更新角色的请求模型"""
name: Optional[str] = None
age: Optional[str] = None
gender: Optional[str] = None
is_organization: Optional[bool] = None
role_type: Optional[str] = None
personality: Optional[str] = None
background: Optional[str] = None
appearance: Optional[str] = None
relationships: Optional[str] = None
organization_type: Optional[str] = None
organization_purpose: Optional[str] = None
organization_members: Optional[str] = None
traits: Optional[str] = None
class CharacterResponse(CharacterBase):
"""角色响应模型"""
id: str
project_id: str
avatar_url: Optional[str] = None
created_at: datetime
updated_at: datetime
class Config:
from_attributes = True
class CharacterGenerateRequest(BaseModel):
"""AI生成角色的请求模型"""
project_id: str = Field(..., description="项目ID")
name: Optional[str] = Field(None, description="角色名称")
role_type: Optional[str] = Field(None, description="角色类型")
background: Optional[str] = Field(None, description="角色背景")
requirements: Optional[str] = Field(None, description="特殊要求")
provider: Optional[str] = Field(None, description="AI提供商")
model: Optional[str] = Field(None, description="AI模型")
class CharacterListResponse(BaseModel):
"""角色列表响应模型"""
total: int
items: List[CharacterResponse]
+88
View File
@@ -0,0 +1,88 @@
"""大纲相关的Pydantic模型"""
from pydantic import BaseModel, Field
from typing import Optional
from datetime import datetime
class OutlineBase(BaseModel):
"""大纲基础模型"""
title: str = Field(..., description="章节标题")
content: str = Field(..., description="章节内容概要")
class OutlineCreate(BaseModel):
"""创建大纲的请求模型"""
project_id: str = Field(..., description="所属项目ID")
title: str = Field(..., description="章节标题")
content: str = Field(..., description="章节内容概要")
order_index: int = Field(..., description="章节序号", ge=1)
structure: Optional[str] = Field(None, description="结构化大纲数据(JSON)")
class OutlineUpdate(BaseModel):
"""更新大纲的请求模型"""
title: Optional[str] = None
content: Optional[str] = None
# order_index 不允许通过普通更新修改,只能通过 reorder_outlines 接口批量调整
# structure 暂不支持修改
class OutlineResponse(BaseModel):
"""大纲响应模型"""
id: str
project_id: str
title: str
content: str
structure: Optional[str] = None
order_index: int
created_at: datetime
updated_at: datetime
class Config:
from_attributes = True
class OutlineGenerateRequest(BaseModel):
"""AI生成大纲的请求模型 - 支持全新生成和智能续写"""
project_id: str = Field(..., description="项目ID")
genre: Optional[str] = Field(None, description="小说类型,如:玄幻、都市、悬疑等")
theme: str = Field(..., description="小说主题")
chapter_count: int = Field(..., ge=1, description="章节数量")
narrative_perspective: str = Field(..., description="叙事视角")
world_context: Optional[dict] = Field(None, description="世界观背景")
characters_context: Optional[list] = Field(None, description="角色信息")
target_words: int = Field(100000, description="目标字数")
requirements: Optional[str] = Field(None, description="其他特殊要求")
provider: Optional[str] = Field(None, description="AI提供商")
model: Optional[str] = Field(None, description="AI模型")
# 续写相关参数
mode: str = Field("auto", description="生成模式: auto(自动判断), new(全新生成), continue(续写)")
story_direction: Optional[str] = Field(None, description="故事发展方向提示(续写时使用)")
plot_stage: str = Field("development", description="情节阶段: development(发展), climax(高潮), ending(结局)")
keep_existing: bool = Field(False, description="是否保留现有大纲(续写时)")
class ChapterOutlineGenerateRequest(BaseModel):
"""为单个章节生成大纲的请求模型"""
outline_id: str = Field(..., description="大纲ID")
context: Optional[str] = Field(None, description="额外上下文")
provider: Optional[str] = Field(None, description="AI提供商")
model: Optional[str] = Field(None, description="AI模型")
class OutlineListResponse(BaseModel):
"""大纲列表响应模型"""
total: int
items: list[OutlineResponse]
class OutlineReorderItem(BaseModel):
"""单个大纲重排序项"""
id: str = Field(..., description="大纲ID")
order_index: int = Field(..., description="新的序号", ge=1)
class OutlineReorderRequest(BaseModel):
"""大纲批量重排序请求"""
orders: list[OutlineReorderItem] = Field(..., description="排序列表")
+20
View File
@@ -0,0 +1,20 @@
"""AI去味相关的Pydantic模型"""
from pydantic import BaseModel, Field
from typing import Optional
class PolishRequest(BaseModel):
"""AI去味请求模型"""
original_text: str = Field(..., description="原始文本(AI生成的文本)")
project_id: Optional[int] = Field(None, description="项目ID(可选,用于记录历史)")
provider: Optional[str] = Field(None, description="AI提供商")
model: Optional[str] = Field(None, description="AI模型")
temperature: Optional[float] = Field(0.8, description="温度参数,建议0.7-0.9")
class PolishResponse(BaseModel):
"""AI去味响应模型"""
original_text: str = Field(..., description="原始文本")
polished_text: str = Field(..., description="去味后的文本")
word_count_before: int = Field(..., description="处理前字数")
word_count_after: int = Field(..., description="处理后字数")
+83
View File
@@ -0,0 +1,83 @@
"""项目相关的Pydantic模型"""
from pydantic import BaseModel, Field
from typing import Optional
from datetime import datetime
class ProjectBase(BaseModel):
"""项目基础模型"""
title: str = Field(..., description="项目标题")
description: Optional[str] = Field(None, description="项目描述")
theme: Optional[str] = Field(None, description="主题")
genre: Optional[str] = Field(None, description="小说类型")
target_words: Optional[int] = Field(None, description="目标字数")
class ProjectCreate(ProjectBase):
"""创建项目的请求模型"""
pass
class ProjectUpdate(BaseModel):
"""更新项目的请求模型"""
title: Optional[str] = None
description: Optional[str] = None
theme: Optional[str] = None
genre: Optional[str] = None
target_words: Optional[int] = None
status: Optional[str] = None
# wizard_status 和 wizard_step 只能通过向导API修改,普通更新不允许
world_time_period: Optional[str] = None
world_location: Optional[str] = None
world_atmosphere: Optional[str] = None
world_rules: Optional[str] = None
chapter_count: Optional[int] = None
narrative_perspective: Optional[str] = None
character_count: Optional[int] = None
# current_words 由章节内容自动计算,不允许手动修改
class ProjectResponse(ProjectBase):
"""项目响应模型"""
id: str # UUID字符串
status: str
current_words: int
wizard_status: Optional[str] = None
wizard_step: Optional[int] = None
world_time_period: Optional[str] = None
world_location: Optional[str] = None
world_atmosphere: Optional[str] = None
world_rules: Optional[str] = None
chapter_count: Optional[int] = None
narrative_perspective: Optional[str] = None
character_count: Optional[int] = None
created_at: datetime
updated_at: datetime
class Config:
from_attributes = True
class ProjectListResponse(BaseModel):
"""项目列表响应模型"""
total: int
items: list[ProjectResponse]
class ProjectWizardRequest(BaseModel):
"""项目创建向导请求模型"""
title: str = Field(..., description="书名")
theme: str = Field(..., description="主题")
genre: Optional[str] = Field(None, description="类型")
chapter_count: int = Field(..., ge=1, description="章节数量")
narrative_perspective: str = Field(..., description="叙事视角")
character_count: int = Field(5, ge=5, description="角色数量(至少5个)")
target_words: Optional[int] = Field(None, description="目标字数")
class WorldBuildingResponse(BaseModel):
"""世界构建响应模型"""
time_period: str = Field(..., description="时间背景")
location: str = Field(..., description="地理位置")
atmosphere: str = Field(..., description="氛围基调")
rules: str = Field(..., description="世界规则")
+204
View File
@@ -0,0 +1,204 @@
"""关系管理相关的Pydantic模型"""
from pydantic import BaseModel, Field
from typing import Optional, List
from datetime import datetime
# ============ 关系类型相关 ============
class RelationshipTypeResponse(BaseModel):
"""关系类型响应模型"""
id: int
name: str
category: str
reverse_name: Optional[str] = None
intimacy_range: Optional[str] = None
icon: Optional[str] = None
description: Optional[str] = None
created_at: datetime
class Config:
from_attributes = True
# ============ 角色关系相关 ============
class CharacterRelationshipBase(BaseModel):
"""角色关系基础模型"""
relationship_type_id: Optional[int] = Field(None, description="关系类型ID")
relationship_name: Optional[str] = Field(None, description="自定义关系名称")
intimacy_level: int = Field(50, ge=0, le=100, description="亲密度:0-100")
status: str = Field("active", description="状态:active/broken/past/complicated")
description: Optional[str] = Field(None, description="关系描述")
started_at: Optional[str] = Field(None, description="关系开始时间(故事时间)")
ended_at: Optional[str] = Field(None, description="关系结束时间")
class CharacterRelationshipCreate(CharacterRelationshipBase):
"""创建角色关系的请求模型"""
project_id: str = Field(..., description="项目ID")
character_from_id: str = Field(..., description="角色A的ID")
character_to_id: str = Field(..., description="角色B的ID")
class CharacterRelationshipUpdate(BaseModel):
"""更新角色关系的请求模型"""
relationship_type_id: Optional[int] = None
relationship_name: Optional[str] = None
intimacy_level: Optional[int] = Field(None, ge=0, le=100)
status: Optional[str] = None
description: Optional[str] = None
started_at: Optional[str] = None
ended_at: Optional[str] = None
class CharacterRelationshipResponse(CharacterRelationshipBase):
"""角色关系响应模型"""
id: str
project_id: str
character_from_id: str
character_to_id: str
source: str
created_at: datetime
updated_at: datetime
class Config:
from_attributes = True
class RelationshipGraphNode(BaseModel):
"""关系图谱节点"""
id: str
name: str
type: str # character / organization
role_type: Optional[str] = None
avatar: Optional[str] = None
class RelationshipGraphLink(BaseModel):
"""关系图谱连线"""
source: str
target: str
relationship: str
intimacy: int
status: str
class RelationshipGraphData(BaseModel):
"""关系图谱数据"""
nodes: List[RelationshipGraphNode]
links: List[RelationshipGraphLink]
# ============ 组织相关 ============
class OrganizationBase(BaseModel):
"""组织基础模型"""
parent_org_id: Optional[str] = Field(None, description="父组织ID")
level: int = Field(0, description="组织层级")
power_level: int = Field(50, ge=0, le=100, description="势力等级")
location: Optional[str] = Field(None, description="所在地")
motto: Optional[str] = Field(None, description="组织宗旨")
color: Optional[str] = Field(None, description="代表颜色")
class OrganizationCreate(OrganizationBase):
"""创建组织的请求模型"""
character_id: str = Field(..., description="关联的角色ID(组织记录)")
project_id: str = Field(..., description="项目ID")
class OrganizationUpdate(BaseModel):
"""更新组织的请求模型"""
parent_org_id: Optional[str] = None
level: Optional[int] = None
power_level: Optional[int] = Field(None, ge=0, le=100)
location: Optional[str] = None
motto: Optional[str] = None
color: Optional[str] = None
class OrganizationResponse(OrganizationBase):
"""组织响应模型"""
id: str
character_id: str
project_id: str
member_count: int
created_at: datetime
updated_at: datetime
class Config:
from_attributes = True
class OrganizationDetailResponse(BaseModel):
"""组织详情响应(包含基本信息)"""
id: str
character_id: str
name: str
type: Optional[str] = None
purpose: Optional[str] = None
member_count: int
power_level: int
location: Optional[str] = None
motto: Optional[str] = None
color: Optional[str] = None
# ============ 组织成员相关 ============
class OrganizationMemberBase(BaseModel):
"""组织成员基础模型"""
position: str = Field(..., description="职位名称")
rank: int = Field(0, description="职位等级")
status: str = Field("active", description="状态:active/retired/expelled/deceased")
joined_at: Optional[str] = Field(None, description="加入时间(故事时间)")
left_at: Optional[str] = Field(None, description="离开时间")
loyalty: int = Field(50, ge=0, le=100, description="忠诚度")
contribution: int = Field(0, ge=0, le=100, description="贡献度")
notes: Optional[str] = Field(None, description="备注")
class OrganizationMemberCreate(OrganizationMemberBase):
"""创建组织成员的请求模型"""
character_id: str = Field(..., description="角色ID")
class OrganizationMemberUpdate(BaseModel):
"""更新组织成员的请求模型"""
position: Optional[str] = None
rank: Optional[int] = None
status: Optional[str] = None
joined_at: Optional[str] = None
left_at: Optional[str] = None
loyalty: Optional[int] = Field(None, ge=0, le=100)
contribution: Optional[int] = Field(None, ge=0, le=100)
notes: Optional[str] = None
class OrganizationMemberResponse(OrganizationMemberBase):
"""组织成员响应模型"""
id: str
organization_id: str
character_id: str
source: str
created_at: datetime
updated_at: datetime
class Config:
from_attributes = True
class OrganizationMemberDetailResponse(BaseModel):
"""组织成员详情响应(包含角色信息)"""
id: str
character_id: str
character_name: str
position: str
rank: int
loyalty: int
contribution: int
status: str
joined_at: Optional[str] = None
left_at: Optional[str] = None
notes: Optional[str] = None
+1
View File
@@ -0,0 +1 @@
"""服务层模块"""
+363
View File
@@ -0,0 +1,363 @@
"""AI服务封装 - 统一的OpenAI和Claude接口"""
from typing import Optional, AsyncGenerator, List, Dict, Any
from openai import AsyncOpenAI
from anthropic import AsyncAnthropic
from app.config import settings
from app.logger import get_logger
import httpx
logger = get_logger(__name__)
class AIService:
"""AI服务统一接口"""
def __init__(self):
"""初始化AI客户端(优化并发性能)"""
# 初始化OpenAI客户端
if settings.openai_api_key:
# 创建自定义的httpx客户端来避免proxies参数问题
try:
# 配置连接池限制,支持高并发
# max_keepalive_connections: 保持活跃的连接数(提高复用率)
# max_connections: 最大并发连接数(防止资源耗尽)
limits = httpx.Limits(
max_keepalive_connections=50, # 保持50个活跃连接
max_connections=100, # 最多100个并发连接
keepalive_expiry=30.0 # 30秒后过期未使用的连接
)
# 使用httpx.AsyncClient并设置超时和连接池
# connect: 连接超时10秒
# read: 读取超时180秒(3分钟,适合长文本生成)
# write: 写入超时10秒
# pool: 连接池超时10秒
http_client = httpx.AsyncClient(
timeout=httpx.Timeout(
connect=10.0,
read=180.0,
write=10.0,
pool=10.0
),
limits=limits
)
client_kwargs = {
"api_key": settings.openai_api_key,
"http_client": http_client
}
if settings.openai_base_url:
client_kwargs["base_url"] = settings.openai_base_url
self.openai_client = AsyncOpenAI(**client_kwargs)
logger.info("✅ OpenAI客户端初始化成功")
logger.info(" - 超时设置:连接10s,读取180s")
logger.info(" - 连接池:50个保活连接,最大100个并发")
except Exception as e:
logger.error(f"OpenAI客户端初始化失败: {e}")
self.openai_client = None
else:
self.openai_client = None
logger.warning("OpenAI API key未配置")
# 初始化Anthropic客户端
if settings.anthropic_api_key:
try:
# 为Anthropic设置相同的超时和连接池配置
limits = httpx.Limits(
max_keepalive_connections=50,
max_connections=100,
keepalive_expiry=30.0
)
http_client = httpx.AsyncClient(
timeout=httpx.Timeout(
connect=10.0,
read=180.0,
write=10.0,
pool=10.0
),
limits=limits
)
client_kwargs = {
"api_key": settings.anthropic_api_key,
"http_client": http_client
}
if settings.anthropic_base_url:
client_kwargs["base_url"] = settings.anthropic_base_url
self.anthropic_client = AsyncAnthropic(**client_kwargs)
logger.info("✅ Anthropic客户端初始化成功")
logger.info(" - 超时设置:连接10s,读取180s")
logger.info(" - 连接池:50个保活连接,最大100个并发")
except Exception as e:
logger.error(f"Anthropic客户端初始化失败: {e}")
self.anthropic_client = None
else:
self.anthropic_client = None
logger.warning("Anthropic API key未配置")
async def generate_text(
self,
prompt: str,
provider: Optional[str] = None,
model: Optional[str] = None,
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
system_prompt: Optional[str] = None
) -> str:
"""
生成文本
Args:
prompt: 用户提示词
provider: AI提供商 (openai/anthropic)
model: 模型名称
temperature: 温度参数
max_tokens: 最大token数
system_prompt: 系统提示词
Returns:
生成的文本
"""
provider = provider or settings.default_ai_provider
model = model or settings.default_model
temperature = temperature or settings.default_temperature
max_tokens = max_tokens or settings.default_max_tokens
if provider == "openai":
return await self._generate_openai(
prompt, model, temperature, max_tokens, system_prompt
)
elif provider == "anthropic":
return await self._generate_anthropic(
prompt, model, temperature, max_tokens, system_prompt
)
else:
raise ValueError(f"不支持的AI提供商: {provider}")
async def generate_text_stream(
self,
prompt: str,
provider: Optional[str] = None,
model: Optional[str] = None,
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
system_prompt: Optional[str] = None
) -> AsyncGenerator[str, None]:
"""
流式生成文本
Args:
prompt: 用户提示词
provider: AI提供商
model: 模型名称
temperature: 温度参数
max_tokens: 最大token数
system_prompt: 系统提示词
Yields:
生成的文本片段
"""
provider = provider or settings.default_ai_provider
model = model or settings.default_model
temperature = temperature or settings.default_temperature
max_tokens = max_tokens or settings.default_max_tokens
if provider == "openai":
async for chunk in self._generate_openai_stream(
prompt, model, temperature, max_tokens, system_prompt
):
yield chunk
elif provider == "anthropic":
async for chunk in self._generate_anthropic_stream(
prompt, model, temperature, max_tokens, system_prompt
):
yield chunk
else:
raise ValueError(f"不支持的AI提供商: {provider}")
async def _generate_openai(
self,
prompt: str,
model: str,
temperature: float,
max_tokens: int,
system_prompt: Optional[str]
) -> str:
"""使用OpenAI生成文本"""
if not self.openai_client:
raise ValueError("OpenAI客户端未初始化,请检查API key配置")
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
messages.append({"role": "user", "content": prompt})
try:
logger.info(f"🔵 开始调用OpenAI API")
logger.info(f" - 模型: {model}")
logger.info(f" - 温度: {temperature}")
logger.info(f" - 最大tokens: {max_tokens}")
logger.info(f" - Prompt长度: {len(prompt)} 字符")
logger.info(f" - 消息数量: {len(messages)}")
response = await self.openai_client.chat.completions.create(
model=model,
messages=messages,
temperature=temperature,
max_tokens=max_tokens
)
logger.info(f"✅ OpenAI API调用成功")
logger.info(f" - 响应ID: {response.id if hasattr(response, 'id') else 'N/A'}")
logger.info(f" - 选项数量: {len(response.choices)}")
if not response.choices:
logger.error("❌ OpenAI返回的choices为空")
return ""
content = response.choices[0].message.content
logger.info(f" - 返回内容长度: {len(content) if content else 0} 字符")
if content:
logger.info(f" - 返回内容预览(前200字符): {content[:200]}")
return content
else:
logger.error("❌ OpenAI返回了空内容")
logger.error(f" - 完整响应: {response}")
raise ValueError("AI返回了空内容,请检查API配置或稍后重试")
except Exception as e:
logger.error(f"❌ OpenAI API调用失败")
logger.error(f" - 错误类型: {type(e).__name__}")
logger.error(f" - 错误信息: {str(e)}")
logger.error(f" - 模型: {model}")
raise
async def _generate_openai_stream(
self,
prompt: str,
model: str,
temperature: float,
max_tokens: int,
system_prompt: Optional[str]
) -> AsyncGenerator[str, None]:
"""使用OpenAI流式生成文本"""
if not self.openai_client:
raise ValueError("OpenAI客户端未初始化,请检查API key配置")
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
messages.append({"role": "user", "content": prompt})
try:
logger.info(f"🔵 开始调用OpenAI流式API")
logger.info(f" - 模型: {model}")
logger.info(f" - Prompt长度: {len(prompt)} 字符")
logger.info(f" - 最大tokens: {max_tokens}")
stream = await self.openai_client.chat.completions.create(
model=model,
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
stream=True
)
logger.info(f"✅ OpenAI流式API连接成功,开始接收数据...")
chunk_count = 0
async for chunk in stream:
if chunk.choices and len(chunk.choices) > 0:
if chunk.choices[0].delta.content:
chunk_count += 1
yield chunk.choices[0].delta.content
logger.info(f"✅ OpenAI流式生成完成,共接收 {chunk_count} 个chunk")
except httpx.TimeoutException as e:
logger.error(f"❌ OpenAI流式API超时")
logger.error(f" - 错误: {str(e)}")
logger.error(f" - 提示: 请检查网络连接或考虑缩短prompt长度")
raise TimeoutError(f"AI服务超时(180秒),请稍后重试或减少上下文长度") from e
except Exception as e:
logger.error(f"❌ OpenAI流式API调用失败: {str(e)}")
logger.error(f" - 错误类型: {type(e).__name__}")
raise
async def _generate_anthropic(
self,
prompt: str,
model: str,
temperature: float,
max_tokens: int,
system_prompt: Optional[str]
) -> str:
"""使用Anthropic生成文本"""
if not self.anthropic_client:
raise ValueError("Anthropic客户端未初始化,请检查API key配置")
try:
response = await self.anthropic_client.messages.create(
model=model,
max_tokens=max_tokens,
temperature=temperature,
system=system_prompt or "",
messages=[{"role": "user", "content": prompt}]
)
return response.content[0].text
except Exception as e:
logger.error(f"Anthropic API调用失败: {str(e)}")
raise
async def _generate_anthropic_stream(
self,
prompt: str,
model: str,
temperature: float,
max_tokens: int,
system_prompt: Optional[str]
) -> AsyncGenerator[str, None]:
"""使用Anthropic流式生成文本"""
if not self.anthropic_client:
raise ValueError("Anthropic客户端未初始化,请检查API key配置")
try:
logger.info(f"🔵 开始调用Anthropic流式API")
logger.info(f" - 模型: {model}")
logger.info(f" - Prompt长度: {len(prompt)} 字符")
logger.info(f" - 最大tokens: {max_tokens}")
async with self.anthropic_client.messages.stream(
model=model,
max_tokens=max_tokens,
temperature=temperature,
system=system_prompt or "",
messages=[{"role": "user", "content": prompt}]
) as stream:
logger.info(f"✅ Anthropic流式API连接成功,开始接收数据...")
chunk_count = 0
async for text in stream.text_stream:
chunk_count += 1
yield text
logger.info(f"✅ Anthropic流式生成完成,共接收 {chunk_count} 个chunk")
except httpx.TimeoutException as e:
logger.error(f"❌ Anthropic流式API超时")
logger.error(f" - 错误: {str(e)}")
raise TimeoutError(f"AI服务超时(180秒),请稍后重试或减少上下文长度") from e
except Exception as e:
logger.error(f"❌ Anthropic流式API调用失败: {str(e)}")
logger.error(f" - 错误类型: {type(e).__name__}")
raise
# 创建全局AI服务实例
ai_service = AIService()
+149
View File
@@ -0,0 +1,149 @@
"""
LinuxDO OAuth2 服务
"""
import httpx
import secrets
from typing import Optional, Dict, Any
from app.config import settings
class LinuxDOOAuthService:
"""LinuxDO OAuth2 服务类"""
# LinuxDO OAuth2 端点
AUTHORIZE_URL = "https://connect.linux.do/oauth2/authorize"
TOKEN_URL = "https://connect.linux.do/oauth2/token"
USERINFO_URL = "https://connect.linux.do/api/user" # 修复:使用正确的用户信息端点
def __init__(self):
self.client_id = settings.LINUXDO_CLIENT_ID
self.client_secret = settings.LINUXDO_CLIENT_SECRET
self.redirect_uri = settings.LINUXDO_REDIRECT_URI
# 验证redirect_uri配置
if not self.redirect_uri:
raise ValueError(
"LINUXDO_REDIRECT_URI 未配置!\n"
"请在 .env 文件中设置正确的回调地址:\n"
"本地开发: LINUXDO_REDIRECT_URI=http://localhost:8000/api/auth/callback\n"
"Docker部署: LINUXDO_REDIRECT_URI=https://your-domain.com/api/auth/callback"
)
# 警告:检查是否使用了localhost(在非开发环境)
if not settings.debug and "localhost" in self.redirect_uri.lower():
import logging
logger = logging.getLogger(__name__)
logger.warning(
f"⚠️ 生产环境检测到使用 localhost 作为回调地址: {self.redirect_uri}\n"
"这可能导致OAuth回调失败!请使用实际的域名或服务器IP。"
)
def generate_state(self) -> str:
"""生成随机 state 参数"""
return secrets.token_urlsafe(32)
def get_authorization_url(self, state: str) -> str:
"""
获取授权 URL
Args:
state: 随机 state 参数
Returns:
授权 URL
"""
params = {
"client_id": self.client_id,
"redirect_uri": self.redirect_uri,
"response_type": "code",
"scope": "read",
"state": state
}
query_string = "&".join([f"{k}={v}" for k, v in params.items()])
return f"{self.AUTHORIZE_URL}?{query_string}"
async def get_access_token(self, code: str) -> Optional[Dict[str, Any]]:
"""
使用授权码获取访问令牌
Args:
code: 授权码
Returns:
包含 access_token 的字典,失败返回 None
"""
data = {
"client_id": self.client_id,
"client_secret": self.client_secret,
"code": code,
"grant_type": "authorization_code",
"redirect_uri": self.redirect_uri
}
try:
async with httpx.AsyncClient() as client:
response = await client.post(
self.TOKEN_URL,
data=data,
headers={"Content-Type": "application/x-www-form-urlencoded"}
)
if response.status_code == 200:
return response.json()
else:
print(f"获取访问令牌失败: {response.status_code} {response.text}")
return None
except Exception as e:
print(f"获取访问令牌异常: {e}")
return None
async def get_user_info(self, access_token: str) -> Optional[Dict[str, Any]]:
"""
使用访问令牌获取用户信息
Args:
access_token: 访问令牌
Returns:
用户信息字典,失败返回 None
"""
try:
# 添加真实浏览器请求头,避免被 Cloudflare 拦截
headers = {
"Authorization": f"Bearer {access_token}",
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36",
"Accept": "application/json",
"Accept-Language": "zh-CN,zh;q=0.9,en;q=0.8",
}
# 不自动处理编码,让 httpx 自动解压
async with httpx.AsyncClient(follow_redirects=True, timeout=30.0) as client:
response = await client.get(
self.USERINFO_URL,
headers=headers
)
print(f"获取用户信息响应状态: {response.status_code}")
print(f"响应头: {response.headers}")
if response.status_code == 200:
try:
user_data = response.json()
print(f"用户信息: {user_data}")
return user_data
except Exception as json_error:
print(f"解析 JSON 失败: {json_error}")
print(f"响应内容前100字符: {response.text[:100]}")
return None
else:
print(f"获取用户信息失败: {response.status_code}")
print(f"响应内容: {response.text[:200]}")
return None
except Exception as e:
print(f"获取用户信息异常: {type(e).__name__}: {str(e)}")
import traceback
traceback.print_exc()
return None
+730
View File
@@ -0,0 +1,730 @@
"""提示词管理服务"""
from typing import Dict, Any
import json
class PromptService:
"""提示词模板管理"""
# 世界构建提示词
WORLD_BUILDING = """你是一位资深的世界观设计师。请根据以下信息构建一个完整的小说世界观:
书名:{title}
主题:{theme}
类型:{genre}
请生成包含以下内容的世界构建框架:
1. **时间背景**:具体的时代设定、时间流逝特点、重要历史事件
2. **地理位置**:主要地点描述、地理环境特征、空间布局
3. **氛围基调**:整体氛围感觉、情感色彩、视觉风格
4. **世界规则**:基本运行法则、特殊设定、社会规则和禁忌、权力结构
要求:
- 与主题高度契合
- 设定要合理自洽
- 为故事发展提供支撑
- 具有独特性和吸引力
**重要:你必须只返回纯JSON格式,不要包含任何markdown标记、代码块标记或其他说明文字。**
请严格按照以下JSON格式返回(每个字段为200-300字的文本描述):
{{
"time_period": "时间背景的详细描述,包括时代设定、时间特点、历史事件",
"location": "地理位置的详细描述,包括主要地点、环境特征、空间布局",
"atmosphere": "氛围基调的详细描述,包括整体氛围、情感色彩、视觉风格",
"rules": "世界规则的详细描述,包括运行法则、特殊设定、社会规则、权力结构"
}}
再次强调:只返回纯JSON对象,不要有```json```这样的标记,不要有任何额外的文字说明。"""
# 批量角色生成提示词
CHARACTERS_BATCH_GENERATION = """你是一位专业的角色设定师。请根据以下世界观和要求,生成{count}个立体丰满的角色和组织:
世界观信息:
- 时间背景:{time_period}
- 地理位置:{location}
- 氛围基调:{atmosphere}
- 世界规则:{rules}
主题:{theme}
类型:{genre}
特殊要求:{requirements}
【数量要求 - 必须严格遵守】
请精确生成{count}个实体,不多不少。数组中必须包含且仅包含{count}个对象。
实体类型分配:
- 至少1个主角(protagonist
- 多个配角(supporting
- 可以包含反派(antagonist
- 可以包含1-2个重要组织
要求:
- 角色要符合世界观设定
- 性格和背景要有深度
- 角色之间要有关系网络
- 组织要有存在的合理性
- 所有实体要为故事服务
**重要:你必须只返回纯JSON数组格式,不要包含任何markdown标记、代码块标记或其他说明文字。**
请严格按照以下JSON数组格式返回(每个角色为数组中的一个对象):
[
{{
"name": "角色姓名",
"age": 25,
"gender": "男/女/其他",
"is_organization": false,
"role_type": "protagonist/supporting/antagonist",
"personality": "性格特点的详细描述(100-200字),包括核心性格、优缺点、特殊习惯",
"background": "背景故事的详细描述(100-200字),包括家庭背景、成长经历、重要转折",
"appearance": "外貌描述(50-100字),包括身高、体型、面容、着装风格",
"traits": ["特长1", "特长2", "特长3"],
"relationships_array": [
{{
"target_character_name": "已生成的角色名称",
"relationship_type": "关系类型(师父/朋友/敌人/父亲/母亲等)",
"intimacy_level": 75,
"description": "关系描述"
}}
],
"organization_memberships": [
{{
"organization_name": "已生成的组织名称",
"position": "职位",
"rank": 5,
"loyalty": 80
}}
]
}},
{{
"name": "组织名称",
"is_organization": true,
"role_type": "supporting",
"personality": "组织特性描述(100-200字),包括运作方式、核心理念、行事风格",
"background": "组织背景(100-200字),包括建立历史、发展历程、重要事件",
"appearance": "组织外在表现(50-100字),如总部位置、标志性建筑等",
"organization_type": "组织类型",
"organization_purpose": "组织目的",
"organization_members": ["成员1", "成员2"],
"traits": []
}}
]
**关系类型参考(从中选择或自定义):**
- 家族:父亲、母亲、兄弟、姐妹、子女、配偶、恋人
- 社交:师父、徒弟、朋友、同学、同事、邻居、知己
- 职业:上司、下属、合作伙伴
- 敌对:敌人、仇人、竞争对手、宿敌
**重要说明:**
1. **数量控制**:数组中必须精确包含{count}个对象,不能多也不能少
2. **关系约束**relationships_array只能引用本批次中已经出现的角色名称
3. **组织约束**organization_memberships只能引用本批次中is_organization=true的实体名称
4. **禁止幻觉**:不要引用任何不存在的角色或组织,如果没有可引用的就留空数组[]
5. intimacy_level和loyalty都是0-100的整数
6. 角色之间要形成合理的关系网络
**示例说明**
- 如果生成了角色A、组织B、角色C,则角色A的organization_memberships只能是[组织B],不能是其他组织
- 如果角色A在数组第一位,它的relationships_array必须为空[],因为还没有其他角色
- 如果角色C在数组第三位,它的relationships_array可以引用角色A,但不能引用不存在的角色D
再次强调:
1. 只返回纯JSON数组,不要有```json```这样的标记
2. 数组中必须精确包含{count}个对象
3. 不要引用任何本批次中不存在的角色或组织名称"""
# 完整大纲生成提示词
COMPLETE_OUTLINE_GENERATION = """你是一位经验丰富的小说作家和编剧。请根据以下信息生成完整的{chapter_count}章小说大纲:
基本信息:
- 书名:{title}
- 主题:{theme}
- 类型:{genre}
- 章节数:{chapter_count}
- 叙事视角:{narrative_perspective}
- 目标字数:{target_words}
世界观:
- 时间背景:{time_period}
- 地理位置:{location}
- 氛围基调:{atmosphere}
- 世界规则:{rules}
角色信息:
{characters_info}
其他要求:{requirements}
整体要求:
- 结构完整:起承转合清晰
- 情节连贯:章节之间紧密衔接
- 冲突递进:矛盾逐步升级
- 人物成长:角色有明确的变化弧线
- 节奏把控:有张有弛
- 视角统一:采用{narrative_perspective}视角叙事
**重要:你必须只返回纯JSON数组格式,不要包含任何markdown标记、代码块标记或其他说明文字。**
请严格按照以下JSON数组格式返回(共{chapter_count}个章节对象):
[
{{
"chapter_number": 1,
"title": "第一章标题",
"summary": "章节概要的详细描述(100-200字),包含主要情节、冲突、转折等",
"scenes": ["场景1描述", "场景2描述", "场景3描述"],
"characters": ["角色1", "角色2"],
"key_points": ["情节要点1", "情节要点2"],
"emotion": "本章情感基调",
"goal": "本章叙事目标"
}},
{{
"chapter_number": 2,
"title": "第二章标题",
"summary": "章节概要...",
"scenes": ["场景1", "场景2"],
"characters": ["角色1", "角色2"],
"key_points": ["要点1", "要点2"],
"emotion": "情感基调",
"goal": "叙事目标"
}}
]
再次强调:只返回纯JSON数组,不要有```json```这样的标记,不要有任何额外的文字说明。数组中要包含{chapter_count}个章节对象。"""
# 大纲续写提示词
OUTLINE_CONTINUE_GENERATION = """你是一位经验丰富的小说作家和编剧。请基于以下信息续写小说大纲:
【项目信息】
- 书名:{title}
- 主题:{theme}
- 类型:{genre}
- 叙事视角:{narrative_perspective}
- 续写章节数:{chapter_count}
【世界观】
- 时间背景:{time_period}
- 地理位置:{location}
- 氛围基调:{atmosphere}
- 世界规则:{rules}
【角色信息】
{characters_info}
【已有章节概览】(共{current_chapter_count}章)
{all_chapters_brief}
【最近剧情】
{recent_plot}
【续写指导】
- 当前情节阶段:{plot_stage_instruction}
- 起始章节编号:第{start_chapter}
- 故事发展方向:{story_direction}
- 其他要求:{requirements}
请生成第{start_chapter}章到第{end_chapter}章的大纲。
要求:
- 与前文自然衔接,保持故事连贯性
- 遵循情节阶段的发展要求
- 保持与已有章节相同的风格和详细程度
- 推进角色成长和情节发展
**重要:你必须只返回纯JSON数组格式,不要包含任何markdown标记、代码块标记或其他说明文字。**
请严格按照以下JSON数组格式返回(共{chapter_count}个章节对象):
[
{{
"chapter_number": {start_chapter},
"title": "章节标题",
"summary": "章节概要的详细描述(100-200字),包含主要情节、角色互动、关键事件、冲突与转折",
"scenes": ["场景1描述", "场景2描述", "场景3描述"],
"characters": ["涉及角色1", "涉及角色2"],
"key_points": ["情节要点1", "情节要点2"],
"emotion": "本章情感基调",
"goal": "本章叙事目标"
}},
{{
"chapter_number": {start_chapter} + 1,
"title": "章节标题",
"summary": "章节概要...",
"scenes": ["场景1", "场景2"],
"characters": ["角色1", "角色2"],
"key_points": ["要点1", "要点2"],
"emotion": "情感基调",
"goal": "叙事目标"
}}
]
再次强调:
1. 只返回纯JSON数组,不要有```json```这样的标记
2. 数组中要包含{chapter_count}个章节对象
3. 每个summary必须是100-200字的详细描述
4. 确保字段结构与已有章节完全一致"""
# AI去味提示词(核心特色功能)
AI_DENOISING = """你是一位追求自然写作风格的编辑。你的任务是将AI生成的文本改写得更像人类作家的手笔。
原文:
{original_text}
修改要求:
1. 去除AI痕迹:
- 删除过于工整的排比句
- 减少重复的修辞手法
- 去掉刻意的对称结构
- 避免机械式的总结陈词
2. 增加人性化:
- 使用更口语化的表达
- 添加不完美的细节
- 保留适度的随意性
- 增加真实的情感波动
3. 优化叙事:
- 让节奏更自然不做作
- 用简单词汇替换华丽辞藻
- 保持叙述的松弛感
- 让对话更生活化
4. 保持原意:
- 不改变核心情节
- 保留关键信息点
- 维持角色性格
- 确保逻辑连贯
修改风格:
- 像是一个喜欢讲故事的普通人写的
- 有点粗糙但很真诚
- 自然流畅不刻意
- 让人读起来很舒服
请直接输出修改后的文本,无需解释。"""
# 章节完整创作提示词
CHAPTER_GENERATION = """你是一位专业的小说作家。请根据以下信息创作本章内容:
项目信息:
- 书名:{title}
- 主题:{theme}
- 类型:{genre}
- 叙事视角:{narrative_perspective}
世界观:
- 时间背景:{time_period}
- 地理位置:{location}
- 氛围基调:{atmosphere}
- 世界规则:{rules}
角色信息:
{characters_info}
全书大纲:
{outlines_context}
本章信息:
- 章节序号:第{chapter_number}
- 章节标题:{chapter_title}
- 章节大纲:{chapter_outline}
创作要求:
1. 严格按照大纲内容展开情节
2. 保持与前后章节的连贯性
3. 符合角色性格设定
4. 体现世界观特色
5. 使用{narrative_perspective}视角
6. 字数不得低于3000字
7. 语言自然流畅,避免AI痕迹
**写作风格要求(重要):**
- 让故事自然流淌,写到哪算哪
- 结尾处直接结束情节,不要加总结性段落
- 不要在章节末尾写"这一天/这一夜就这样过去了"之类的总结句
- 不要用"他/她陷入了沉思"作为结尾
- 避免刻意的情感升华或哲理感悟收尾
- 章节结尾可以戛然而止,可以是对话,可以是动作,可以是悬念
- 就像在讲一个故事,讲完了就停,不需要画龙点睛
请直接输出章节正文内容,不要包含章节标题和其他说明文字。"""
# 章节完整创作提示词(带前置章节上下文)
CHAPTER_GENERATION_WITH_CONTEXT = """你是一位专业的小说作家。请根据以下信息创作本章内容:
项目信息:
- 书名:{title}
- 主题:{theme}
- 类型:{genre}
- 叙事视角:{narrative_perspective}
世界观:
- 时间背景:{time_period}
- 地理位置:{location}
- 氛围基调:{atmosphere}
- 世界规则:{rules}
角色信息:
{characters_info}
全书大纲:
{outlines_context}
【已完成的前置章节内容】
{previous_content}
本章信息:
- 章节序号:第{chapter_number}
- 章节标题:{chapter_title}
- 章节大纲:{chapter_outline}
创作要求:
1. **剧情连贯性(最重要)**
- 必须承接前面章节的剧情发展
- 注意角色状态、情节进展、时间线的连续性
- 不能出现与前文矛盾的内容
- 自然过渡,避免突兀的跳跃
2. **情节推进**
- 严格按照本章大纲展开情节
- 推动故事向前发展
- 保持与全书大纲的一致性
3. **角色一致性**
- 符合角色性格设定
- 延续角色在前文中的成长和变化
- 保持角色关系的连贯性
4. **写作风格**
- 使用{narrative_perspective}视角
- 字数不得低于3000字
- 语言自然流畅,避免AI痕迹
- 体现世界观特色
5. **承上启下**
- 开头自然衔接上一章结尾
- 结尾为下一章做好铺垫
**写作风格要求(重要):**
- 让故事自然流淌,写到哪算哪
- 结尾处直接结束情节,不要加总结性段落
- 不要在章节末尾写"这一天/这一夜就这样过去了"之类的总结句
- 不要用"他/她陷入了沉思"作为结尾
- 避免刻意的情感升华或哲理感悟收尾
- 章节结尾可以戛然而止,可以是对话,可以是动作,可以是悬念
- 就像在讲一个故事,讲完了就停,不需要画龙点睛
请直接输出章节正文内容,不要包含章节标题和其他说明文字。"""
# 大纲生成提示词
OUTLINE_GENERATION = """你是一位经验丰富的小说作家和编剧。请根据以下信息生成小说大纲:
类型:{genre}
主题:{theme}
目标字数:{target_words}
其他要求:{requirements}
请生成一个完整的章节大纲框架,包含:
1. 合理的章节数量(根据字数)
2. 每章的标题和内容概要
3. 清晰的故事结构(起承转合)
4. 情节的递进和冲突升级
5. 角色的成长弧线
**重要:你必须只返回纯JSON格式,不要包含任何markdown标记、代码块标记或其他说明文字。**
请严格按照以下JSON格式返回:
{{
"chapters": [
{{
"order": 1,
"title": "章节标题",
"content": "章节内容概要(150-200字)"
}}
]
}}
再次强调:只返回纯JSON对象,不要有```json```这样的标记,不要有任何额外的文字说明。"""
# 单个角色生成提示词
SINGLE_CHARACTER_GENERATION = """你是一位专业的角色设定师。请根据以下信息创建一个立体饱满的小说角色。
{project_context}
{user_input}
请生成一个完整的角色卡片,包含以下所有信息:
1. **基本信息**
- 姓名:如果用户未提供,请生成一个符合世界观的名字
- 年龄:具体数字或年龄段
- 性别:男/女/其他
2. **外貌特征**100-150字):
- 身高体型、面容特征、着装风格
- 要符合角色定位和世界观设定
3. **性格特点**150-200字):
- 核心性格特质(至少3个)
- 优点和缺点
- 特殊习惯或癖好
- 性格要有复杂性和矛盾性
4. **背景故事**200-300字):
- 家庭背景
- 成长经历
- 重要转折事件
- 如何与项目主题关联
- 融入用户提供的背景设定
5. **人际关系**
- 与现有角色的关系(如果有)
- 重要的人际纽带
- 社会地位和人脉
6. **特殊能力/特长**
- 擅长的领域
- 特殊技能或知识
- 符合世界观设定
**你必须只返回纯JSON格式,不要包含任何markdown标记、代码块标记或其他说明文字。**
请严格按照以下JSON格式返回:
{{
"name": "角色姓名",
"age": "年龄",
"gender": "性别",
"appearance": "外貌描述(100-150字)",
"personality": "性格特点(150-200字)",
"background": "背景故事(200-300字)",
"traits": ["特长1", "特长2", "特长3"],
"relationships_text": "人际关系的文字描述(用于显示)",
"relationships": [
{{
"target_character_name": "已存在的角色名称",
"relationship_type": "关系类型(如:师父、朋友、敌人、父亲、母亲等)",
"intimacy_level": 75,
"description": "这段关系的详细描述",
"started_at": "关系开始的故事时间点(可选)"
}}
],
"organization_memberships": [
{{
"organization_name": "已存在的组织名称",
"position": "职位名称",
"rank": 8,
"loyalty": 80,
"joined_at": "加入时间(可选)",
"status": "active"
}}
]
}}
**关系类型参考(请从中选择或自定义):**
- 家族关系:父亲、母亲、兄弟、姐妹、子女、配偶、恋人
- 社交关系:师父、徒弟、朋友、同学、同事、邻居、知己
- 职业关系:上司、下属、合作伙伴
- 敌对关系:敌人、仇人、竞争对手、宿敌
**重要说明:**
1. relationships数组:只包含与上面列出的已存在角色的关系,通过target_character_name匹配
2. organization_memberships数组:只包含与上面列出的已存在组织的关系
3. intimacy_level和loyalty都是0-100的整数
4. 如果没有关系或组织,对应数组为空[]
5. relationships_text是自然语言描述,用于展示给用户看
**角色设定要求:**
- 角色要符合项目的世界观和主题
- 如果是主角,要有明确的成长空间和目标动机
- 如果是反派,要有合理的动机,不能脸谱化
- 配角要有独特性,不能是工具人
- 所有设定要为故事服务
再次强调:只返回纯JSON对象,不要有```json```这样的标记,不要有任何额外的文字说明。"""
@staticmethod
def format_prompt(template: str, **kwargs) -> str:
"""
格式化提示词模板
Args:
template: 提示词模板
**kwargs: 模板参数
Returns:
格式化后的提示词
"""
try:
return template.format(**kwargs)
except KeyError as e:
raise ValueError(f"缺少必需的参数: {e}")
@classmethod
def get_denoising_prompt(cls, original_text: str) -> str:
"""获取AI去味提示词"""
return cls.format_prompt(
cls.AI_DENOISING,
original_text=original_text
)
@classmethod
def get_world_building_prompt(cls, title: str, theme: str, genre: str = "") -> str:
"""获取世界构建提示词"""
return cls.format_prompt(
cls.WORLD_BUILDING,
title=title,
theme=theme,
genre=genre or "通用类型"
)
@classmethod
def get_characters_batch_prompt(cls, count: int, time_period: str, location: str,
atmosphere: str, rules: str, theme: str,
genre: str = "", requirements: str = "") -> str:
"""获取批量角色生成提示词"""
return cls.format_prompt(
cls.CHARACTERS_BATCH_GENERATION,
count=count,
time_period=time_period,
location=location,
atmosphere=atmosphere,
rules=rules,
theme=theme,
genre=genre or "通用类型",
requirements=requirements or "无特殊要求"
)
@classmethod
def get_complete_outline_prompt(cls, title: str, theme: str, genre: str,
chapter_count: int, narrative_perspective: str,
target_words: int, time_period: str, location: str,
atmosphere: str, rules: str, characters_info: str,
requirements: str = "") -> str:
"""获取完整大纲生成提示词"""
return cls.format_prompt(
cls.COMPLETE_OUTLINE_GENERATION,
title=title,
theme=theme,
genre=genre,
chapter_count=chapter_count,
narrative_perspective=narrative_perspective,
target_words=target_words,
time_period=time_period,
location=location,
atmosphere=atmosphere,
rules=rules,
characters_info=characters_info,
requirements=requirements or "无特殊要求"
)
@classmethod
def get_chapter_generation_prompt(cls, title: str, theme: str, genre: str,
narrative_perspective: str, time_period: str,
location: str, atmosphere: str, rules: str,
characters_info: str, outlines_context: str,
chapter_number: int, chapter_title: str,
chapter_outline: str) -> str:
"""获取章节完整创作提示词"""
return cls.format_prompt(
cls.CHAPTER_GENERATION,
title=title,
theme=theme,
genre=genre,
narrative_perspective=narrative_perspective,
time_period=time_period,
location=location,
atmosphere=atmosphere,
rules=rules,
characters_info=characters_info,
outlines_context=outlines_context,
chapter_number=chapter_number,
chapter_title=chapter_title,
chapter_outline=chapter_outline
)
@classmethod
def get_chapter_generation_with_context_prompt(cls, title: str, theme: str, genre: str,
narrative_perspective: str, time_period: str,
location: str, atmosphere: str, rules: str,
characters_info: str, outlines_context: str,
previous_content: str, chapter_number: int,
chapter_title: str, chapter_outline: str) -> str:
"""获取章节完整创作提示词(带前置章节上下文)"""
return cls.format_prompt(
cls.CHAPTER_GENERATION_WITH_CONTEXT,
title=title,
theme=theme,
genre=genre,
narrative_perspective=narrative_perspective,
time_period=time_period,
location=location,
atmosphere=atmosphere,
rules=rules,
characters_info=characters_info,
outlines_context=outlines_context,
previous_content=previous_content,
chapter_number=chapter_number,
chapter_title=chapter_title,
chapter_outline=chapter_outline
)
@classmethod
def get_outline_prompt(cls, genre: str, theme: str, target_words: int,
requirements: str = "") -> str:
"""获取大纲生成提示词"""
return cls.format_prompt(
cls.OUTLINE_GENERATION,
genre=genre,
theme=theme,
target_words=target_words,
requirements=requirements or "无特殊要求"
)
@classmethod
def get_outline_continue_prompt(cls, title: str, theme: str, genre: str,
narrative_perspective: str, chapter_count: int,
time_period: str, location: str, atmosphere: str,
rules: str, characters_info: str,
current_chapter_count: int, all_chapters_brief: str,
recent_plot: str, plot_stage_instruction: str,
start_chapter: int, story_direction: str,
requirements: str = "") -> str:
"""获取大纲续写提示词"""
end_chapter = start_chapter + chapter_count - 1
return cls.format_prompt(
cls.OUTLINE_CONTINUE_GENERATION,
title=title,
theme=theme,
genre=genre,
narrative_perspective=narrative_perspective,
chapter_count=chapter_count,
time_period=time_period,
location=location,
atmosphere=atmosphere,
rules=rules,
characters_info=characters_info,
current_chapter_count=current_chapter_count,
all_chapters_brief=all_chapters_brief,
recent_plot=recent_plot,
plot_stage_instruction=plot_stage_instruction,
start_chapter=start_chapter,
end_chapter=end_chapter,
story_direction=story_direction,
requirements=requirements or "无特殊要求"
)
@classmethod
def get_single_character_prompt(cls, project_context: str, user_input: str) -> str:
"""获取单个角色生成提示词"""
return cls.format_prompt(
cls.SINGLE_CHARACTER_GENERATION,
project_context=project_context,
user_input=user_input
)
# 创建全局提示词服务实例
prompt_service = PromptService()
+294
View File
@@ -0,0 +1,294 @@
"""
用户管理模块 - 支持 LinuxDO OAuth2
"""
import json
import os
import asyncio
from datetime import datetime
from typing import Optional, Dict, List
from pydantic import BaseModel
from app.config import settings, DATA_DIR
class User(BaseModel):
"""用户模型"""
user_id: str # 格式: linuxdo_{linuxdo_id}
username: str
display_name: str
avatar_url: Optional[str] = None
trust_level: int = 0 # 仅用于显示
is_admin: bool = False # 手动设置的管理员权限
linuxdo_id: str # LinuxDO 用户 ID
created_at: str
last_login: str
class UserManager:
"""用户管理器 - 线程安全版本"""
USERS_FILE = str(DATA_DIR / "users.json")
ADMINS_FILE = str(DATA_DIR / "admins.json")
def __init__(self):
"""初始化用户管理器"""
# DATA_DIR 已在 config.py 中创建,无需重复创建
# 添加文件锁保护并发读写
self._users_lock = asyncio.Lock()
self._admins_lock = asyncio.Lock()
self._ensure_files_exist()
def _ensure_files_exist(self):
"""确保必要的文件存在"""
if not os.path.exists(self.USERS_FILE):
with open(self.USERS_FILE, "w", encoding="utf-8") as f:
json.dump({}, f, ensure_ascii=False, indent=2)
if not os.path.exists(self.ADMINS_FILE):
with open(self.ADMINS_FILE, "w", encoding="utf-8") as f:
json.dump({"admins": []}, f, ensure_ascii=False, indent=2)
def _load_users_unsafe(self) -> Dict[str, dict]:
"""加载用户数据(不加锁,内部使用)"""
try:
with open(self.USERS_FILE, "r", encoding="utf-8") as f:
return json.load(f)
except Exception as e:
print(f"加载用户数据失败: {e}")
return {}
def _save_users_unsafe(self, users: Dict[str, dict]):
"""保存用户数据(不加锁,内部使用)"""
try:
with open(self.USERS_FILE, "w", encoding="utf-8") as f:
json.dump(users, f, ensure_ascii=False, indent=2)
except Exception as e:
print(f"保存用户数据失败: {e}")
async def _load_users(self) -> Dict[str, dict]:
"""加载用户数据(加锁)"""
async with self._users_lock:
return self._load_users_unsafe()
async def _save_users(self, users: Dict[str, dict]):
"""保存用户数据(加锁)"""
async with self._users_lock:
self._save_users_unsafe(users)
def _load_admin_list_unsafe(self) -> List[str]:
"""加载管理员列表(不加锁,内部使用)"""
try:
with open(self.ADMINS_FILE, "r", encoding="utf-8") as f:
data = json.load(f)
return data.get("admins", [])
except Exception as e:
print(f"加载管理员列表失败: {e}")
return []
def _save_admin_list_unsafe(self, admin_list: List[str]):
"""保存管理员列表(不加锁,内部使用)"""
try:
with open(self.ADMINS_FILE, "w", encoding="utf-8") as f:
json.dump({"admins": admin_list}, f, ensure_ascii=False, indent=2)
except Exception as e:
print(f"保存管理员列表失败: {e}")
async def _load_admin_list(self) -> List[str]:
"""加载管理员列表(加锁)"""
async with self._admins_lock:
return self._load_admin_list_unsafe()
async def _save_admin_list(self, admin_list: List[str]):
"""保存管理员列表(加锁)"""
async with self._admins_lock:
self._save_admin_list_unsafe(admin_list)
async def create_or_update_from_linuxdo(
self,
linuxdo_id: str,
username: str,
display_name: str,
avatar_url: Optional[str],
trust_level: int
) -> User:
"""
从 LinuxDO 用户信息创建或更新用户(线程安全)
Args:
linuxdo_id: LinuxDO 用户 ID(本地用户时为 local_xxx 格式)
username: 用户名
display_name: 显示名称
avatar_url: 头像 URL
trust_level: 信任等级 (仅用于显示)
Returns:
用户对象
"""
# 如果已经是 local_ 开头,直接使用;否则添加 linuxdo_ 前缀
if linuxdo_id.startswith("local_"):
user_id = linuxdo_id
else:
user_id = f"linuxdo_{linuxdo_id}"
# 使用锁保护整个读-改-写操作
async with self._users_lock:
async with self._admins_lock:
users = self._load_users_unsafe()
admin_list = self._load_admin_list_unsafe()
now = datetime.now().isoformat()
# 检查是否为初始管理员
initial_admin_id = settings.INITIAL_ADMIN_LINUXDO_ID
is_initial_admin = (initial_admin_id and linuxdo_id == initial_admin_id)
# 检查是否为本地用户(所有 local_ 开头的用户默认为管理员)
is_local_user = user_id.startswith("local_")
if user_id in users:
# 更新现有用户
user_data = users[user_id]
user_data["username"] = username
user_data["display_name"] = display_name
user_data["avatar_url"] = avatar_url
user_data["trust_level"] = trust_level
user_data["last_login"] = now
# 如果是初始管理员或本地用户且还不在管理员列表中,添加进去
if (is_initial_admin or is_local_user) and user_id not in admin_list:
admin_list.append(user_id)
self._save_admin_list_unsafe(admin_list)
user_data["is_admin"] = True
else:
# 从管理员列表同步 is_admin 状态
user_data["is_admin"] = user_id in admin_list
else:
# 创建新用户(本地用户默认为管理员)
is_admin = is_initial_admin or is_local_user
if is_admin and user_id not in admin_list:
admin_list.append(user_id)
self._save_admin_list_unsafe(admin_list)
user_data = {
"user_id": user_id,
"username": username,
"display_name": display_name,
"avatar_url": avatar_url,
"trust_level": trust_level,
"is_admin": is_admin,
"linuxdo_id": linuxdo_id,
"created_at": now,
"last_login": now
}
users[user_id] = user_data
self._save_users_unsafe(users)
return User(**user_data)
async def get_user(self, user_id: str) -> Optional[User]:
"""获取用户(线程安全)"""
users = await self._load_users()
user_data = users.get(user_id)
if user_data:
# 同步管理员状态
admin_list = await self._load_admin_list()
user_data["is_admin"] = user_id in admin_list
return User(**user_data)
return None
async def get_all_users(self) -> List[User]:
"""获取所有用户(线程安全)"""
users = await self._load_users()
admin_list = await self._load_admin_list()
user_list = []
for user_data in users.values():
# 同步管理员状态
user_data["is_admin"] = user_data["user_id"] in admin_list
user_list.append(User(**user_data))
return user_list
async def set_admin(self, user_id: str, is_admin: bool) -> bool:
"""
设置用户的管理员权限(线程安全)
Args:
user_id: 用户 ID
is_admin: 是否为管理员
Returns:
是否成功
"""
# 使用锁保护整个读-改-写操作
async with self._users_lock:
async with self._admins_lock:
users = self._load_users_unsafe()
if user_id not in users:
return False
admin_list = self._load_admin_list_unsafe()
if is_admin:
# 授予管理员权限
if user_id not in admin_list:
admin_list.append(user_id)
self._save_admin_list_unsafe(admin_list)
else:
# 撤销管理员权限
if user_id in admin_list:
# 确保至少保留一个管理员
if len(admin_list) <= 1:
return False
admin_list.remove(user_id)
self._save_admin_list_unsafe(admin_list)
# 更新用户数据中的 is_admin 字段
users[user_id]["is_admin"] = is_admin
self._save_users_unsafe(users)
return True
async def delete_user(self, user_id: str) -> bool:
"""
删除用户(线程安全)
Args:
user_id: 用户 ID
Returns:
是否成功
"""
# 使用锁保护整个读-改-写操作
async with self._users_lock:
async with self._admins_lock:
users = self._load_users_unsafe()
if user_id not in users:
return False
# 不能删除管理员
admin_list = self._load_admin_list_unsafe()
if user_id in admin_list:
return False
# 删除用户数据
del users[user_id]
self._save_users_unsafe(users)
# 删除用户数据库文件(在锁外执行,避免阻塞)
db_file = str(DATA_DIR / f"ai_story_user_{user_id}.db")
if os.path.exists(db_file):
try:
os.remove(db_file)
except Exception as e:
print(f"删除用户数据库文件失败: {e}")
return True
async def is_admin(self, user_id: str) -> bool:
"""检查用户是否为管理员(线程安全)"""
admin_list = await self._load_admin_list()
return user_id in admin_list
# 全局用户管理器实例
user_manager = UserManager()
+347
View File
@@ -0,0 +1,347 @@
"""数据一致性辅助函数"""
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
from typing import Optional, Tuple, List
from app.models.character import Character
from app.models.relationship import Organization, OrganizationMember, CharacterRelationship
from app.logger import get_logger
logger = get_logger(__name__)
async def ensure_organization_record(
character: Character,
db: AsyncSession,
power_level: int = 50,
location: Optional[str] = None,
motto: Optional[str] = None
) -> Optional[Organization]:
"""
确保组织角色拥有对应的Organization记录
Args:
character: Character对象(必须是is_organization=True
db: 数据库会话
power_level: 势力等级(默认50
location: 所在地
motto: 宗旨/口号
Returns:
Organization对象,如果character不是组织则返回None
"""
if not character.is_organization:
logger.debug(f"角色 {character.name} 不是组织,跳过Organization记录创建")
return None
# 检查是否已存在
result = await db.execute(
select(Organization).where(Organization.character_id == character.id)
)
org = result.scalar_one_or_none()
if not org:
# 创建新的Organization记录
org = Organization(
character_id=character.id,
project_id=character.project_id,
member_count=0,
power_level=power_level,
location=location,
motto=motto
)
db.add(org)
await db.flush()
await db.refresh(org)
logger.info(f"✅ 自动创建组织详情:{character.name} (Org ID: {org.id})")
else:
logger.debug(f"组织详情已存在:{character.name} (Org ID: {org.id})")
return org
async def sync_organization_member_count(
organization: Organization,
db: AsyncSession
) -> int:
"""
同步组织的成员计数,从实际成员记录计算
Args:
organization: Organization对象
db: 数据库会话
Returns:
实际成员数量
"""
result = await db.execute(
select(OrganizationMember).where(
OrganizationMember.organization_id == organization.id,
OrganizationMember.status == "active"
)
)
members = result.scalars().all()
actual_count = len(members)
if organization.member_count != actual_count:
logger.warning(
f"组织 {organization.id} 成员计数不一致:"
f"记录值={organization.member_count}, 实际值={actual_count},已修正"
)
organization.member_count = actual_count
await db.flush()
return actual_count
async def fix_missing_organization_records(
project_id: str,
db: AsyncSession
) -> Tuple[int, int]:
"""
修复项目中缺失的Organization记录
为所有is_organization=True但没有Organization记录的Character创建记录
Args:
project_id: 项目ID
db: 数据库会话
Returns:
(修复数量, 检查总数)
"""
# 查找所有组织角色
result = await db.execute(
select(Character).where(
Character.project_id == project_id,
Character.is_organization == True
)
)
org_characters = result.scalars().all()
fixed_count = 0
for char in org_characters:
org = await ensure_organization_record(char, db)
if org and org.id: # 新创建的才计数
# 检查是否是新创建的(通过查询历史)
result = await db.execute(
select(Organization).where(Organization.character_id == char.id)
)
if result.scalar_one_or_none():
fixed_count += 1
await db.commit()
logger.info(f"📊 修复统计 - 检查了 {len(org_characters)} 个组织,修复了 {fixed_count} 个缺失的Organization记录")
return fixed_count, len(org_characters)
async def fix_organization_member_counts(
project_id: str,
db: AsyncSession
) -> Tuple[int, int]:
"""
修复项目中所有组织的成员计数
Args:
project_id: 项目ID
db: 数据库会话
Returns:
(修复数量, 检查总数)
"""
# 查找所有组织
result = await db.execute(
select(Organization).where(Organization.project_id == project_id)
)
organizations = result.scalars().all()
fixed_count = 0
for org in organizations:
old_count = org.member_count
actual_count = await sync_organization_member_count(org, db)
if old_count != actual_count:
fixed_count += 1
await db.commit()
logger.info(f"📊 修复统计 - 检查了 {len(organizations)} 个组织,修复了 {fixed_count} 个计数错误")
return fixed_count, len(organizations)
async def validate_relationships(
project_id: str,
db: AsyncSession
) -> List[dict]:
"""
验证项目中的关系数据完整性
检查所有关系中的character_from_id和character_to_id是否都指向存在的角色
Args:
project_id: 项目ID
db: 数据库会话
Returns:
问题列表,每个问题包含 {issue_type, relationship_id, details}
"""
issues = []
# 获取所有关系
result = await db.execute(
select(CharacterRelationship).where(CharacterRelationship.project_id == project_id)
)
relationships = result.scalars().all()
for rel in relationships:
# 检查from角色
from_char = await db.execute(
select(Character).where(Character.id == rel.character_from_id)
)
if not from_char.scalar_one_or_none():
issues.append({
"issue_type": "missing_from_character",
"relationship_id": rel.id,
"details": f"关系 {rel.id} 的源角色 {rel.character_from_id} 不存在"
})
# 检查to角色
to_char = await db.execute(
select(Character).where(Character.id == rel.character_to_id)
)
if not to_char.scalar_one_or_none():
issues.append({
"issue_type": "missing_to_character",
"relationship_id": rel.id,
"details": f"关系 {rel.id} 的目标角色 {rel.character_to_id} 不存在"
})
if issues:
logger.warning(f"⚠️ 发现 {len(issues)} 个关系数据问题")
for issue in issues:
logger.warning(f" - {issue['details']}")
else:
logger.info(f"✅ 所有 {len(relationships)} 条关系数据完整")
return issues
async def validate_organization_members(
project_id: str,
db: AsyncSession
) -> List[dict]:
"""
验证项目中的组织成员数据完整性
检查所有成员关系中的organization_id和character_id是否都有效
Args:
project_id: 项目ID
db: 数据库会话
Returns:
问题列表
"""
issues = []
# 获取所有成员关系
result = await db.execute(
select(OrganizationMember).where(
OrganizationMember.organization_id.in_(
select(Organization.id).where(Organization.project_id == project_id)
)
)
)
members = result.scalars().all()
for member in members:
# 检查组织
org = await db.execute(
select(Organization).where(Organization.id == member.organization_id)
)
if not org.scalar_one_or_none():
issues.append({
"issue_type": "missing_organization",
"member_id": member.id,
"details": f"成员 {member.id} 的组织 {member.organization_id} 不存在"
})
# 检查角色
char = await db.execute(
select(Character).where(Character.id == member.character_id)
)
if not char.scalar_one_or_none():
issues.append({
"issue_type": "missing_character",
"member_id": member.id,
"details": f"成员 {member.id} 的角色 {member.character_id} 不存在"
})
if issues:
logger.warning(f"⚠️ 发现 {len(issues)} 个组织成员数据问题")
for issue in issues:
logger.warning(f" - {issue['details']}")
else:
logger.info(f"✅ 所有 {len(members)} 条组织成员数据完整")
return issues
async def run_full_data_consistency_check(
project_id: str,
db: AsyncSession,
auto_fix: bool = True
) -> dict:
"""
对项目运行完整的数据一致性检查和修复
Args:
project_id: 项目ID
db: 数据库会话
auto_fix: 是否自动修复问题(默认True)
Returns:
检查报告字典
"""
logger.info(f"🔍 开始数据一致性检查 - 项目 {project_id}")
report = {
"project_id": project_id,
"checks": {}
}
# 1. 检查并修复缺失的Organization记录
if auto_fix:
fixed, total = await fix_missing_organization_records(project_id, db)
report["checks"]["organization_records"] = {
"checked": total,
"fixed": fixed,
"status": "ok" if fixed == 0 else "fixed"
}
# 2. 检查并修复成员计数
if auto_fix:
fixed, total = await fix_organization_member_counts(project_id, db)
report["checks"]["member_counts"] = {
"checked": total,
"fixed": fixed,
"status": "ok" if fixed == 0 else "fixed"
}
# 3. 验证关系数据
rel_issues = await validate_relationships(project_id, db)
report["checks"]["relationships"] = {
"issues_found": len(rel_issues),
"issues": rel_issues,
"status": "ok" if len(rel_issues) == 0 else "warning"
}
# 4. 验证组织成员数据
member_issues = await validate_organization_members(project_id, db)
report["checks"]["organization_members"] = {
"issues_found": len(member_issues),
"issues": member_issues,
"status": "ok" if len(member_issues) == 0 else "warning"
}
logger.info(f"✅ 数据一致性检查完成")
return report
+169
View File
@@ -0,0 +1,169 @@
"""Server-Sent Events (SSE) 响应工具类"""
import json
import asyncio
from typing import AsyncGenerator, Dict, Any, Optional
from fastapi.responses import StreamingResponse
from app.logger import get_logger
logger = get_logger(__name__)
class SSEResponse:
"""SSE响应构建器"""
@staticmethod
def format_sse(data: Dict[str, Any], event: Optional[str] = None) -> str:
"""
格式化SSE消息
Args:
data: 要发送的数据字典
event: 事件类型(可选)
Returns:
格式化后的SSE消息字符串
"""
message = ""
if event:
message += f"event: {event}\n"
message += f"data: {json.dumps(data, ensure_ascii=False)}\n\n"
return message
@staticmethod
async def send_progress(
message: str,
progress: int,
status: str = "processing"
) -> str:
"""
发送进度消息
Args:
message: 进度消息
progress: 进度百分比(0-100)
status: 状态(processing/success/error)
"""
return SSEResponse.format_sse({
"type": "progress",
"message": message,
"progress": progress,
"status": status
})
@staticmethod
async def send_chunk(content: str) -> str:
"""
发送内容块(用于流式输出AI生成内容)
Args:
content: 内容块
"""
return SSEResponse.format_sse({
"type": "chunk",
"content": content
})
@staticmethod
async def send_result(data: Dict[str, Any]) -> str:
"""
发送最终结果
Args:
data: 结果数据
"""
return SSEResponse.format_sse({
"type": "result",
"data": data
})
@staticmethod
async def send_error(error: str, code: int = 500) -> str:
"""
发送错误消息
Args:
error: 错误描述
code: 错误码
"""
return SSEResponse.format_sse({
"type": "error",
"error": error,
"code": code
})
@staticmethod
async def send_done() -> str:
"""发送完成消息"""
return SSEResponse.format_sse({
"type": "done"
})
@staticmethod
async def send_heartbeat() -> str:
"""发送心跳消息(保持连接活跃)"""
return ": heartbeat\n\n"
async def create_sse_generator(
async_gen: AsyncGenerator[str, None],
show_progress: bool = True
) -> AsyncGenerator[str, None]:
"""
创建SSE生成器包装器
Args:
async_gen: 异步生成器
show_progress: 是否显示进度
Yields:
格式化的SSE消息
"""
try:
if show_progress:
yield await SSEResponse.send_progress("开始生成...", 0)
# 累积内容用于进度计算
accumulated_content = ""
chunk_count = 0
async for chunk in async_gen:
chunk_count += 1
accumulated_content += chunk
# 发送内容块
yield await SSEResponse.send_chunk(chunk)
# 每10个块发送一次心跳
if chunk_count % 10 == 0:
yield await SSEResponse.send_heartbeat()
if show_progress:
yield await SSEResponse.send_progress("生成完成", 100, "success")
# 发送完成信号
yield await SSEResponse.send_done()
except Exception as e:
logger.error(f"SSE生成器错误: {str(e)}")
yield await SSEResponse.send_error(str(e))
def create_sse_response(generator: AsyncGenerator[str, None]) -> StreamingResponse:
"""
创建SSE StreamingResponse
Args:
generator: SSE消息生成器
Returns:
StreamingResponse对象
"""
return StreamingResponse(
generator,
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no", # 禁用nginx缓冲
}
)
+20
View File
@@ -0,0 +1,20 @@
# Web框架
fastapi==0.109.0
uvicorn[standard]==0.27.0
python-multipart==0.0.6
# 数据库
sqlalchemy==2.0.25
aiosqlite==0.19.0
# 数据验证
pydantic==2.5.3
pydantic-settings==2.1.0
# AI服务
openai==1.10.0
anthropic==0.18.0
# 工具库
httpx==0.26.0
python-dotenv==1.0.0