init
This commit is contained in:
@@ -0,0 +1,2 @@
|
||||
"""AI Story Creator - 后端应用包"""
|
||||
__version__ = "1.0.0"
|
||||
@@ -0,0 +1 @@
|
||||
"""API路由模块"""
|
||||
@@ -0,0 +1,230 @@
|
||||
"""
|
||||
认证 API - LinuxDO OAuth2 登录 + 本地账户登录
|
||||
"""
|
||||
from fastapi import APIRouter, HTTPException, Response, Request
|
||||
from fastapi.responses import RedirectResponse
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional
|
||||
import hashlib
|
||||
from app.services.oauth_service import LinuxDOOAuthService
|
||||
from app.user_manager import user_manager
|
||||
from app.database import init_db
|
||||
from app.logger import get_logger
|
||||
from app.config import settings
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/auth", tags=["认证"])
|
||||
|
||||
# OAuth2 服务实例
|
||||
oauth_service = LinuxDOOAuthService()
|
||||
|
||||
# State 临时存储(生产环境应使用 Redis)
|
||||
_state_storage = {}
|
||||
|
||||
|
||||
class AuthUrlResponse(BaseModel):
|
||||
auth_url: str
|
||||
state: str
|
||||
|
||||
|
||||
class LocalLoginRequest(BaseModel):
|
||||
"""本地登录请求"""
|
||||
username: str
|
||||
password: str
|
||||
|
||||
|
||||
class LocalLoginResponse(BaseModel):
|
||||
"""本地登录响应"""
|
||||
success: bool
|
||||
message: str
|
||||
user: Optional[dict] = None
|
||||
|
||||
|
||||
@router.get("/config")
|
||||
async def get_auth_config():
|
||||
"""获取认证配置信息"""
|
||||
return {
|
||||
"local_auth_enabled": settings.LOCAL_AUTH_ENABLED,
|
||||
"linuxdo_enabled": bool(settings.LINUXDO_CLIENT_ID and settings.LINUXDO_CLIENT_SECRET)
|
||||
}
|
||||
|
||||
|
||||
@router.post("/local/login", response_model=LocalLoginResponse)
|
||||
async def local_login(request: LocalLoginRequest, response: Response):
|
||||
"""本地账户登录"""
|
||||
# 检查是否启用本地登录
|
||||
if not settings.LOCAL_AUTH_ENABLED:
|
||||
raise HTTPException(status_code=403, detail="本地账户登录未启用")
|
||||
|
||||
# 检查是否配置了本地账户
|
||||
if not settings.LOCAL_AUTH_USERNAME or not settings.LOCAL_AUTH_PASSWORD:
|
||||
raise HTTPException(status_code=500, detail="本地账户未配置")
|
||||
|
||||
# 验证用户名和密码
|
||||
if request.username != settings.LOCAL_AUTH_USERNAME or request.password != settings.LOCAL_AUTH_PASSWORD:
|
||||
raise HTTPException(status_code=401, detail="用户名或密码错误")
|
||||
|
||||
# 生成本地用户ID(使用用户名的hash)
|
||||
user_id = f"local_{hashlib.md5(request.username.encode()).hexdigest()[:16]}"
|
||||
|
||||
# 创建或更新本地用户
|
||||
user = await user_manager.create_or_update_from_linuxdo(
|
||||
linuxdo_id=user_id,
|
||||
username=request.username,
|
||||
display_name=settings.LOCAL_AUTH_DISPLAY_NAME,
|
||||
avatar_url=None,
|
||||
trust_level=9 # 本地用户给予高信任级别
|
||||
)
|
||||
|
||||
# 初始化用户数据库
|
||||
try:
|
||||
await init_db(user.user_id)
|
||||
logger.info(f"本地用户 {user.user_id} 数据库初始化成功")
|
||||
except Exception as e:
|
||||
logger.error(f"本地用户 {user.user_id} 数据库初始化失败: {e}")
|
||||
|
||||
# 设置 Cookie(7天有效)
|
||||
response.set_cookie(
|
||||
key="user_id",
|
||||
value=user.user_id,
|
||||
max_age=7 * 24 * 60 * 60, # 7天
|
||||
httponly=True,
|
||||
samesite="lax"
|
||||
)
|
||||
|
||||
return LocalLoginResponse(
|
||||
success=True,
|
||||
message="登录成功",
|
||||
user=user.dict()
|
||||
)
|
||||
|
||||
|
||||
@router.get("/linuxdo/url", response_model=AuthUrlResponse)
|
||||
async def get_linuxdo_auth_url():
|
||||
"""获取 LinuxDO 授权 URL"""
|
||||
state = oauth_service.generate_state()
|
||||
auth_url = oauth_service.get_authorization_url(state)
|
||||
|
||||
# 临时存储 state(5分钟有效)
|
||||
_state_storage[state] = True
|
||||
|
||||
return AuthUrlResponse(auth_url=auth_url, state=state)
|
||||
|
||||
|
||||
async def _handle_callback(
|
||||
code: Optional[str] = None,
|
||||
state: Optional[str] = None,
|
||||
error: Optional[str] = None,
|
||||
response: Response = None
|
||||
):
|
||||
"""
|
||||
LinuxDO OAuth2 回调处理
|
||||
|
||||
成功后重定向到前端首页,并设置 user_id Cookie
|
||||
"""
|
||||
# 检查是否有错误
|
||||
if error:
|
||||
raise HTTPException(status_code=400, detail=f"授权失败: {error}")
|
||||
|
||||
# 检查必需参数
|
||||
if not code or not state:
|
||||
raise HTTPException(status_code=400, detail="缺少 code 或 state 参数")
|
||||
|
||||
# 验证 state(防止 CSRF)
|
||||
if state not in _state_storage:
|
||||
raise HTTPException(status_code=400, detail="无效的 state 参数")
|
||||
|
||||
# 删除已使用的 state
|
||||
del _state_storage[state]
|
||||
|
||||
# 1. 使用 code 获取 access_token
|
||||
token_data = await oauth_service.get_access_token(code)
|
||||
if not token_data or "access_token" not in token_data:
|
||||
raise HTTPException(status_code=400, detail="获取访问令牌失败")
|
||||
|
||||
access_token = token_data["access_token"]
|
||||
|
||||
# 2. 使用 access_token 获取用户信息
|
||||
user_info = await oauth_service.get_user_info(access_token)
|
||||
if not user_info:
|
||||
raise HTTPException(status_code=400, detail="获取用户信息失败")
|
||||
|
||||
# 3. 创建或更新用户
|
||||
linuxdo_id = str(user_info.get("id"))
|
||||
username = user_info.get("username", "")
|
||||
display_name = user_info.get("name", username)
|
||||
avatar_url = user_info.get("avatar_url")
|
||||
trust_level = user_info.get("trust_level", 0)
|
||||
|
||||
user = await user_manager.create_or_update_from_linuxdo(
|
||||
linuxdo_id=linuxdo_id,
|
||||
username=username,
|
||||
display_name=display_name,
|
||||
avatar_url=avatar_url,
|
||||
trust_level=trust_level
|
||||
)
|
||||
|
||||
# 3.5. 初始化用户数据库(如果是新用户)
|
||||
try:
|
||||
await init_db(user.user_id)
|
||||
logger.info(f"用户 {user.user_id} 数据库初始化成功")
|
||||
except Exception as e:
|
||||
logger.error(f"用户 {user.user_id} 数据库初始化失败: {e}")
|
||||
# 继续执行,不影响登录流程(可能是已存在的用户)
|
||||
|
||||
# 4. 设置 Cookie 并重定向到前端回调页面
|
||||
# 使用配置的前端URL,支持不同的部署环境
|
||||
frontend_url = settings.FRONTEND_URL.rstrip('/')
|
||||
redirect_url = f"{frontend_url}/auth/callback"
|
||||
logger.info(f"OAuth回调成功,重定向到前端: {redirect_url}")
|
||||
redirect_response = RedirectResponse(url=redirect_url)
|
||||
|
||||
# 设置 httponly Cookie(7天有效)
|
||||
redirect_response.set_cookie(
|
||||
key="user_id",
|
||||
value=user.user_id,
|
||||
max_age=7 * 24 * 60 * 60, # 7天
|
||||
httponly=True,
|
||||
samesite="lax"
|
||||
)
|
||||
|
||||
return redirect_response
|
||||
|
||||
|
||||
@router.get("/linuxdo/callback")
|
||||
async def linuxdo_callback(
|
||||
code: Optional[str] = None,
|
||||
state: Optional[str] = None,
|
||||
error: Optional[str] = None,
|
||||
response: Response = None
|
||||
):
|
||||
"""LinuxDO OAuth2 回调处理(标准路径)"""
|
||||
return await _handle_callback(code, state, error, response)
|
||||
|
||||
|
||||
@router.get("/callback")
|
||||
async def callback_alias(
|
||||
code: Optional[str] = None,
|
||||
state: Optional[str] = None,
|
||||
error: Optional[str] = None,
|
||||
response: Response = None
|
||||
):
|
||||
"""LinuxDO OAuth2 回调处理(兼容路径)"""
|
||||
return await _handle_callback(code, state, error, response)
|
||||
|
||||
|
||||
@router.post("/logout")
|
||||
async def logout(response: Response):
|
||||
"""退出登录"""
|
||||
response.delete_cookie("user_id")
|
||||
return {"message": "退出登录成功"}
|
||||
|
||||
|
||||
@router.get("/user")
|
||||
async def get_current_user(request: Request):
|
||||
"""获取当前登录用户信息"""
|
||||
if not hasattr(request.state, "user") or not request.state.user:
|
||||
raise HTTPException(status_code=401, detail="未登录")
|
||||
|
||||
return request.state.user.dict()
|
||||
@@ -0,0 +1,655 @@
|
||||
"""章节管理API"""
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, func
|
||||
import json
|
||||
import asyncio
|
||||
|
||||
from app.database import get_db
|
||||
from app.models.chapter import Chapter
|
||||
from app.models.project import Project
|
||||
from app.models.outline import Outline
|
||||
from app.models.character import Character
|
||||
from app.models.generation_history import GenerationHistory
|
||||
from app.schemas.chapter import (
|
||||
ChapterCreate,
|
||||
ChapterUpdate,
|
||||
ChapterResponse,
|
||||
ChapterListResponse
|
||||
)
|
||||
from app.services.ai_service import ai_service
|
||||
from app.services.prompt_service import prompt_service
|
||||
from app.logger import get_logger
|
||||
|
||||
router = APIRouter(prefix="/chapters", tags=["章节管理"])
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@router.post("", response_model=ChapterResponse, summary="创建章节")
|
||||
async def create_chapter(
|
||||
chapter: ChapterCreate,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""创建新的章节"""
|
||||
# 验证项目是否存在
|
||||
result = await db.execute(
|
||||
select(Project).where(Project.id == chapter.project_id)
|
||||
)
|
||||
project = result.scalar_one_or_none()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="项目不存在")
|
||||
|
||||
# 计算字数
|
||||
word_count = len(chapter.content)
|
||||
|
||||
db_chapter = Chapter(
|
||||
**chapter.model_dump(),
|
||||
word_count=word_count
|
||||
)
|
||||
db.add(db_chapter)
|
||||
|
||||
# 更新项目的当前字数
|
||||
project.current_words = project.current_words + word_count
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(db_chapter)
|
||||
return db_chapter
|
||||
|
||||
|
||||
@router.get("/project/{project_id}", response_model=ChapterListResponse, summary="获取项目的所有章节")
|
||||
async def get_project_chapters(
|
||||
project_id: str,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""获取指定项目的所有章节(路径参数版本)"""
|
||||
# 获取总数
|
||||
count_result = await db.execute(
|
||||
select(func.count(Chapter.id)).where(Chapter.project_id == project_id)
|
||||
)
|
||||
total = count_result.scalar_one()
|
||||
|
||||
# 获取章节列表
|
||||
result = await db.execute(
|
||||
select(Chapter)
|
||||
.where(Chapter.project_id == project_id)
|
||||
.order_by(Chapter.chapter_number)
|
||||
)
|
||||
chapters = result.scalars().all()
|
||||
|
||||
return ChapterListResponse(total=total, items=chapters)
|
||||
|
||||
|
||||
@router.get("/{chapter_id}", response_model=ChapterResponse, summary="获取章节详情")
|
||||
async def get_chapter(
|
||||
chapter_id: str,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""根据ID获取章节详情"""
|
||||
result = await db.execute(
|
||||
select(Chapter).where(Chapter.id == chapter_id)
|
||||
)
|
||||
chapter = result.scalar_one_or_none()
|
||||
|
||||
if not chapter:
|
||||
raise HTTPException(status_code=404, detail="章节不存在")
|
||||
|
||||
return chapter
|
||||
|
||||
|
||||
@router.put("/{chapter_id}", response_model=ChapterResponse, summary="更新章节")
|
||||
async def update_chapter(
|
||||
chapter_id: str,
|
||||
chapter_update: ChapterUpdate,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""更新章节信息"""
|
||||
result = await db.execute(
|
||||
select(Chapter).where(Chapter.id == chapter_id)
|
||||
)
|
||||
chapter = result.scalar_one_or_none()
|
||||
|
||||
if not chapter:
|
||||
raise HTTPException(status_code=404, detail="章节不存在")
|
||||
|
||||
# 记录旧字数
|
||||
old_word_count = chapter.word_count or 0
|
||||
|
||||
# 更新字段
|
||||
update_data = chapter_update.model_dump(exclude_unset=True)
|
||||
for field, value in update_data.items():
|
||||
setattr(chapter, field, value)
|
||||
|
||||
# 如果内容更新了,重新计算字数
|
||||
if "content" in update_data and chapter.content:
|
||||
new_word_count = len(chapter.content)
|
||||
chapter.word_count = new_word_count
|
||||
|
||||
# 更新项目字数
|
||||
result = await db.execute(
|
||||
select(Project).where(Project.id == chapter.project_id)
|
||||
)
|
||||
project = result.scalar_one_or_none()
|
||||
if project:
|
||||
project.current_words = project.current_words - old_word_count + new_word_count
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(chapter)
|
||||
return chapter
|
||||
|
||||
|
||||
@router.delete("/{chapter_id}", summary="删除章节")
|
||||
async def delete_chapter(
|
||||
chapter_id: str,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""删除章节"""
|
||||
result = await db.execute(
|
||||
select(Chapter).where(Chapter.id == chapter_id)
|
||||
)
|
||||
chapter = result.scalar_one_or_none()
|
||||
|
||||
if not chapter:
|
||||
raise HTTPException(status_code=404, detail="章节不存在")
|
||||
|
||||
# 更新项目字数
|
||||
result = await db.execute(
|
||||
select(Project).where(Project.id == chapter.project_id)
|
||||
)
|
||||
project = result.scalar_one_or_none()
|
||||
if project:
|
||||
project.current_words = max(0, project.current_words - chapter.word_count)
|
||||
|
||||
await db.delete(chapter)
|
||||
await db.commit()
|
||||
|
||||
return {"message": "章节删除成功"}
|
||||
|
||||
|
||||
async def check_prerequisites(db: AsyncSession, chapter: Chapter) -> tuple[bool, str, list[Chapter]]:
|
||||
"""
|
||||
检查章节前置条件
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
chapter: 当前章节
|
||||
|
||||
Returns:
|
||||
(可否生成, 错误信息, 前置章节列表)
|
||||
"""
|
||||
# 如果是第一章,无需检查前置
|
||||
if chapter.chapter_number == 1:
|
||||
return True, "", []
|
||||
|
||||
# 查询所有前置章节(序号小于当前章节的)
|
||||
result = await db.execute(
|
||||
select(Chapter)
|
||||
.where(Chapter.project_id == chapter.project_id)
|
||||
.where(Chapter.chapter_number < chapter.chapter_number)
|
||||
.order_by(Chapter.chapter_number)
|
||||
)
|
||||
previous_chapters = result.scalars().all()
|
||||
|
||||
# 检查是否所有前置章节都有内容
|
||||
incomplete_chapters = [
|
||||
ch for ch in previous_chapters
|
||||
if not ch.content or ch.content.strip() == ""
|
||||
]
|
||||
|
||||
if incomplete_chapters:
|
||||
missing_numbers = [str(ch.chapter_number) for ch in incomplete_chapters]
|
||||
error_msg = f"需要先完成前置章节:第 {', '.join(missing_numbers)} 章"
|
||||
return False, error_msg, previous_chapters
|
||||
|
||||
return True, "", previous_chapters
|
||||
|
||||
|
||||
@router.get("/{chapter_id}/can-generate", summary="检查章节是否可以生成")
|
||||
async def check_can_generate(
|
||||
chapter_id: str,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
检查章节是否满足生成条件
|
||||
返回可生成状态和前置章节信息
|
||||
"""
|
||||
# 获取章节
|
||||
result = await db.execute(
|
||||
select(Chapter).where(Chapter.id == chapter_id)
|
||||
)
|
||||
chapter = result.scalar_one_or_none()
|
||||
if not chapter:
|
||||
raise HTTPException(status_code=404, detail="章节不存在")
|
||||
|
||||
# 检查前置条件
|
||||
can_generate, error_msg, previous_chapters = await check_prerequisites(db, chapter)
|
||||
|
||||
# 构建前置章节信息
|
||||
previous_info = [
|
||||
{
|
||||
"id": ch.id,
|
||||
"chapter_number": ch.chapter_number,
|
||||
"title": ch.title,
|
||||
"has_content": bool(ch.content and ch.content.strip()),
|
||||
"word_count": ch.word_count or 0
|
||||
}
|
||||
for ch in previous_chapters
|
||||
]
|
||||
|
||||
return {
|
||||
"can_generate": can_generate,
|
||||
"reason": error_msg if not can_generate else "",
|
||||
"previous_chapters": previous_info,
|
||||
"chapter_number": chapter.chapter_number
|
||||
}
|
||||
|
||||
|
||||
@router.post("/{chapter_id}/generate", summary="AI创作章节内容")
|
||||
async def generate_chapter_content(
|
||||
chapter_id: str,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
根据大纲、前置章节内容和项目信息AI创作章节完整内容
|
||||
要求:必须按顺序生成,确保前置章节都已完成
|
||||
"""
|
||||
# 获取章节
|
||||
result = await db.execute(
|
||||
select(Chapter).where(Chapter.id == chapter_id)
|
||||
)
|
||||
chapter = result.scalar_one_or_none()
|
||||
if not chapter:
|
||||
raise HTTPException(status_code=404, detail="章节不存在")
|
||||
|
||||
# 检查前置条件
|
||||
can_generate, error_msg, previous_chapters = await check_prerequisites(db, chapter)
|
||||
if not can_generate:
|
||||
raise HTTPException(status_code=400, detail=error_msg)
|
||||
|
||||
try:
|
||||
# 获取项目信息
|
||||
project_result = await db.execute(
|
||||
select(Project).where(Project.id == chapter.project_id)
|
||||
)
|
||||
project = project_result.scalar_one_or_none()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="项目不存在")
|
||||
|
||||
# 获取对应的大纲(使用新的查询确保获取最新数据)
|
||||
outline_result = await db.execute(
|
||||
select(Outline)
|
||||
.where(Outline.project_id == chapter.project_id)
|
||||
.where(Outline.order_index == chapter.chapter_number)
|
||||
.execution_options(populate_existing=True)
|
||||
)
|
||||
outline = outline_result.scalar_one_or_none()
|
||||
|
||||
# 获取所有大纲用于上下文(使用新的查询确保获取最新数据)
|
||||
all_outlines_result = await db.execute(
|
||||
select(Outline)
|
||||
.where(Outline.project_id == chapter.project_id)
|
||||
.order_by(Outline.order_index)
|
||||
.execution_options(populate_existing=True)
|
||||
)
|
||||
all_outlines = all_outlines_result.scalars().all()
|
||||
outlines_context = "\n".join([
|
||||
f"第{o.order_index}章 {o.title}: {o.content[:100]}..."
|
||||
for o in all_outlines
|
||||
])
|
||||
|
||||
# 获取角色信息
|
||||
characters_result = await db.execute(
|
||||
select(Character).where(Character.project_id == chapter.project_id)
|
||||
)
|
||||
characters = characters_result.scalars().all()
|
||||
characters_info = "\n".join([
|
||||
f"- {c.name}({'组织' if c.is_organization else '角色'}, {c.role_type}): {c.personality[:100] if c.personality else ''}"
|
||||
for c in characters
|
||||
])
|
||||
|
||||
# 构建前置章节内容上下文(如果有前置章节)
|
||||
previous_content = ""
|
||||
if previous_chapters:
|
||||
# Token控制:保留最近3章的完整内容,早期章节使用摘要
|
||||
recent_chapters = previous_chapters[-3:] if len(previous_chapters) > 3 else previous_chapters
|
||||
early_chapters = previous_chapters[:-3] if len(previous_chapters) > 3 else []
|
||||
|
||||
# 早期章节摘要
|
||||
if early_chapters:
|
||||
early_summary = "【前期剧情概要】\n" + "\n".join([
|
||||
f"第{ch.chapter_number}章《{ch.title}》:{ch.content[:200] if ch.content else ''}..."
|
||||
for ch in early_chapters
|
||||
])
|
||||
previous_content += early_summary + "\n\n"
|
||||
|
||||
# 最近章节完整内容
|
||||
if recent_chapters:
|
||||
recent_content = "【最近章节完整内容】\n" + "\n\n".join([
|
||||
f"=== 第{ch.chapter_number}章:{ch.title} ===\n{ch.content}"
|
||||
for ch in recent_chapters
|
||||
])
|
||||
previous_content += recent_content
|
||||
|
||||
logger.info(f"构建前置上下文:{len(early_chapters)}章摘要 + {len(recent_chapters)}章完整内容")
|
||||
|
||||
# 根据是否有前置内容选择不同的提示词
|
||||
if previous_content:
|
||||
# 使用带上下文的提示词
|
||||
prompt = prompt_service.get_chapter_generation_with_context_prompt(
|
||||
title=project.title,
|
||||
theme=project.theme or '',
|
||||
genre=project.genre or '',
|
||||
narrative_perspective=project.narrative_perspective or '第三人称',
|
||||
time_period=project.world_time_period or '未设定',
|
||||
location=project.world_location or '未设定',
|
||||
atmosphere=project.world_atmosphere or '未设定',
|
||||
rules=project.world_rules or '未设定',
|
||||
characters_info=characters_info or '暂无角色信息',
|
||||
outlines_context=outlines_context,
|
||||
previous_content=previous_content,
|
||||
chapter_number=chapter.chapter_number,
|
||||
chapter_title=chapter.title,
|
||||
chapter_outline=outline.content if outline else chapter.summary or '暂无大纲'
|
||||
)
|
||||
else:
|
||||
# 第一章,使用原有提示词
|
||||
prompt = prompt_service.get_chapter_generation_prompt(
|
||||
title=project.title,
|
||||
theme=project.theme or '',
|
||||
genre=project.genre or '',
|
||||
narrative_perspective=project.narrative_perspective or '第三人称',
|
||||
time_period=project.world_time_period or '未设定',
|
||||
location=project.world_location or '未设定',
|
||||
atmosphere=project.world_atmosphere or '未设定',
|
||||
rules=project.world_rules or '未设定',
|
||||
characters_info=characters_info or '暂无角色信息',
|
||||
outlines_context=outlines_context,
|
||||
chapter_number=chapter.chapter_number,
|
||||
chapter_title=chapter.title,
|
||||
chapter_outline=outline.content if outline else chapter.summary or '暂无大纲'
|
||||
)
|
||||
|
||||
logger.info(f"开始AI创作章节 {chapter_id}")
|
||||
|
||||
# 调用AI生成
|
||||
ai_content = await ai_service.generate_text(
|
||||
prompt=prompt
|
||||
)
|
||||
|
||||
# 更新章节内容
|
||||
old_word_count = chapter.word_count or 0
|
||||
chapter.content = ai_content
|
||||
new_word_count = len(ai_content)
|
||||
chapter.word_count = new_word_count
|
||||
chapter.status = "completed"
|
||||
|
||||
# 更新项目字数
|
||||
project.current_words = project.current_words - old_word_count + new_word_count
|
||||
|
||||
# 记录生成历史
|
||||
history = GenerationHistory(
|
||||
project_id=chapter.project_id,
|
||||
chapter_id=chapter.id,
|
||||
prompt=f"创作章节: 第{chapter.chapter_number}章 {chapter.title}",
|
||||
generated_content=ai_content[:500] if len(ai_content) > 500 else ai_content,
|
||||
model="default"
|
||||
)
|
||||
db.add(history)
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(chapter)
|
||||
|
||||
logger.info(f"成功创作章节 {chapter_id},共 {new_word_count} 字")
|
||||
|
||||
return {"content": ai_content}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"创作章节失败: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"创作章节失败: {str(e)}")
|
||||
|
||||
@router.post("/{chapter_id}/generate-stream", summary="AI创作章节内容(流式)")
|
||||
async def generate_chapter_content_stream(
|
||||
chapter_id: str,
|
||||
request: Request
|
||||
):
|
||||
"""
|
||||
根据大纲、前置章节内容和项目信息AI创作章节完整内容(流式返回)
|
||||
要求:必须按顺序生成,确保前置章节都已完成
|
||||
|
||||
注意:此函数不使用依赖注入的db,而是在生成器内部创建独立的数据库会话
|
||||
以避免流式响应期间的连接泄漏问题
|
||||
"""
|
||||
# 预先验证章节存在性(使用临时会话)
|
||||
async for temp_db in get_db(request):
|
||||
try:
|
||||
result = await temp_db.execute(
|
||||
select(Chapter).where(Chapter.id == chapter_id)
|
||||
)
|
||||
chapter = result.scalar_one_or_none()
|
||||
if not chapter:
|
||||
raise HTTPException(status_code=404, detail="章节不存在")
|
||||
|
||||
# 检查前置条件
|
||||
can_generate, error_msg, previous_chapters = await check_prerequisites(temp_db, chapter)
|
||||
if not can_generate:
|
||||
raise HTTPException(status_code=400, detail=error_msg)
|
||||
|
||||
# 保存前置章节数据供生成器使用
|
||||
previous_chapters_data = [
|
||||
{
|
||||
'id': ch.id,
|
||||
'chapter_number': ch.chapter_number,
|
||||
'title': ch.title,
|
||||
'content': ch.content
|
||||
}
|
||||
for ch in previous_chapters
|
||||
]
|
||||
finally:
|
||||
await temp_db.close()
|
||||
break
|
||||
|
||||
async def event_generator():
|
||||
# 在生成器内部创建独立的数据库会话
|
||||
db_session = None
|
||||
db_committed = False
|
||||
try:
|
||||
# 创建新的数据库会话
|
||||
async for db_session in get_db(request):
|
||||
# 重新获取章节信息
|
||||
chapter_result = await db_session.execute(
|
||||
select(Chapter).where(Chapter.id == chapter_id)
|
||||
)
|
||||
current_chapter = chapter_result.scalar_one_or_none()
|
||||
if not current_chapter:
|
||||
yield f"data: {json.dumps({'type': 'error', 'error': '章节不存在'}, ensure_ascii=False)}\n\n"
|
||||
return
|
||||
|
||||
# 获取项目信息
|
||||
project_result = await db_session.execute(
|
||||
select(Project).where(Project.id == current_chapter.project_id)
|
||||
)
|
||||
project = project_result.scalar_one_or_none()
|
||||
if not project:
|
||||
yield f"data: {json.dumps({'type': 'error', 'error': '项目不存在'}, ensure_ascii=False)}\n\n"
|
||||
return
|
||||
|
||||
# 获取对应的大纲
|
||||
outline_result = await db_session.execute(
|
||||
select(Outline)
|
||||
.where(Outline.project_id == current_chapter.project_id)
|
||||
.where(Outline.order_index == current_chapter.chapter_number)
|
||||
.execution_options(populate_existing=True)
|
||||
)
|
||||
outline = outline_result.scalar_one_or_none()
|
||||
|
||||
# 获取所有大纲用于上下文
|
||||
all_outlines_result = await db_session.execute(
|
||||
select(Outline)
|
||||
.where(Outline.project_id == current_chapter.project_id)
|
||||
.order_by(Outline.order_index)
|
||||
.execution_options(populate_existing=True)
|
||||
)
|
||||
all_outlines = all_outlines_result.scalars().all()
|
||||
outlines_context = "\n".join([
|
||||
f"第{o.order_index}章 {o.title}: {o.content[:100]}..."
|
||||
for o in all_outlines
|
||||
])
|
||||
|
||||
# 获取角色信息
|
||||
characters_result = await db_session.execute(
|
||||
select(Character).where(Character.project_id == current_chapter.project_id)
|
||||
)
|
||||
characters = characters_result.scalars().all()
|
||||
characters_info = "\n".join([
|
||||
f"- {c.name}({'组织' if c.is_organization else '角色'}, {c.role_type}): {c.personality[:100] if c.personality else ''}"
|
||||
for c in characters
|
||||
])
|
||||
|
||||
# 构建前置章节内容上下文(使用之前保存的数据)
|
||||
previous_content = ""
|
||||
if previous_chapters_data:
|
||||
recent_chapters = previous_chapters_data[-3:] if len(previous_chapters_data) > 3 else previous_chapters_data
|
||||
early_chapters = previous_chapters_data[:-3] if len(previous_chapters_data) > 3 else []
|
||||
|
||||
if early_chapters:
|
||||
early_summary = "【前期剧情概要】\n" + "\n".join([
|
||||
f"第{ch['chapter_number']}章《{ch['title']}》:{ch['content'][:200] if ch['content'] else ''}..."
|
||||
for ch in early_chapters
|
||||
])
|
||||
previous_content += early_summary + "\n\n"
|
||||
|
||||
if recent_chapters:
|
||||
recent_content = "【最近章节完整内容】\n" + "\n\n".join([
|
||||
f"=== 第{ch['chapter_number']}章:{ch['title']} ===\n{ch['content']}"
|
||||
for ch in recent_chapters
|
||||
])
|
||||
previous_content += recent_content
|
||||
|
||||
logger.info(f"构建前置上下文:{len(early_chapters)}章摘要 + {len(recent_chapters)}章完整内容")
|
||||
|
||||
# 发送开始事件
|
||||
yield f"data: {json.dumps({'type': 'start', 'message': '开始AI创作...'}, ensure_ascii=False)}\n\n"
|
||||
|
||||
# 根据是否有前置内容选择不同的提示词
|
||||
if previous_content:
|
||||
prompt = prompt_service.get_chapter_generation_with_context_prompt(
|
||||
title=project.title,
|
||||
theme=project.theme or '',
|
||||
genre=project.genre or '',
|
||||
narrative_perspective=project.narrative_perspective or '第三人称',
|
||||
time_period=project.world_time_period or '未设定',
|
||||
location=project.world_location or '未设定',
|
||||
atmosphere=project.world_atmosphere or '未设定',
|
||||
rules=project.world_rules or '未设定',
|
||||
characters_info=characters_info or '暂无角色信息',
|
||||
outlines_context=outlines_context,
|
||||
previous_content=previous_content,
|
||||
chapter_number=current_chapter.chapter_number,
|
||||
chapter_title=current_chapter.title,
|
||||
chapter_outline=outline.content if outline else current_chapter.summary or '暂无大纲'
|
||||
)
|
||||
else:
|
||||
prompt = prompt_service.get_chapter_generation_prompt(
|
||||
title=project.title,
|
||||
theme=project.theme or '',
|
||||
genre=project.genre or '',
|
||||
narrative_perspective=project.narrative_perspective or '第三人称',
|
||||
time_period=project.world_time_period or '未设定',
|
||||
location=project.world_location or '未设定',
|
||||
atmosphere=project.world_atmosphere or '未设定',
|
||||
rules=project.world_rules or '未设定',
|
||||
characters_info=characters_info or '暂无角色信息',
|
||||
outlines_context=outlines_context,
|
||||
chapter_number=current_chapter.chapter_number,
|
||||
chapter_title=current_chapter.title,
|
||||
chapter_outline=outline.content if outline else current_chapter.summary or '暂无大纲'
|
||||
)
|
||||
|
||||
logger.info(f"开始AI流式创作章节 {chapter_id}")
|
||||
|
||||
# 流式生成内容
|
||||
full_content = ""
|
||||
async for chunk in ai_service.generate_text_stream(prompt=prompt):
|
||||
full_content += chunk
|
||||
yield f"data: {json.dumps({'type': 'content', 'content': chunk}, ensure_ascii=False)}\n\n"
|
||||
await asyncio.sleep(0) # 让出控制权
|
||||
|
||||
# 更新章节内容到数据库
|
||||
old_word_count = current_chapter.word_count or 0
|
||||
current_chapter.content = full_content
|
||||
new_word_count = len(full_content)
|
||||
current_chapter.word_count = new_word_count
|
||||
current_chapter.status = "completed"
|
||||
|
||||
# 更新项目字数
|
||||
project.current_words = project.current_words - old_word_count + new_word_count
|
||||
|
||||
# 记录生成历史
|
||||
history = GenerationHistory(
|
||||
project_id=current_chapter.project_id,
|
||||
chapter_id=current_chapter.id,
|
||||
prompt=f"创作章节: 第{current_chapter.chapter_number}章 {current_chapter.title}",
|
||||
generated_content=full_content[:500] if len(full_content) > 500 else full_content,
|
||||
model="default"
|
||||
)
|
||||
db_session.add(history)
|
||||
|
||||
await db_session.commit()
|
||||
db_committed = True
|
||||
await db_session.refresh(current_chapter)
|
||||
|
||||
logger.info(f"成功创作章节 {chapter_id},共 {new_word_count} 字")
|
||||
|
||||
# 发送完成事件
|
||||
yield f"data: {json.dumps({'type': 'done', 'message': '创作完成', 'word_count': new_word_count}, ensure_ascii=False)}\n\n"
|
||||
|
||||
break # 退出async for db_session循环
|
||||
|
||||
except GeneratorExit:
|
||||
# SSE连接断开
|
||||
logger.warning("章节生成器被提前关闭(SSE断开)")
|
||||
if db_session and not db_committed:
|
||||
try:
|
||||
if db_session.in_transaction():
|
||||
await db_session.rollback()
|
||||
logger.info("章节生成事务已回滚(GeneratorExit)")
|
||||
except Exception as e:
|
||||
logger.error(f"GeneratorExit回滚失败: {str(e)}")
|
||||
except Exception as e:
|
||||
logger.error(f"流式创作章节失败: {str(e)}")
|
||||
if db_session and not db_committed:
|
||||
try:
|
||||
if db_session.in_transaction():
|
||||
await db_session.rollback()
|
||||
logger.info("章节生成事务已回滚(异常)")
|
||||
except Exception as rollback_error:
|
||||
logger.error(f"回滚失败: {str(rollback_error)}")
|
||||
yield f"data: {json.dumps({'type': 'error', 'error': str(e)}, ensure_ascii=False)}\n\n"
|
||||
finally:
|
||||
# 确保数据库会话被正确关闭
|
||||
if db_session:
|
||||
try:
|
||||
# 最后检查:确保没有未提交的事务
|
||||
if not db_committed and db_session.in_transaction():
|
||||
await db_session.rollback()
|
||||
logger.warning("在finally中发现未提交事务,已回滚")
|
||||
|
||||
await db_session.close()
|
||||
logger.info("数据库会话已关闭")
|
||||
except Exception as close_error:
|
||||
logger.error(f"关闭数据库会话失败: {str(close_error)}")
|
||||
# 强制关闭
|
||||
try:
|
||||
await db_session.close()
|
||||
except:
|
||||
pass
|
||||
|
||||
return StreamingResponse(
|
||||
event_generator(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no"
|
||||
}
|
||||
)
|
||||
@@ -0,0 +1,491 @@
|
||||
"""角色管理API"""
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, func
|
||||
import json
|
||||
|
||||
from app.database import get_db
|
||||
from app.models.character import Character
|
||||
from app.models.project import Project
|
||||
from app.models.generation_history import GenerationHistory
|
||||
from app.models.relationship import CharacterRelationship, Organization, OrganizationMember, RelationshipType
|
||||
from app.schemas.character import (
|
||||
CharacterUpdate,
|
||||
CharacterResponse,
|
||||
CharacterListResponse,
|
||||
CharacterGenerateRequest
|
||||
)
|
||||
from app.services.ai_service import ai_service
|
||||
from app.services.prompt_service import prompt_service
|
||||
from app.logger import get_logger
|
||||
|
||||
router = APIRouter(prefix="/characters", tags=["角色管理"])
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@router.get("", response_model=CharacterListResponse, summary="获取角色列表")
|
||||
async def get_characters(
|
||||
project_id: str,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""获取指定项目的所有角色(query参数版本)"""
|
||||
# 获取总数
|
||||
count_result = await db.execute(
|
||||
select(func.count(Character.id)).where(Character.project_id == project_id)
|
||||
)
|
||||
total = count_result.scalar_one()
|
||||
|
||||
# 获取角色列表
|
||||
result = await db.execute(
|
||||
select(Character)
|
||||
.where(Character.project_id == project_id)
|
||||
.order_by(Character.created_at.desc())
|
||||
)
|
||||
characters = result.scalars().all()
|
||||
|
||||
return CharacterListResponse(total=total, items=characters)
|
||||
|
||||
|
||||
@router.get("/project/{project_id}", response_model=CharacterListResponse, summary="获取项目的所有角色")
|
||||
async def get_project_characters(
|
||||
project_id: str,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""获取指定项目的所有角色(路径参数版本)"""
|
||||
# 获取总数
|
||||
count_result = await db.execute(
|
||||
select(func.count(Character.id)).where(Character.project_id == project_id)
|
||||
)
|
||||
total = count_result.scalar_one()
|
||||
|
||||
# 获取角色列表
|
||||
result = await db.execute(
|
||||
select(Character)
|
||||
.where(Character.project_id == project_id)
|
||||
.order_by(Character.created_at.desc())
|
||||
)
|
||||
characters = result.scalars().all()
|
||||
|
||||
return CharacterListResponse(total=total, items=characters)
|
||||
|
||||
|
||||
@router.get("/{character_id}", response_model=CharacterResponse, summary="获取角色详情")
|
||||
async def get_character(
|
||||
character_id: str,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""根据ID获取角色详情"""
|
||||
result = await db.execute(
|
||||
select(Character).where(Character.id == character_id)
|
||||
)
|
||||
character = result.scalar_one_or_none()
|
||||
|
||||
if not character:
|
||||
raise HTTPException(status_code=404, detail="角色不存在")
|
||||
|
||||
return character
|
||||
|
||||
|
||||
@router.put("/{character_id}", response_model=CharacterResponse, summary="更新角色")
|
||||
async def update_character(
|
||||
character_id: str,
|
||||
character_update: CharacterUpdate,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""更新角色信息"""
|
||||
result = await db.execute(
|
||||
select(Character).where(Character.id == character_id)
|
||||
)
|
||||
character = result.scalar_one_or_none()
|
||||
|
||||
if not character:
|
||||
raise HTTPException(status_code=404, detail="角色不存在")
|
||||
|
||||
# 更新字段
|
||||
update_data = character_update.model_dump(exclude_unset=True)
|
||||
for field, value in update_data.items():
|
||||
setattr(character, field, value)
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(character)
|
||||
return character
|
||||
|
||||
|
||||
@router.delete("/{character_id}", summary="删除角色")
|
||||
async def delete_character(
|
||||
character_id: str,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""删除角色"""
|
||||
result = await db.execute(
|
||||
select(Character).where(Character.id == character_id)
|
||||
)
|
||||
character = result.scalar_one_or_none()
|
||||
|
||||
if not character:
|
||||
raise HTTPException(status_code=404, detail="角色不存在")
|
||||
|
||||
await db.delete(character)
|
||||
await db.commit()
|
||||
|
||||
return {"message": "角色删除成功"}
|
||||
|
||||
|
||||
@router.post("/generate", response_model=CharacterResponse, summary="AI生成角色")
|
||||
async def generate_character(
|
||||
request: CharacterGenerateRequest,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
使用AI生成角色卡
|
||||
|
||||
根据用户输入的信息,结合项目的世界观、主题等背景,
|
||||
AI会生成一个完整、详细的角色设定卡片。
|
||||
|
||||
生成内容包括:姓名、年龄、性别、性格、外貌、背景故事、人际关系等
|
||||
"""
|
||||
# 验证项目是否存在并获取项目信息
|
||||
result = await db.execute(
|
||||
select(Project).where(Project.id == request.project_id)
|
||||
)
|
||||
project = result.scalar_one_or_none()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="项目不存在")
|
||||
|
||||
try:
|
||||
# 获取已存在的角色列表,用于关系网络
|
||||
existing_chars_result = await db.execute(
|
||||
select(Character)
|
||||
.where(Character.project_id == request.project_id)
|
||||
.order_by(Character.created_at.desc())
|
||||
)
|
||||
existing_characters = existing_chars_result.scalars().all()
|
||||
|
||||
# 构建现有角色信息摘要(包含组织)
|
||||
existing_chars_info = ""
|
||||
character_list = []
|
||||
organization_list = []
|
||||
|
||||
if existing_characters:
|
||||
for c in existing_characters[:10]: # 最多显示10个
|
||||
if c.is_organization:
|
||||
organization_list.append(f"- {c.name} [{c.organization_type or '组织'}]")
|
||||
else:
|
||||
character_list.append(f"- {c.name}({c.role_type or '未知'})")
|
||||
|
||||
if character_list:
|
||||
existing_chars_info += "\n已有角色:\n" + "\n".join(character_list)
|
||||
if organization_list:
|
||||
existing_chars_info += "\n\n已有组织:\n" + "\n".join(organization_list)
|
||||
|
||||
# 构建项目上下文信息
|
||||
project_context = f"""
|
||||
项目信息:
|
||||
- 书名:{project.title}
|
||||
- 主题:{project.theme or '未设定'}
|
||||
- 类型:{project.genre or '未设定'}
|
||||
- 时间背景:{project.world_time_period or '未设定'}
|
||||
- 地理位置:{project.world_location or '未设定'}
|
||||
- 氛围基调:{project.world_atmosphere or '未设定'}
|
||||
- 世界规则:{project.world_rules or '未设定'}
|
||||
{existing_chars_info}
|
||||
"""
|
||||
|
||||
# 构建用户输入信息
|
||||
user_input = f"""
|
||||
用户要求:
|
||||
- 角色名称:{request.name or '请AI生成'}
|
||||
- 角色定位:{request.role_type or 'supporting'}(protagonist=主角, supporting=配角, antagonist=反派)
|
||||
- 背景设定:{request.background or '无特殊要求'}
|
||||
- 其他要求:{request.requirements or '无'}
|
||||
"""
|
||||
|
||||
# 使用统一的提示词服务
|
||||
prompt = prompt_service.get_single_character_prompt(
|
||||
project_context=project_context,
|
||||
user_input=user_input
|
||||
)
|
||||
|
||||
# 调用AI生成角色
|
||||
logger.info(f"🎯 开始为项目 {request.project_id} 生成角色")
|
||||
logger.info(f" - 角色名:{request.name or 'AI生成'}")
|
||||
logger.info(f" - 角色定位:{request.role_type}")
|
||||
logger.info(f" - 背景设定:{request.background or '无'}")
|
||||
logger.info(f" - AI提供商:{request.provider or 'default'}")
|
||||
logger.info(f" - AI模型:{request.model or 'default'}")
|
||||
logger.info(f" - Prompt长度:{len(prompt)} 字符")
|
||||
|
||||
try:
|
||||
ai_response = await ai_service.generate_text(
|
||||
prompt=prompt,
|
||||
provider=request.provider,
|
||||
model=request.model
|
||||
)
|
||||
logger.info(f"✅ AI响应接收完成,长度:{len(ai_response) if ai_response else 0} 字符")
|
||||
except Exception as ai_error:
|
||||
logger.error(f"❌ AI服务调用异常:{str(ai_error)}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"AI服务调用失败:{str(ai_error)}"
|
||||
)
|
||||
|
||||
# 检查AI响应
|
||||
if not ai_response or not ai_response.strip():
|
||||
logger.error("❌ AI返回了空响应")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="AI服务返回空响应。可能原因:1) API配置错误 2) 模型不支持 3) 网络问题。请检查后端日志。"
|
||||
)
|
||||
|
||||
logger.info(f"📝 开始清理AI响应")
|
||||
# 清理AI响应,移除可能的markdown标记
|
||||
cleaned_response = ai_response.strip()
|
||||
original_length = len(cleaned_response)
|
||||
|
||||
if cleaned_response.startswith("```json"):
|
||||
cleaned_response = cleaned_response[7:]
|
||||
logger.info(" - 移除了 ```json 标记")
|
||||
if cleaned_response.startswith("```"):
|
||||
cleaned_response = cleaned_response[3:]
|
||||
logger.info(" - 移除了 ``` 标记")
|
||||
if cleaned_response.endswith("```"):
|
||||
cleaned_response = cleaned_response[:-3]
|
||||
logger.info(" - 移除了末尾 ``` 标记")
|
||||
cleaned_response = cleaned_response.strip()
|
||||
|
||||
logger.info(f" - 清理前长度:{original_length},清理后长度:{len(cleaned_response)}")
|
||||
logger.info(f" - 清理后内容预览(前300字符):{cleaned_response[:300]}")
|
||||
|
||||
# 解析AI响应
|
||||
logger.info(f"🔍 开始解析JSON")
|
||||
try:
|
||||
character_data = json.loads(cleaned_response)
|
||||
logger.info(f"✅ JSON解析成功")
|
||||
logger.info(f" - 解析后的字段:{list(character_data.keys())}")
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"❌ JSON解析失败")
|
||||
logger.error(f" - 错误位置:line {e.lineno}, column {e.colno}")
|
||||
logger.error(f" - 错误信息:{str(e)}")
|
||||
logger.error(f" - 完整响应内容(前1000字符):{cleaned_response[:1000]}")
|
||||
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"AI返回的内容无法解析为JSON。错误:{str(e)}。响应内容已记录到日志,请查看后端日志排查。"
|
||||
)
|
||||
|
||||
# 转换traits为JSON字符串
|
||||
traits_json = json.dumps(character_data.get("traits", []), ensure_ascii=False) if character_data.get("traits") else None
|
||||
|
||||
# 判断是否为组织
|
||||
is_organization = character_data.get("is_organization", False)
|
||||
|
||||
# 创建角色
|
||||
character = Character(
|
||||
project_id=request.project_id,
|
||||
name=character_data.get("name", request.name or "未命名角色"),
|
||||
age=str(character_data.get("age", "")),
|
||||
gender=character_data.get("gender"),
|
||||
is_organization=is_organization,
|
||||
role_type=request.role_type or "supporting",
|
||||
personality=character_data.get("personality", ""),
|
||||
background=character_data.get("background", ""),
|
||||
appearance=character_data.get("appearance", ""),
|
||||
relationships=character_data.get("relationships_text", character_data.get("relationships", "")), # 优先使用文本描述
|
||||
organization_type=character_data.get("organization_type") if is_organization else None,
|
||||
organization_purpose=character_data.get("organization_purpose") if is_organization else None,
|
||||
organization_members=json.dumps(character_data.get("organization_members", []), ensure_ascii=False) if is_organization else None,
|
||||
traits=traits_json
|
||||
)
|
||||
db.add(character)
|
||||
await db.flush() # 获取character.id
|
||||
|
||||
logger.info(f"✅ 角色创建成功:{character.name} (ID: {character.id}, 是否组织: {is_organization})")
|
||||
|
||||
# 如果是组织,自动创建Organization详情记录
|
||||
if is_organization:
|
||||
org_check = await db.execute(
|
||||
select(Organization).where(Organization.character_id == character.id)
|
||||
)
|
||||
existing_org = org_check.scalar_one_or_none()
|
||||
|
||||
if not existing_org:
|
||||
organization = Organization(
|
||||
character_id=character.id,
|
||||
project_id=request.project_id,
|
||||
member_count=0,
|
||||
power_level=character_data.get("power_level", 50),
|
||||
location=character_data.get("location"),
|
||||
motto=character_data.get("motto")
|
||||
)
|
||||
db.add(organization)
|
||||
await db.flush()
|
||||
logger.info(f"✅ 自动创建组织详情:{character.name} (Org ID: {organization.id})")
|
||||
else:
|
||||
logger.info(f"ℹ️ 组织详情已存在:{character.name}")
|
||||
|
||||
# 处理结构化关系数据(仅针对非组织角色)
|
||||
if not is_organization:
|
||||
relationships_data = character_data.get("relationships", [])
|
||||
if relationships_data and isinstance(relationships_data, list):
|
||||
logger.info(f"📊 开始处理 {len(relationships_data)} 条关系数据")
|
||||
created_rels = 0
|
||||
|
||||
for rel in relationships_data:
|
||||
try:
|
||||
target_name = rel.get("target_character_name")
|
||||
if not target_name:
|
||||
logger.debug(f" ⚠️ 关系缺少target_character_name,跳过")
|
||||
continue
|
||||
|
||||
target_result = await db.execute(
|
||||
select(Character).where(
|
||||
Character.project_id == request.project_id,
|
||||
Character.name == target_name
|
||||
)
|
||||
)
|
||||
target_char = target_result.scalar_one_or_none()
|
||||
|
||||
if target_char:
|
||||
# 检查是否已存在相同关系
|
||||
existing_rel = await db.execute(
|
||||
select(CharacterRelationship).where(
|
||||
CharacterRelationship.project_id == request.project_id,
|
||||
CharacterRelationship.character_from_id == character.id,
|
||||
CharacterRelationship.character_to_id == target_char.id
|
||||
)
|
||||
)
|
||||
if existing_rel.scalar_one_or_none():
|
||||
logger.debug(f" ℹ️ 关系已存在:{character.name} -> {target_name}")
|
||||
continue
|
||||
|
||||
relationship = CharacterRelationship(
|
||||
project_id=request.project_id,
|
||||
character_from_id=character.id,
|
||||
character_to_id=target_char.id,
|
||||
relationship_name=rel.get("relationship_type", "未知关系"),
|
||||
intimacy_level=rel.get("intimacy_level", 50),
|
||||
description=rel.get("description", ""),
|
||||
started_at=rel.get("started_at"),
|
||||
source="ai"
|
||||
)
|
||||
|
||||
# 匹配预定义关系类型
|
||||
rel_type_result = await db.execute(
|
||||
select(RelationshipType).where(
|
||||
RelationshipType.name == rel.get("relationship_type")
|
||||
)
|
||||
)
|
||||
rel_type = rel_type_result.scalar_one_or_none()
|
||||
if rel_type:
|
||||
relationship.relationship_type_id = rel_type.id
|
||||
|
||||
db.add(relationship)
|
||||
created_rels += 1
|
||||
logger.info(f" ✅ 创建关系:{character.name} -> {target_name} ({rel.get('relationship_type')})")
|
||||
else:
|
||||
logger.warning(f" ⚠️ 目标角色不存在:{target_name}")
|
||||
|
||||
except Exception as rel_error:
|
||||
logger.warning(f" ❌ 创建关系失败:{str(rel_error)}")
|
||||
continue
|
||||
|
||||
logger.info(f"✅ 成功创建 {created_rels} 条关系记录")
|
||||
|
||||
# 处理组织成员关系(仅针对非组织角色)
|
||||
if not is_organization:
|
||||
org_memberships = character_data.get("organization_memberships", [])
|
||||
if org_memberships and isinstance(org_memberships, list):
|
||||
logger.info(f"🏢 开始处理 {len(org_memberships)} 条组织成员关系")
|
||||
created_members = 0
|
||||
|
||||
for membership in org_memberships:
|
||||
try:
|
||||
org_name = membership.get("organization_name")
|
||||
if not org_name:
|
||||
logger.debug(f" ⚠️ 组织成员关系缺少organization_name,跳过")
|
||||
continue
|
||||
|
||||
org_char_result = await db.execute(
|
||||
select(Character).where(
|
||||
Character.project_id == request.project_id,
|
||||
Character.name == org_name,
|
||||
Character.is_organization == True
|
||||
)
|
||||
)
|
||||
org_char = org_char_result.scalar_one_or_none()
|
||||
|
||||
if org_char:
|
||||
# 获取或创建Organization记录
|
||||
org_result = await db.execute(
|
||||
select(Organization).where(Organization.character_id == org_char.id)
|
||||
)
|
||||
org = org_result.scalar_one_or_none()
|
||||
|
||||
if not org:
|
||||
# 如果组织Character存在但Organization不存在,自动创建
|
||||
org = Organization(
|
||||
character_id=org_char.id,
|
||||
project_id=request.project_id,
|
||||
member_count=0
|
||||
)
|
||||
db.add(org)
|
||||
await db.flush()
|
||||
logger.info(f" ℹ️ 自动创建缺失的组织详情:{org_name}")
|
||||
|
||||
# 检查是否已存在成员关系
|
||||
existing_member = await db.execute(
|
||||
select(OrganizationMember).where(
|
||||
OrganizationMember.organization_id == org.id,
|
||||
OrganizationMember.character_id == character.id
|
||||
)
|
||||
)
|
||||
if existing_member.scalar_one_or_none():
|
||||
logger.debug(f" ℹ️ 成员关系已存在:{character.name} -> {org_name}")
|
||||
continue
|
||||
|
||||
# 创建成员关系
|
||||
member = OrganizationMember(
|
||||
organization_id=org.id,
|
||||
character_id=character.id,
|
||||
position=membership.get("position", "成员"),
|
||||
rank=membership.get("rank", 0),
|
||||
loyalty=membership.get("loyalty", 50),
|
||||
joined_at=membership.get("joined_at"),
|
||||
status=membership.get("status", "active"),
|
||||
source="ai"
|
||||
)
|
||||
db.add(member)
|
||||
|
||||
# 更新组织成员计数
|
||||
org.member_count += 1
|
||||
|
||||
created_members += 1
|
||||
logger.info(f" ✅ 添加成员:{character.name} -> {org_name} ({membership.get('position')})")
|
||||
else:
|
||||
logger.warning(f" ⚠️ 组织不存在:{org_name}")
|
||||
|
||||
except Exception as org_error:
|
||||
logger.warning(f" ❌ 添加组织成员失败:{str(org_error)}")
|
||||
continue
|
||||
|
||||
logger.info(f"✅ 成功创建 {created_members} 条组织成员记录")
|
||||
|
||||
# 记录生成历史
|
||||
history = GenerationHistory(
|
||||
project_id=request.project_id,
|
||||
prompt=prompt,
|
||||
generated_content=ai_response,
|
||||
model=request.model or "default"
|
||||
)
|
||||
db.add(history)
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(character)
|
||||
|
||||
logger.info(f"🎉 成功为项目 {request.project_id} 生成角色: {character.name}")
|
||||
|
||||
return character
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"生成角色失败: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"生成角色失败: {str(e)}")
|
||||
@@ -0,0 +1,341 @@
|
||||
"""组织管理API"""
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, and_
|
||||
from typing import List
|
||||
|
||||
from app.database import get_db
|
||||
from app.models.relationship import Organization, OrganizationMember
|
||||
from app.models.character import Character
|
||||
from app.schemas.relationship import (
|
||||
OrganizationCreate,
|
||||
OrganizationUpdate,
|
||||
OrganizationResponse,
|
||||
OrganizationDetailResponse,
|
||||
OrganizationMemberCreate,
|
||||
OrganizationMemberUpdate,
|
||||
OrganizationMemberResponse,
|
||||
OrganizationMemberDetailResponse
|
||||
)
|
||||
from app.logger import get_logger
|
||||
|
||||
router = APIRouter(prefix="/organizations", tags=["组织管理"])
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@router.get("/project/{project_id}", response_model=List[OrganizationDetailResponse], summary="获取项目的所有组织")
|
||||
async def get_project_organizations(
|
||||
project_id: str,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
获取项目中的所有组织及其详情
|
||||
|
||||
返回组织的基本信息和统计数据
|
||||
"""
|
||||
result = await db.execute(
|
||||
select(Organization).where(Organization.project_id == project_id)
|
||||
)
|
||||
organizations = result.scalars().all()
|
||||
|
||||
# 获取每个组织的角色信息
|
||||
org_list = []
|
||||
for org in organizations:
|
||||
char_result = await db.execute(
|
||||
select(Character).where(Character.id == org.character_id)
|
||||
)
|
||||
char = char_result.scalar_one_or_none()
|
||||
|
||||
if char:
|
||||
org_list.append(OrganizationDetailResponse(
|
||||
id=org.id,
|
||||
character_id=org.character_id,
|
||||
name=char.name,
|
||||
type=char.organization_type,
|
||||
purpose=char.organization_purpose,
|
||||
member_count=org.member_count,
|
||||
power_level=org.power_level,
|
||||
location=org.location,
|
||||
motto=org.motto,
|
||||
color=org.color
|
||||
))
|
||||
|
||||
logger.info(f"获取项目 {project_id} 的组织列表,共 {len(org_list)} 个")
|
||||
return org_list
|
||||
|
||||
|
||||
@router.get("/{org_id}", response_model=OrganizationResponse, summary="获取组织详情")
|
||||
async def get_organization(
|
||||
org_id: str,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""获取组织的详细信息"""
|
||||
result = await db.execute(
|
||||
select(Organization).where(Organization.id == org_id)
|
||||
)
|
||||
org = result.scalar_one_or_none()
|
||||
|
||||
if not org:
|
||||
raise HTTPException(status_code=404, detail="组织不存在")
|
||||
|
||||
return org
|
||||
|
||||
|
||||
@router.post("/", response_model=OrganizationResponse, summary="创建组织")
|
||||
async def create_organization(
|
||||
organization: OrganizationCreate,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
创建新组织
|
||||
|
||||
- 需要关联到一个已存在的角色记录(is_organization=True)
|
||||
- 可以设置父组织、势力等级等属性
|
||||
"""
|
||||
# 验证角色是否存在且是组织
|
||||
char_result = await db.execute(
|
||||
select(Character).where(Character.id == organization.character_id)
|
||||
)
|
||||
char = char_result.scalar_one_or_none()
|
||||
|
||||
if not char:
|
||||
raise HTTPException(status_code=404, detail="关联的角色不存在")
|
||||
if not char.is_organization:
|
||||
raise HTTPException(status_code=400, detail="关联的角色不是组织类型")
|
||||
|
||||
# 检查是否已存在
|
||||
existing = await db.execute(
|
||||
select(Organization).where(Organization.character_id == organization.character_id)
|
||||
)
|
||||
if existing.scalar_one_or_none():
|
||||
raise HTTPException(status_code=400, detail="该角色已有组织详情记录")
|
||||
|
||||
# 创建组织
|
||||
db_org = Organization(**organization.model_dump())
|
||||
db.add(db_org)
|
||||
await db.commit()
|
||||
await db.refresh(db_org)
|
||||
|
||||
logger.info(f"创建组织成功:{db_org.id} - {char.name}")
|
||||
return db_org
|
||||
|
||||
|
||||
@router.put("/{org_id}", response_model=OrganizationResponse, summary="更新组织")
|
||||
async def update_organization(
|
||||
org_id: str,
|
||||
organization: OrganizationUpdate,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""更新组织的属性"""
|
||||
result = await db.execute(
|
||||
select(Organization).where(Organization.id == org_id)
|
||||
)
|
||||
db_org = result.scalar_one_or_none()
|
||||
|
||||
if not db_org:
|
||||
raise HTTPException(status_code=404, detail="组织不存在")
|
||||
|
||||
# 更新字段
|
||||
update_data = organization.model_dump(exclude_unset=True)
|
||||
for field, value in update_data.items():
|
||||
setattr(db_org, field, value)
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(db_org)
|
||||
|
||||
logger.info(f"更新组织成功:{org_id}")
|
||||
return db_org
|
||||
|
||||
|
||||
@router.delete("/{org_id}", summary="删除组织")
|
||||
async def delete_organization(
|
||||
org_id: str,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""删除组织(会级联删除所有成员关系)"""
|
||||
result = await db.execute(
|
||||
select(Organization).where(Organization.id == org_id)
|
||||
)
|
||||
db_org = result.scalar_one_or_none()
|
||||
|
||||
if not db_org:
|
||||
raise HTTPException(status_code=404, detail="组织不存在")
|
||||
|
||||
await db.delete(db_org)
|
||||
await db.commit()
|
||||
|
||||
logger.info(f"删除组织成功:{org_id}")
|
||||
return {"message": "组织删除成功", "id": org_id}
|
||||
|
||||
|
||||
# ============ 组织成员管理 ============
|
||||
|
||||
@router.get("/{org_id}/members", response_model=List[OrganizationMemberDetailResponse], summary="获取组织成员")
|
||||
async def get_organization_members(
|
||||
org_id: str,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
获取组织的所有成员
|
||||
|
||||
按职位等级(rank)降序排列
|
||||
"""
|
||||
# 验证组织存在
|
||||
org_result = await db.execute(
|
||||
select(Organization).where(Organization.id == org_id)
|
||||
)
|
||||
if not org_result.scalar_one_or_none():
|
||||
raise HTTPException(status_code=404, detail="组织不存在")
|
||||
|
||||
# 获取成员列表
|
||||
result = await db.execute(
|
||||
select(OrganizationMember)
|
||||
.where(OrganizationMember.organization_id == org_id)
|
||||
.order_by(OrganizationMember.rank.desc(), OrganizationMember.created_at)
|
||||
)
|
||||
members = result.scalars().all()
|
||||
|
||||
# 获取成员角色信息
|
||||
member_list = []
|
||||
for member in members:
|
||||
char_result = await db.execute(
|
||||
select(Character).where(Character.id == member.character_id)
|
||||
)
|
||||
char = char_result.scalar_one_or_none()
|
||||
|
||||
if char:
|
||||
member_list.append(OrganizationMemberDetailResponse(
|
||||
id=member.id,
|
||||
character_id=member.character_id,
|
||||
character_name=char.name,
|
||||
position=member.position,
|
||||
rank=member.rank,
|
||||
loyalty=member.loyalty,
|
||||
contribution=member.contribution,
|
||||
status=member.status,
|
||||
joined_at=member.joined_at,
|
||||
left_at=member.left_at,
|
||||
notes=member.notes
|
||||
))
|
||||
|
||||
logger.info(f"获取组织 {org_id} 的成员列表,共 {len(member_list)} 人")
|
||||
return member_list
|
||||
|
||||
|
||||
@router.post("/{org_id}/members", response_model=OrganizationMemberResponse, summary="添加组织成员")
|
||||
async def add_organization_member(
|
||||
org_id: str,
|
||||
member: OrganizationMemberCreate,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
添加角色到组织
|
||||
|
||||
- 一个角色在同一组织中只能有一个职位
|
||||
- 会自动更新组织的成员计数
|
||||
"""
|
||||
# 验证组织存在
|
||||
org_result = await db.execute(
|
||||
select(Organization).where(Organization.id == org_id)
|
||||
)
|
||||
org = org_result.scalar_one_or_none()
|
||||
if not org:
|
||||
raise HTTPException(status_code=404, detail="组织不存在")
|
||||
|
||||
# 验证角色存在
|
||||
char_result = await db.execute(
|
||||
select(Character).where(Character.id == member.character_id)
|
||||
)
|
||||
char = char_result.scalar_one_or_none()
|
||||
if not char:
|
||||
raise HTTPException(status_code=404, detail="角色不存在")
|
||||
if char.is_organization:
|
||||
raise HTTPException(status_code=400, detail="不能将组织添加为成员")
|
||||
|
||||
# 检查是否已存在
|
||||
existing = await db.execute(
|
||||
select(OrganizationMember).where(
|
||||
and_(
|
||||
OrganizationMember.organization_id == org_id,
|
||||
OrganizationMember.character_id == member.character_id
|
||||
)
|
||||
)
|
||||
)
|
||||
if existing.scalar_one_or_none():
|
||||
raise HTTPException(status_code=400, detail="该角色已在组织中")
|
||||
|
||||
# 创建成员关系
|
||||
db_member = OrganizationMember(
|
||||
organization_id=org_id,
|
||||
**member.model_dump(),
|
||||
source="manual"
|
||||
)
|
||||
db.add(db_member)
|
||||
|
||||
# 更新组织成员计数
|
||||
org.member_count += 1
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(db_member)
|
||||
|
||||
logger.info(f"添加成员成功:{char.name} 加入组织 {org_id}")
|
||||
return db_member
|
||||
|
||||
|
||||
@router.put("/members/{member_id}", response_model=OrganizationMemberResponse, summary="更新成员信息")
|
||||
async def update_organization_member(
|
||||
member_id: str,
|
||||
member: OrganizationMemberUpdate,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""更新组织成员的职位、忠诚度等信息"""
|
||||
result = await db.execute(
|
||||
select(OrganizationMember).where(OrganizationMember.id == member_id)
|
||||
)
|
||||
db_member = result.scalar_one_or_none()
|
||||
|
||||
if not db_member:
|
||||
raise HTTPException(status_code=404, detail="成员记录不存在")
|
||||
|
||||
# 更新字段
|
||||
update_data = member.model_dump(exclude_unset=True)
|
||||
for field, value in update_data.items():
|
||||
setattr(db_member, field, value)
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(db_member)
|
||||
|
||||
logger.info(f"更新成员信息成功:{member_id}")
|
||||
return db_member
|
||||
|
||||
|
||||
@router.delete("/members/{member_id}", summary="移除组织成员")
|
||||
async def remove_organization_member(
|
||||
member_id: str,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
从组织中移除成员
|
||||
|
||||
会自动更新组织的成员计数
|
||||
"""
|
||||
result = await db.execute(
|
||||
select(OrganizationMember).where(OrganizationMember.id == member_id)
|
||||
)
|
||||
db_member = result.scalar_one_or_none()
|
||||
|
||||
if not db_member:
|
||||
raise HTTPException(status_code=404, detail="成员记录不存在")
|
||||
|
||||
# 更新组织成员计数
|
||||
org_result = await db.execute(
|
||||
select(Organization).where(Organization.id == db_member.organization_id)
|
||||
)
|
||||
org = org_result.scalar_one()
|
||||
org.member_count = max(0, org.member_count - 1)
|
||||
|
||||
await db.delete(db_member)
|
||||
await db.commit()
|
||||
|
||||
logger.info(f"移除成员成功:{member_id}")
|
||||
return {"message": "成员移除成功", "id": member_id}
|
||||
@@ -0,0 +1,657 @@
|
||||
"""大纲管理API"""
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, func, delete
|
||||
from typing import List
|
||||
import json
|
||||
|
||||
from app.database import get_db
|
||||
from app.models.outline import Outline
|
||||
from app.models.project import Project
|
||||
from app.models.chapter import Chapter
|
||||
from app.models.character import Character
|
||||
from app.models.generation_history import GenerationHistory
|
||||
from app.schemas.outline import (
|
||||
OutlineCreate,
|
||||
OutlineUpdate,
|
||||
OutlineResponse,
|
||||
OutlineListResponse,
|
||||
OutlineGenerateRequest,
|
||||
OutlineReorderRequest
|
||||
)
|
||||
from app.services.ai_service import ai_service
|
||||
from app.services.prompt_service import prompt_service
|
||||
from app.logger import get_logger
|
||||
|
||||
router = APIRouter(prefix="/outlines", tags=["大纲管理"])
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@router.post("", response_model=OutlineResponse, summary="创建大纲")
|
||||
async def create_outline(
|
||||
outline: OutlineCreate,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""创建新的章节大纲,同时创建对应的章节记录"""
|
||||
# 验证项目是否存在
|
||||
result = await db.execute(
|
||||
select(Project).where(Project.id == outline.project_id)
|
||||
)
|
||||
project = result.scalar_one_or_none()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="项目不存在")
|
||||
|
||||
# 创建大纲
|
||||
db_outline = Outline(**outline.model_dump())
|
||||
db.add(db_outline)
|
||||
|
||||
# 同步创建对应的章节记录
|
||||
chapter = Chapter(
|
||||
project_id=outline.project_id,
|
||||
chapter_number=outline.order_index,
|
||||
title=outline.title,
|
||||
summary=outline.content[:500] if len(outline.content) > 500 else outline.content,
|
||||
status="draft"
|
||||
)
|
||||
db.add(chapter)
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(db_outline)
|
||||
return db_outline
|
||||
|
||||
|
||||
@router.get("", response_model=OutlineListResponse, summary="获取大纲列表")
|
||||
async def get_outlines(
|
||||
project_id: str,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""获取指定项目的所有大纲"""
|
||||
# 获取总数
|
||||
count_result = await db.execute(
|
||||
select(func.count(Outline.id)).where(Outline.project_id == project_id)
|
||||
)
|
||||
total = count_result.scalar_one()
|
||||
|
||||
# 获取大纲列表
|
||||
result = await db.execute(
|
||||
select(Outline)
|
||||
.where(Outline.project_id == project_id)
|
||||
.order_by(Outline.order_index)
|
||||
)
|
||||
outlines = result.scalars().all()
|
||||
|
||||
return OutlineListResponse(total=total, items=outlines)
|
||||
|
||||
|
||||
@router.get("/project/{project_id}", response_model=OutlineListResponse, summary="获取项目的所有大纲")
|
||||
async def get_project_outlines(
|
||||
project_id: str,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""获取指定项目的所有大纲(路径参数版本)"""
|
||||
# 获取总数
|
||||
count_result = await db.execute(
|
||||
select(func.count(Outline.id)).where(Outline.project_id == project_id)
|
||||
)
|
||||
total = count_result.scalar_one()
|
||||
|
||||
# 获取大纲列表
|
||||
result = await db.execute(
|
||||
select(Outline)
|
||||
.where(Outline.project_id == project_id)
|
||||
.order_by(Outline.order_index)
|
||||
)
|
||||
outlines = result.scalars().all()
|
||||
|
||||
return OutlineListResponse(total=total, items=outlines)
|
||||
|
||||
|
||||
@router.get("/{outline_id}", response_model=OutlineResponse, summary="获取大纲详情")
|
||||
async def get_outline(
|
||||
outline_id: str,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""根据ID获取大纲详情"""
|
||||
result = await db.execute(
|
||||
select(Outline).where(Outline.id == outline_id)
|
||||
)
|
||||
outline = result.scalar_one_or_none()
|
||||
|
||||
if not outline:
|
||||
raise HTTPException(status_code=404, detail="大纲不存在")
|
||||
|
||||
return outline
|
||||
|
||||
|
||||
@router.put("/{outline_id}", response_model=OutlineResponse, summary="更新大纲")
|
||||
async def update_outline(
|
||||
outline_id: str,
|
||||
outline_update: OutlineUpdate,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""更新大纲信息,同步更新对应章节和structure字段"""
|
||||
result = await db.execute(
|
||||
select(Outline).where(Outline.id == outline_id)
|
||||
)
|
||||
outline = result.scalar_one_or_none()
|
||||
|
||||
if not outline:
|
||||
raise HTTPException(status_code=404, detail="大纲不存在")
|
||||
|
||||
# 更新字段
|
||||
update_data = outline_update.model_dump(exclude_unset=True)
|
||||
for field, value in update_data.items():
|
||||
setattr(outline, field, value)
|
||||
|
||||
# 如果修改了content或title,同步更新structure字段
|
||||
if 'content' in update_data or 'title' in update_data:
|
||||
try:
|
||||
# 尝试解析现有的structure
|
||||
if outline.structure:
|
||||
structure_data = json.loads(outline.structure)
|
||||
else:
|
||||
structure_data = {}
|
||||
|
||||
# 更新structure中的对应字段
|
||||
if 'title' in update_data:
|
||||
structure_data['title'] = outline.title
|
||||
if 'content' in update_data:
|
||||
structure_data['summary'] = outline.content
|
||||
structure_data['content'] = outline.content
|
||||
|
||||
# 保存更新后的structure
|
||||
outline.structure = json.dumps(structure_data, ensure_ascii=False)
|
||||
logger.info(f"同步更新大纲 {outline_id} 的structure字段")
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"大纲 {outline_id} 的structure字段格式错误,跳过更新")
|
||||
|
||||
# 同步更新对应的章节标题和摘要
|
||||
if 'title' in update_data or 'content' in update_data:
|
||||
chapter_result = await db.execute(
|
||||
select(Chapter).where(
|
||||
Chapter.project_id == outline.project_id,
|
||||
Chapter.chapter_number == outline.order_index
|
||||
)
|
||||
)
|
||||
chapter = chapter_result.scalar_one_or_none()
|
||||
|
||||
if chapter:
|
||||
if 'title' in update_data:
|
||||
chapter.title = outline.title
|
||||
if 'content' in update_data:
|
||||
# 更新章节摘要(取content前500字符)
|
||||
chapter.summary = outline.content[:500] if len(outline.content) > 500 else outline.content
|
||||
logger.info(f"同步更新章节 {chapter.id} 的标题和摘要")
|
||||
else:
|
||||
logger.warning(f"未找到对应的章节记录 (order_index={outline.order_index})")
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(outline)
|
||||
return outline
|
||||
|
||||
|
||||
@router.delete("/{outline_id}", summary="删除大纲")
|
||||
async def delete_outline(
|
||||
outline_id: str,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""删除大纲,同步删除章节,并重新排序后续项"""
|
||||
result = await db.execute(
|
||||
select(Outline).where(Outline.id == outline_id)
|
||||
)
|
||||
outline = result.scalar_one_or_none()
|
||||
|
||||
if not outline:
|
||||
raise HTTPException(status_code=404, detail="大纲不存在")
|
||||
|
||||
project_id = outline.project_id
|
||||
deleted_order = outline.order_index
|
||||
|
||||
# 删除对应的章节
|
||||
await db.execute(
|
||||
delete(Chapter).where(
|
||||
Chapter.project_id == project_id,
|
||||
Chapter.chapter_number == deleted_order
|
||||
)
|
||||
)
|
||||
|
||||
# 删除大纲
|
||||
await db.delete(outline)
|
||||
|
||||
# 重新排序后续的大纲和章节(序号-1)
|
||||
result = await db.execute(
|
||||
select(Outline).where(
|
||||
Outline.project_id == project_id,
|
||||
Outline.order_index > deleted_order
|
||||
)
|
||||
)
|
||||
subsequent_outlines = result.scalars().all()
|
||||
|
||||
for o in subsequent_outlines:
|
||||
old_order = o.order_index
|
||||
o.order_index -= 1
|
||||
|
||||
# 同步更新对应的章节
|
||||
chapter_result = await db.execute(
|
||||
select(Chapter).where(
|
||||
Chapter.project_id == project_id,
|
||||
Chapter.chapter_number == old_order
|
||||
)
|
||||
)
|
||||
chapter = chapter_result.scalar_one_or_none()
|
||||
if chapter:
|
||||
chapter.chapter_number = old_order - 1
|
||||
|
||||
await db.commit()
|
||||
|
||||
return {"message": "大纲删除成功"}
|
||||
|
||||
|
||||
@router.post("/reorder", summary="批量重排序大纲")
|
||||
async def reorder_outlines(
|
||||
request: OutlineReorderRequest,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
批量调整大纲顺序,同步更新章节序号
|
||||
|
||||
策略:先收集所有变更,最后一次性提交,避免临时冲突
|
||||
"""
|
||||
try:
|
||||
# 第一步:收集所有大纲和对应的章节
|
||||
outline_chapter_map = {} # {outline_id: (outline, chapter, old_order, new_order)}
|
||||
|
||||
for item in request.orders:
|
||||
outline_id = item.id
|
||||
new_order = item.order_index
|
||||
|
||||
# 获取大纲
|
||||
result = await db.execute(
|
||||
select(Outline).where(Outline.id == outline_id)
|
||||
)
|
||||
outline = result.scalar_one_or_none()
|
||||
|
||||
if not outline:
|
||||
logger.warning(f"大纲 {outline_id} 不存在,跳过")
|
||||
continue
|
||||
|
||||
old_order = outline.order_index
|
||||
|
||||
# 获取对应的章节(通过旧的chapter_number匹配)
|
||||
chapter_result = await db.execute(
|
||||
select(Chapter).where(
|
||||
Chapter.project_id == outline.project_id,
|
||||
Chapter.chapter_number == old_order
|
||||
)
|
||||
)
|
||||
chapter = chapter_result.first()
|
||||
chapter_obj = chapter[0] if chapter else None
|
||||
|
||||
outline_chapter_map[outline_id] = (outline, chapter_obj, old_order, new_order)
|
||||
|
||||
# 第二步:批量更新所有大纲和章节
|
||||
updated_outlines = 0
|
||||
updated_chapters = 0
|
||||
|
||||
for outline_id, (outline, chapter, old_order, new_order) in outline_chapter_map.items():
|
||||
# 更新大纲
|
||||
outline.order_index = new_order
|
||||
updated_outlines += 1
|
||||
|
||||
# 更新章节
|
||||
if chapter:
|
||||
chapter.chapter_number = new_order
|
||||
chapter.title = outline.title # 同步更新标题
|
||||
updated_chapters += 1
|
||||
else:
|
||||
logger.warning(f"章节 {old_order} 不存在,跳过")
|
||||
|
||||
# 第三步:一次性提交所有更改
|
||||
await db.commit()
|
||||
|
||||
logger.info(f"重排序成功:更新了 {updated_outlines} 个大纲,{updated_chapters} 个章节")
|
||||
|
||||
return {
|
||||
"message": "重排序成功",
|
||||
"updated_outlines": updated_outlines,
|
||||
"updated_chapters": updated_chapters
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"重排序失败: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"重排序失败: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/generate", response_model=OutlineListResponse, summary="AI生成/续写大纲")
|
||||
async def generate_outline(
|
||||
request: OutlineGenerateRequest,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
使用AI生成或续写小说大纲 - 智能模式
|
||||
|
||||
支持三种模式:
|
||||
- auto: 自动判断(无大纲→新建,有大纲→续写)
|
||||
- new: 强制全新生成
|
||||
- continue: 强制续写模式
|
||||
"""
|
||||
# 验证项目是否存在
|
||||
result = await db.execute(
|
||||
select(Project).where(Project.id == request.project_id)
|
||||
)
|
||||
project = result.scalar_one_or_none()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="项目不存在")
|
||||
|
||||
try:
|
||||
# 获取现有大纲(强制从数据库获取最新数据,包括用户手动修改的内容)
|
||||
existing_result = await db.execute(
|
||||
select(Outline)
|
||||
.where(Outline.project_id == request.project_id)
|
||||
.order_by(Outline.order_index)
|
||||
.execution_options(populate_existing=True)
|
||||
)
|
||||
existing_outlines = existing_result.scalars().all()
|
||||
|
||||
# 判断实际执行模式
|
||||
actual_mode = request.mode
|
||||
if actual_mode == "auto":
|
||||
actual_mode = "continue" if existing_outlines else "new"
|
||||
logger.info(f"自动判断模式:{'续写' if existing_outlines else '新建'}")
|
||||
|
||||
# 模式:全新生成
|
||||
if actual_mode == "new":
|
||||
return await _generate_new_outline(
|
||||
request, project, db
|
||||
)
|
||||
|
||||
# 模式:续写
|
||||
elif actual_mode == "continue":
|
||||
if not existing_outlines:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="续写模式需要已有大纲,当前项目没有大纲"
|
||||
)
|
||||
|
||||
return await _continue_outline(
|
||||
request, project, existing_outlines, db
|
||||
)
|
||||
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"不支持的模式: {request.mode}"
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"生成大纲失败: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"生成大纲失败: {str(e)}")
|
||||
|
||||
|
||||
async def _generate_new_outline(
|
||||
request: OutlineGenerateRequest,
|
||||
project: Project,
|
||||
db: AsyncSession
|
||||
) -> OutlineListResponse:
|
||||
"""全新生成大纲"""
|
||||
logger.info(f"全新生成大纲 - 项目: {project.id}, keep_existing: {request.keep_existing}")
|
||||
|
||||
# 获取角色信息
|
||||
characters_result = await db.execute(
|
||||
select(Character).where(Character.project_id == project.id)
|
||||
)
|
||||
characters = characters_result.scalars().all()
|
||||
characters_info = "\n".join([
|
||||
f"- {char.name} ({'组织' if char.is_organization else '角色'}, {char.role_type}): "
|
||||
f"{char.personality[:100] if char.personality else '暂无描述'}"
|
||||
for char in characters
|
||||
])
|
||||
|
||||
# 使用完整提示词
|
||||
prompt = prompt_service.get_complete_outline_prompt(
|
||||
title=project.title,
|
||||
theme=request.theme or project.theme or "未设定",
|
||||
genre=request.genre or project.genre or "通用",
|
||||
chapter_count=request.chapter_count,
|
||||
narrative_perspective=request.narrative_perspective,
|
||||
target_words=request.target_words,
|
||||
time_period=project.world_time_period or "未设定",
|
||||
location=project.world_location or "未设定",
|
||||
atmosphere=project.world_atmosphere or "未设定",
|
||||
rules=project.world_rules or "未设定",
|
||||
characters_info=characters_info or "暂无角色信息",
|
||||
requirements=request.requirements or ""
|
||||
)
|
||||
|
||||
# 调用AI
|
||||
ai_response = await ai_service.generate_text(
|
||||
prompt=prompt,
|
||||
provider=request.provider,
|
||||
model=request.model
|
||||
)
|
||||
|
||||
# 解析响应
|
||||
outline_data = _parse_ai_response(ai_response)
|
||||
|
||||
# 全新生成模式:必须删除旧大纲和章节
|
||||
# 注意:这是"new"模式的核心逻辑,应该始终删除旧数据
|
||||
logger.info(f"删除项目 {project.id} 的旧大纲和章节")
|
||||
await db.execute(
|
||||
delete(Outline).where(Outline.project_id == project.id)
|
||||
)
|
||||
await db.execute(
|
||||
delete(Chapter).where(Chapter.project_id == project.id)
|
||||
)
|
||||
|
||||
# 保存新大纲
|
||||
outlines = await _save_outlines(
|
||||
project.id, outline_data, db, start_index=1
|
||||
)
|
||||
|
||||
# 记录历史
|
||||
history = GenerationHistory(
|
||||
project_id=project.id,
|
||||
prompt=prompt,
|
||||
generated_content=ai_response,
|
||||
model=request.model or "default"
|
||||
)
|
||||
db.add(history)
|
||||
|
||||
await db.commit()
|
||||
|
||||
for outline in outlines:
|
||||
await db.refresh(outline)
|
||||
|
||||
logger.info(f"全新生成完成 - {len(outlines)} 章")
|
||||
return OutlineListResponse(total=len(outlines), items=outlines)
|
||||
|
||||
|
||||
async def _continue_outline(
|
||||
request: OutlineGenerateRequest,
|
||||
project: Project,
|
||||
existing_outlines: List[Outline],
|
||||
db: AsyncSession
|
||||
) -> OutlineListResponse:
|
||||
"""续写大纲"""
|
||||
logger.info(f"续写大纲 - 项目: {project.id}, 已有: {len(existing_outlines)} 章")
|
||||
|
||||
# 分析已有大纲
|
||||
current_chapter_count = len(existing_outlines)
|
||||
last_chapter_number = existing_outlines[-1].order_index
|
||||
|
||||
# 获取最近2章的剧情
|
||||
recent_outlines = existing_outlines[-2:] if len(existing_outlines) >= 2 else existing_outlines
|
||||
recent_plot = "\n".join([
|
||||
f"第{o.order_index}章《{o.title}》: {o.content}"
|
||||
for o in recent_outlines
|
||||
])
|
||||
# logger.debug(f"最近三章内容:{recent_plot}")
|
||||
# 全部章节概览
|
||||
all_chapters_brief = "\n".join([
|
||||
f"第{o.order_index}章: {o.title}"
|
||||
for o in existing_outlines
|
||||
])
|
||||
|
||||
# 获取角色信息
|
||||
characters_result = await db.execute(
|
||||
select(Character).where(Character.project_id == project.id)
|
||||
)
|
||||
characters = characters_result.scalars().all()
|
||||
characters_info = "\n".join([
|
||||
f"- {char.name} ({'组织' if char.is_organization else '角色'}, {char.role_type}): "
|
||||
f"{char.personality[:100] if char.personality else '暂无描述'}"
|
||||
for char in characters
|
||||
])
|
||||
|
||||
# 情节阶段指导
|
||||
stage_instructions = {
|
||||
"development": "继续展开情节,深化角色关系,推进主线冲突",
|
||||
"climax": "进入故事高潮,矛盾激化,关键冲突爆发",
|
||||
"ending": "解决主要冲突,收束伏笔,给出结局"
|
||||
}
|
||||
stage_instruction = stage_instructions.get(request.plot_stage, "")
|
||||
|
||||
# 使用标准续写提示词模板
|
||||
prompt = prompt_service.get_outline_continue_prompt(
|
||||
title=project.title,
|
||||
theme=request.theme or project.theme or "未设定",
|
||||
genre=request.genre or project.genre or "通用",
|
||||
narrative_perspective=request.narrative_perspective,
|
||||
chapter_count=request.chapter_count,
|
||||
time_period=project.world_time_period or "未设定",
|
||||
location=project.world_location or "未设定",
|
||||
atmosphere=project.world_atmosphere or "未设定",
|
||||
rules=project.world_rules or "未设定",
|
||||
characters_info=characters_info or "暂无角色信息",
|
||||
current_chapter_count=current_chapter_count,
|
||||
all_chapters_brief=all_chapters_brief,
|
||||
recent_plot=recent_plot,
|
||||
plot_stage_instruction=stage_instruction,
|
||||
start_chapter=last_chapter_number + 1,
|
||||
story_direction=request.story_direction or "自然延续",
|
||||
requirements=request.requirements or ""
|
||||
)
|
||||
|
||||
# 调用AI
|
||||
ai_response = await ai_service.generate_text(
|
||||
prompt=prompt,
|
||||
provider=request.provider,
|
||||
model=request.model
|
||||
)
|
||||
|
||||
# 解析响应
|
||||
outline_data = _parse_ai_response(ai_response)
|
||||
|
||||
# 保存续写的大纲
|
||||
new_outlines = await _save_outlines(
|
||||
project.id, outline_data, db, start_index=last_chapter_number + 1
|
||||
)
|
||||
|
||||
# 记录历史
|
||||
history = GenerationHistory(
|
||||
project_id=project.id,
|
||||
prompt=prompt,
|
||||
generated_content=ai_response,
|
||||
model=request.model or "default"
|
||||
)
|
||||
db.add(history)
|
||||
|
||||
await db.commit()
|
||||
|
||||
for outline in new_outlines:
|
||||
await db.refresh(outline)
|
||||
|
||||
# 返回所有大纲(包括旧的和新的)
|
||||
all_result = await db.execute(
|
||||
select(Outline)
|
||||
.where(Outline.project_id == project.id)
|
||||
.order_by(Outline.order_index)
|
||||
)
|
||||
all_outlines = all_result.scalars().all()
|
||||
|
||||
logger.info(f"续写完成 - 新增 {len(new_outlines)} 章,总计 {len(all_outlines)} 章")
|
||||
return OutlineListResponse(total=len(all_outlines), items=all_outlines)
|
||||
|
||||
|
||||
def _parse_ai_response(ai_response: str) -> list:
|
||||
"""解析AI响应为章节数据列表"""
|
||||
try:
|
||||
# 清理响应文本
|
||||
cleaned_text = ai_response.strip()
|
||||
if cleaned_text.startswith('```json'):
|
||||
cleaned_text = cleaned_text[7:]
|
||||
if cleaned_text.startswith('```'):
|
||||
cleaned_text = cleaned_text[3:]
|
||||
if cleaned_text.endswith('```'):
|
||||
cleaned_text = cleaned_text[:-3]
|
||||
cleaned_text = cleaned_text.strip()
|
||||
|
||||
outline_data = json.loads(cleaned_text)
|
||||
|
||||
# 确保是列表格式
|
||||
if not isinstance(outline_data, list):
|
||||
# 如果是对象,尝试提取chapters字段
|
||||
if isinstance(outline_data, dict):
|
||||
outline_data = outline_data.get("chapters", [outline_data])
|
||||
else:
|
||||
outline_data = [outline_data]
|
||||
|
||||
return outline_data
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"AI响应解析失败: {e}")
|
||||
# 返回一个包含原始内容的章节
|
||||
return [{
|
||||
"title": "AI生成的大纲",
|
||||
"content": ai_response[:1000],
|
||||
"summary": ai_response[:1000]
|
||||
}]
|
||||
|
||||
|
||||
async def _save_outlines(
|
||||
project_id: str,
|
||||
outline_data: list,
|
||||
db: AsyncSession,
|
||||
start_index: int = 1
|
||||
) -> List[Outline]:
|
||||
"""保存大纲到数据库"""
|
||||
outlines = []
|
||||
|
||||
for idx, chapter_data in enumerate(outline_data):
|
||||
order_idx = chapter_data.get("chapter_number", start_index + idx)
|
||||
title = chapter_data.get("title", f"第{order_idx}章")
|
||||
|
||||
# 优先使用summary,其次content
|
||||
content = chapter_data.get("summary") or chapter_data.get("content", "")
|
||||
|
||||
# 如果有额外信息,添加到内容中
|
||||
if "key_events" in chapter_data:
|
||||
content += f"\n\n关键事件:" + "、".join(chapter_data["key_events"])
|
||||
if "characters_involved" in chapter_data:
|
||||
content += f"\n涉及角色:" + "、".join(chapter_data["characters_involved"])
|
||||
|
||||
# 创建大纲
|
||||
outline = Outline(
|
||||
project_id=project_id,
|
||||
title=title,
|
||||
content=content,
|
||||
structure=json.dumps(chapter_data, ensure_ascii=False),
|
||||
order_index=order_idx
|
||||
)
|
||||
db.add(outline)
|
||||
outlines.append(outline)
|
||||
|
||||
# 同步创建章节记录
|
||||
chapter = Chapter(
|
||||
project_id=project_id,
|
||||
chapter_number=order_idx,
|
||||
title=title,
|
||||
summary=content[:500] if len(content) > 500 else content,
|
||||
status="draft"
|
||||
)
|
||||
db.add(chapter)
|
||||
|
||||
return outlines
|
||||
@@ -0,0 +1,124 @@
|
||||
"""AI去味API - 核心特色功能"""
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.database import get_db
|
||||
from app.models.generation_history import GenerationHistory
|
||||
from app.schemas.polish import PolishRequest, PolishResponse
|
||||
from app.services.ai_service import ai_service
|
||||
from app.services.prompt_service import prompt_service
|
||||
from app.logger import get_logger
|
||||
|
||||
router = APIRouter(prefix="/polish", tags=["AI去味"])
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@router.post("", response_model=PolishResponse, summary="AI去味")
|
||||
async def polish_text(
|
||||
request: PolishRequest,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
AI去味 - 将AI生成的文本改写得更像人类作家的手笔
|
||||
|
||||
核心功能:
|
||||
- 去除AI痕迹(工整排比、重复修辞、机械总结)
|
||||
- 增加人性化(口语化、不完美细节、真实情感)
|
||||
- 优化叙事(自然节奏、简单词汇、松弛感)
|
||||
- 让对话更生活化
|
||||
|
||||
这是本项目的核心特色功能!
|
||||
"""
|
||||
try:
|
||||
# 构建AI去味提示词
|
||||
prompt = prompt_service.get_denoising_prompt(
|
||||
original_text=request.original_text
|
||||
)
|
||||
|
||||
logger.info(f"开始AI去味处理,原文长度: {len(request.original_text)}")
|
||||
|
||||
# 调用AI进行去味处理
|
||||
polished_text = await ai_service.generate_text(
|
||||
prompt=prompt,
|
||||
provider=request.provider,
|
||||
model=request.model,
|
||||
temperature=request.temperature,
|
||||
max_tokens=len(request.original_text) * 2 # 预留足够token
|
||||
)
|
||||
|
||||
# 计算字数
|
||||
word_count_before = len(request.original_text)
|
||||
word_count_after = len(polished_text)
|
||||
|
||||
logger.info(f"AI去味完成,处理后长度: {word_count_after}")
|
||||
|
||||
# 如果提供了项目ID,记录到历史
|
||||
if request.project_id:
|
||||
history = GenerationHistory(
|
||||
project_id=request.project_id,
|
||||
generation_type="polish",
|
||||
prompt=f"原文: {request.original_text[:100]}...",
|
||||
result=polished_text,
|
||||
provider=request.provider or "default",
|
||||
model=request.model or "default"
|
||||
)
|
||||
db.add(history)
|
||||
await db.commit()
|
||||
|
||||
return PolishResponse(
|
||||
original_text=request.original_text,
|
||||
polished_text=polished_text,
|
||||
word_count_before=word_count_before,
|
||||
word_count_after=word_count_after
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"AI去味失败: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"AI去味失败: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/batch", summary="批量AI去味")
|
||||
async def polish_batch(
|
||||
texts: list[str],
|
||||
project_id: int = None,
|
||||
provider: str = None,
|
||||
model: str = None,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
批量处理多个文本的AI去味
|
||||
|
||||
适用于一次性处理多个章节或段落
|
||||
"""
|
||||
try:
|
||||
results = []
|
||||
|
||||
for idx, text in enumerate(texts):
|
||||
logger.info(f"处理第 {idx+1}/{len(texts)} 个文本")
|
||||
|
||||
prompt = prompt_service.get_denoising_prompt(original_text=text)
|
||||
|
||||
polished_text = await ai_service.generate_text(
|
||||
prompt=prompt,
|
||||
provider=provider,
|
||||
model=model
|
||||
)
|
||||
|
||||
results.append({
|
||||
"index": idx,
|
||||
"original": text,
|
||||
"polished": polished_text,
|
||||
"word_count_before": len(text),
|
||||
"word_count_after": len(polished_text)
|
||||
})
|
||||
|
||||
logger.info(f"批量AI去味完成,共处理 {len(results)} 个文本")
|
||||
|
||||
return {
|
||||
"total": len(results),
|
||||
"results": results
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"批量AI去味失败: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"批量AI去味失败: {str(e)}")
|
||||
@@ -0,0 +1,414 @@
|
||||
"""项目管理API"""
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.responses import Response
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, func, delete
|
||||
from typing import List
|
||||
from app.database import get_db
|
||||
from app.models.project import Project
|
||||
from app.models.character import Character
|
||||
from app.models.outline import Outline
|
||||
from app.models.chapter import Chapter
|
||||
from app.models.generation_history import GenerationHistory
|
||||
from app.models.relationship import CharacterRelationship, Organization, OrganizationMember
|
||||
from app.schemas.project import (
|
||||
ProjectCreate,
|
||||
ProjectUpdate,
|
||||
ProjectResponse,
|
||||
ProjectListResponse
|
||||
)
|
||||
from app.logger import get_logger
|
||||
from app.utils.data_consistency import (
|
||||
run_full_data_consistency_check,
|
||||
fix_missing_organization_records,
|
||||
fix_organization_member_counts
|
||||
)
|
||||
|
||||
logger = get_logger(__name__)
|
||||
router = APIRouter(prefix="/projects", tags=["项目管理"])
|
||||
|
||||
|
||||
@router.post("", response_model=ProjectResponse, summary="创建项目")
|
||||
async def create_project(
|
||||
project: ProjectCreate,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
try:
|
||||
logger.info(f"创建新项目: {project.title}")
|
||||
db_project = Project(**project.model_dump())
|
||||
db.add(db_project)
|
||||
await db.commit()
|
||||
await db.refresh(db_project)
|
||||
logger.info(f"项目创建成功: {db_project.id}")
|
||||
return db_project
|
||||
except Exception as e:
|
||||
logger.error(f"创建项目失败: {str(e)}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
@router.get("", response_model=ProjectListResponse, summary="获取项目列表")
|
||||
async def get_projects(
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""获取所有项目列表"""
|
||||
try:
|
||||
logger.debug(f"获取项目列表: skip={skip}, limit={limit}")
|
||||
count_result = await db.execute(select(func.count(Project.id)))
|
||||
total = count_result.scalar_one()
|
||||
|
||||
result = await db.execute(
|
||||
select(Project)
|
||||
.order_by(Project.updated_at.desc())
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
)
|
||||
projects = result.scalars().all()
|
||||
logger.info(f"获取项目列表成功: 共{total}个项目")
|
||||
|
||||
return ProjectListResponse(total=total, items=projects)
|
||||
except Exception as e:
|
||||
logger.error(f"获取项目列表失败: {str(e)}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
@router.get("/{project_id}", response_model=ProjectResponse, summary="获取项目详情")
|
||||
async def get_project(
|
||||
project_id: str,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
try:
|
||||
logger.debug(f"获取项目详情: {project_id}")
|
||||
result = await db.execute(
|
||||
select(Project).where(Project.id == project_id)
|
||||
)
|
||||
project = result.scalar_one_or_none()
|
||||
|
||||
if not project:
|
||||
logger.warning(f"项目不存在: {project_id}")
|
||||
raise HTTPException(status_code=404, detail="项目不存在")
|
||||
|
||||
logger.info(f"获取项目详情成功: {project.title}")
|
||||
return project
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"获取项目详情失败: {str(e)}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
@router.put("/{project_id}", response_model=ProjectResponse, summary="更新项目")
|
||||
async def update_project(
|
||||
project_id: str,
|
||||
project_update: ProjectUpdate,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
try:
|
||||
logger.info(f"更新项目: {project_id}")
|
||||
result = await db.execute(
|
||||
select(Project).where(Project.id == project_id)
|
||||
)
|
||||
project = result.scalar_one_or_none()
|
||||
|
||||
if not project:
|
||||
logger.warning(f"项目不存在: {project_id}")
|
||||
raise HTTPException(status_code=404, detail="项目不存在")
|
||||
|
||||
update_data = project_update.model_dump(exclude_unset=True)
|
||||
logger.debug(f"更新字段: {list(update_data.keys())}")
|
||||
for field, value in update_data.items():
|
||||
setattr(project, field, value)
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(project)
|
||||
logger.info(f"项目更新成功: {project.title}")
|
||||
return project
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"更新项目失败: {str(e)}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
@router.delete("/{project_id}", summary="删除项目")
|
||||
async def delete_project(
|
||||
project_id: str,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
try:
|
||||
logger.info(f"删除项目: {project_id}")
|
||||
result = await db.execute(
|
||||
select(Project).where(Project.id == project_id)
|
||||
)
|
||||
project = result.scalar_one_or_none()
|
||||
|
||||
if not project:
|
||||
logger.warning(f"项目不存在: {project_id}")
|
||||
raise HTTPException(status_code=404, detail="项目不存在")
|
||||
|
||||
project_title = project.title
|
||||
|
||||
relationships_result = await db.execute(
|
||||
delete(CharacterRelationship).where(CharacterRelationship.project_id == project_id)
|
||||
)
|
||||
logger.debug(f"删除角色关系数: {relationships_result.rowcount}")
|
||||
|
||||
orgs_result = await db.execute(
|
||||
select(Organization).where(Organization.project_id == project_id)
|
||||
)
|
||||
orgs = orgs_result.scalars().all()
|
||||
org_member_count = 0
|
||||
for org in orgs:
|
||||
members_result = await db.execute(
|
||||
delete(OrganizationMember).where(OrganizationMember.organization_id == org.id)
|
||||
)
|
||||
org_member_count += members_result.rowcount
|
||||
logger.debug(f"删除组织成员数: {org_member_count}")
|
||||
|
||||
organizations_result = await db.execute(
|
||||
delete(Organization).where(Organization.project_id == project_id)
|
||||
)
|
||||
logger.debug(f"删除组织数: {organizations_result.rowcount}")
|
||||
|
||||
history_result = await db.execute(
|
||||
delete(GenerationHistory).where(GenerationHistory.project_id == project_id)
|
||||
)
|
||||
logger.debug(f"删除生成历史数: {history_result.rowcount}")
|
||||
|
||||
chapters_result = await db.execute(
|
||||
delete(Chapter).where(Chapter.project_id == project_id)
|
||||
)
|
||||
logger.debug(f"删除章节数: {chapters_result.rowcount}")
|
||||
|
||||
outlines_result = await db.execute(
|
||||
delete(Outline).where(Outline.project_id == project_id)
|
||||
)
|
||||
logger.debug(f"删除大纲数: {outlines_result.rowcount}")
|
||||
|
||||
characters_result = await db.execute(
|
||||
delete(Character).where(Character.project_id == project_id)
|
||||
)
|
||||
logger.debug(f"删除角色数: {characters_result.rowcount}")
|
||||
|
||||
await db.delete(project)
|
||||
await db.commit()
|
||||
|
||||
logger.info(f"项目删除成功: {project_title}")
|
||||
return {"message": "项目及所有关联数据删除成功"}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"删除项目失败: {str(e)}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
@router.get("/{project_id}/export", summary="导出项目章节为TXT")
|
||||
async def export_project_chapters(
|
||||
project_id: str,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
导出项目的所有章节内容为TXT文本文件
|
||||
按章节顺序组织,包含项目基本信息
|
||||
"""
|
||||
try:
|
||||
logger.info(f"开始导出项目: {project_id}")
|
||||
|
||||
result = await db.execute(
|
||||
select(Project).where(Project.id == project_id)
|
||||
)
|
||||
project = result.scalar_one_or_none()
|
||||
|
||||
if not project:
|
||||
logger.warning(f"项目不存在: {project_id}")
|
||||
raise HTTPException(status_code=404, detail="项目不存在")
|
||||
|
||||
chapters_result = await db.execute(
|
||||
select(Chapter)
|
||||
.where(Chapter.project_id == project_id)
|
||||
.order_by(Chapter.chapter_number)
|
||||
)
|
||||
chapters = chapters_result.scalars().all()
|
||||
|
||||
if not chapters:
|
||||
logger.warning(f"项目没有章节: {project_id}")
|
||||
raise HTTPException(status_code=404, detail="项目没有任何章节")
|
||||
|
||||
txt_content = []
|
||||
|
||||
txt_content.append("=" * 80)
|
||||
txt_content.append(f"项目标题: {project.title}")
|
||||
txt_content.append("=" * 80)
|
||||
|
||||
if project.description:
|
||||
txt_content.append(f"\n简介: {project.description}\n")
|
||||
|
||||
if project.theme:
|
||||
txt_content.append(f"主题: {project.theme}")
|
||||
|
||||
if project.genre:
|
||||
txt_content.append(f"类型: {project.genre}")
|
||||
|
||||
txt_content.append(f"总章节数: {len(chapters)}")
|
||||
txt_content.append(f"总字数: {project.current_words}")
|
||||
txt_content.append("\n" + "=" * 80 + "\n\n")
|
||||
|
||||
for chapter in chapters:
|
||||
txt_content.append(f"第 {chapter.chapter_number} 章 {chapter.title}")
|
||||
txt_content.append("-" * 80)
|
||||
txt_content.append("") # 空行
|
||||
|
||||
if chapter.content:
|
||||
txt_content.append(chapter.content)
|
||||
else:
|
||||
txt_content.append("(本章暂无内容)")
|
||||
|
||||
txt_content.append("\n\n" + "=" * 80 + "\n\n")
|
||||
|
||||
txt_content.append(f"--- 全文完 ---")
|
||||
txt_content.append(f"\n导出时间: {func.now()}")
|
||||
|
||||
final_content = "\n".join(txt_content)
|
||||
|
||||
safe_title = "".join(c for c in project.title if c.isalnum() or c in (' ', '-', '_', ',', '。', '、'))
|
||||
filename = f"{safe_title}.txt"
|
||||
|
||||
from urllib.parse import quote
|
||||
encoded_filename = quote(filename)
|
||||
|
||||
logger.info(f"导出成功: {filename}, 共{len(chapters)}章, {len(final_content)}字符")
|
||||
|
||||
return Response(
|
||||
content=final_content.encode('utf-8'),
|
||||
media_type="text/plain; charset=utf-8",
|
||||
headers={
|
||||
"Content-Disposition": f"attachment; filename*=UTF-8''{encoded_filename}",
|
||||
"Content-Type": "text/plain; charset=utf-8"
|
||||
}
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"导出项目失败: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"导出失败: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/{project_id}/check-consistency", summary="检查数据一致性")
|
||||
async def check_project_consistency(
|
||||
project_id: str,
|
||||
auto_fix: bool = True,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
检查并修复项目的数据一致性问题
|
||||
|
||||
Args:
|
||||
project_id: 项目ID
|
||||
auto_fix: 是否自动修复问题(默认True)
|
||||
|
||||
返回检查报告,包含:
|
||||
- organization_records: 检查并修复缺失的Organization记录
|
||||
- member_counts: 检查并修复组织成员计数
|
||||
- relationships: 验证关系数据完整性
|
||||
- organization_members: 验证组织成员数据完整性
|
||||
"""
|
||||
try:
|
||||
logger.info(f"开始数据一致性检查: {project_id}, auto_fix={auto_fix}")
|
||||
|
||||
result = await db.execute(
|
||||
select(Project).where(Project.id == project_id)
|
||||
)
|
||||
project = result.scalar_one_or_none()
|
||||
|
||||
if not project:
|
||||
logger.warning(f"项目不存在: {project_id}")
|
||||
raise HTTPException(status_code=404, detail="项目不存在")
|
||||
|
||||
report = await run_full_data_consistency_check(project_id, db, auto_fix)
|
||||
|
||||
logger.info(f"数据一致性检查完成: {project_id}")
|
||||
return report
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"数据一致性检查失败: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"检查失败: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/{project_id}/fix-organizations", summary="修复组织记录")
|
||||
async def fix_project_organizations(
|
||||
project_id: str,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
修复项目中缺失的Organization记录
|
||||
|
||||
为所有is_organization=True但没有Organization记录的Character创建记录
|
||||
"""
|
||||
try:
|
||||
logger.info(f"开始修复组织记录: {project_id}")
|
||||
|
||||
result = await db.execute(
|
||||
select(Project).where(Project.id == project_id)
|
||||
)
|
||||
project = result.scalar_one_or_none()
|
||||
|
||||
if not project:
|
||||
logger.warning(f"项目不存在: {project_id}")
|
||||
raise HTTPException(status_code=404, detail="项目不存在")
|
||||
|
||||
fixed_count, total_count = await fix_missing_organization_records(project_id, db)
|
||||
|
||||
logger.info(f"组织记录修复完成: {project_id}, 修复{fixed_count}/{total_count}")
|
||||
return {
|
||||
"message": "组织记录修复完成",
|
||||
"fixed": fixed_count,
|
||||
"total": total_count
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"修复组织记录失败: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"修复失败: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/{project_id}/fix-member-counts", summary="修复成员计数")
|
||||
async def fix_project_member_counts(
|
||||
project_id: str,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
修复项目中所有组织的成员计数
|
||||
|
||||
从实际成员记录重新计算每个组织的member_count
|
||||
"""
|
||||
try:
|
||||
logger.info(f"开始修复成员计数: {project_id}")
|
||||
|
||||
result = await db.execute(
|
||||
select(Project).where(Project.id == project_id)
|
||||
)
|
||||
project = result.scalar_one_or_none()
|
||||
|
||||
if not project:
|
||||
logger.warning(f"项目不存在: {project_id}")
|
||||
raise HTTPException(status_code=404, detail="项目不存在")
|
||||
|
||||
fixed_count, total_count = await fix_organization_member_counts(project_id, db)
|
||||
|
||||
logger.info(f"成员计数修复完成: {project_id}, 修复{fixed_count}/{total_count}")
|
||||
return {
|
||||
"message": "成员计数修复完成",
|
||||
"fixed": fixed_count,
|
||||
"total": total_count
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"修复成员计数失败: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"修复失败: {str(e)}")
|
||||
@@ -0,0 +1,209 @@
|
||||
"""关系管理API"""
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, or_, and_
|
||||
from typing import List, Optional
|
||||
|
||||
from app.database import get_db
|
||||
from app.models.relationship import (
|
||||
RelationshipType,
|
||||
CharacterRelationship,
|
||||
Organization,
|
||||
OrganizationMember
|
||||
)
|
||||
from app.models.character import Character
|
||||
from app.schemas.relationship import (
|
||||
RelationshipTypeResponse,
|
||||
CharacterRelationshipCreate,
|
||||
CharacterRelationshipUpdate,
|
||||
CharacterRelationshipResponse,
|
||||
RelationshipGraphData,
|
||||
RelationshipGraphNode,
|
||||
RelationshipGraphLink
|
||||
)
|
||||
from app.logger import get_logger
|
||||
|
||||
router = APIRouter(prefix="/relationships", tags=["关系管理"])
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@router.get("/types", response_model=List[RelationshipTypeResponse], summary="获取关系类型列表")
|
||||
async def get_relationship_types(db: AsyncSession = Depends(get_db)):
|
||||
"""获取所有预定义的关系类型"""
|
||||
result = await db.execute(select(RelationshipType).order_by(RelationshipType.category, RelationshipType.id))
|
||||
types = result.scalars().all()
|
||||
return types
|
||||
|
||||
|
||||
@router.get("/project/{project_id}", response_model=List[CharacterRelationshipResponse], summary="获取项目的所有关系")
|
||||
async def get_project_relationships(
|
||||
project_id: str,
|
||||
character_id: Optional[str] = Query(None, description="筛选特定角色的关系"),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
获取项目中的所有角色关系
|
||||
|
||||
- 如果提供character_id,则只返回与该角色相关的关系(作为发起方或接收方)
|
||||
- 否则返回项目中的所有关系
|
||||
"""
|
||||
query = select(CharacterRelationship).where(
|
||||
CharacterRelationship.project_id == project_id
|
||||
)
|
||||
|
||||
if character_id:
|
||||
query = query.where(
|
||||
or_(
|
||||
CharacterRelationship.character_from_id == character_id,
|
||||
CharacterRelationship.character_to_id == character_id
|
||||
)
|
||||
)
|
||||
|
||||
query = query.order_by(CharacterRelationship.created_at.desc())
|
||||
result = await db.execute(query)
|
||||
relationships = result.scalars().all()
|
||||
|
||||
logger.info(f"获取项目 {project_id} 的关系列表,共 {len(relationships)} 条")
|
||||
return relationships
|
||||
|
||||
|
||||
@router.get("/graph/{project_id}", response_model=RelationshipGraphData, summary="获取关系图谱数据")
|
||||
async def get_relationship_graph(
|
||||
project_id: str,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
获取用于可视化的关系图谱数据
|
||||
|
||||
返回格式:
|
||||
- nodes: 角色节点列表
|
||||
- links: 关系连线列表
|
||||
"""
|
||||
# 获取所有角色(节点)
|
||||
chars_result = await db.execute(
|
||||
select(Character).where(Character.project_id == project_id)
|
||||
)
|
||||
characters = chars_result.scalars().all()
|
||||
|
||||
nodes = [
|
||||
RelationshipGraphNode(
|
||||
id=c.id,
|
||||
name=c.name,
|
||||
type="organization" if c.is_organization else "character",
|
||||
role_type=c.role_type,
|
||||
avatar=c.avatar_url
|
||||
)
|
||||
for c in characters
|
||||
]
|
||||
|
||||
# 获取所有关系(边)
|
||||
rels_result = await db.execute(
|
||||
select(CharacterRelationship).where(
|
||||
CharacterRelationship.project_id == project_id
|
||||
)
|
||||
)
|
||||
relationships = rels_result.scalars().all()
|
||||
|
||||
links = [
|
||||
RelationshipGraphLink(
|
||||
source=r.character_from_id,
|
||||
target=r.character_to_id,
|
||||
relationship=r.relationship_name or "未知关系",
|
||||
intimacy=r.intimacy_level,
|
||||
status=r.status
|
||||
)
|
||||
for r in relationships
|
||||
]
|
||||
|
||||
logger.info(f"获取项目 {project_id} 的关系图谱:{len(nodes)} 个节点,{len(links)} 条关系")
|
||||
return RelationshipGraphData(nodes=nodes, links=links)
|
||||
|
||||
|
||||
@router.post("/", response_model=CharacterRelationshipResponse, summary="创建角色关系")
|
||||
async def create_relationship(
|
||||
relationship: CharacterRelationshipCreate,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
手动创建角色关系
|
||||
|
||||
- 需要提供角色A和角色B的ID
|
||||
- 可以指定预定义的关系类型或自定义关系名称
|
||||
- 可以设置亲密度、状态等属性
|
||||
"""
|
||||
# 验证角色是否存在
|
||||
char_from = await db.execute(
|
||||
select(Character).where(Character.id == relationship.character_from_id)
|
||||
)
|
||||
char_to = await db.execute(
|
||||
select(Character).where(Character.id == relationship.character_to_id)
|
||||
)
|
||||
|
||||
if not char_from.scalar_one_or_none():
|
||||
raise HTTPException(status_code=404, detail=f"角色A(ID: {relationship.character_from_id})不存在")
|
||||
if not char_to.scalar_one_or_none():
|
||||
raise HTTPException(status_code=404, detail=f"角色B(ID: {relationship.character_to_id})不存在")
|
||||
|
||||
# 创建关系
|
||||
db_relationship = CharacterRelationship(
|
||||
**relationship.model_dump(),
|
||||
source="manual"
|
||||
)
|
||||
db.add(db_relationship)
|
||||
await db.commit()
|
||||
await db.refresh(db_relationship)
|
||||
|
||||
logger.info(f"创建关系成功:{relationship.character_from_id} -> {relationship.character_to_id}")
|
||||
return db_relationship
|
||||
|
||||
|
||||
@router.put("/{relationship_id}", response_model=CharacterRelationshipResponse, summary="更新关系")
|
||||
async def update_relationship(
|
||||
relationship_id: str,
|
||||
relationship: CharacterRelationshipUpdate,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""更新角色关系的属性(亲密度、状态等)"""
|
||||
result = await db.execute(
|
||||
select(CharacterRelationship).where(
|
||||
CharacterRelationship.id == relationship_id
|
||||
)
|
||||
)
|
||||
db_rel = result.scalar_one_or_none()
|
||||
|
||||
if not db_rel:
|
||||
raise HTTPException(status_code=404, detail="关系不存在")
|
||||
|
||||
# 更新字段
|
||||
update_data = relationship.model_dump(exclude_unset=True)
|
||||
for field, value in update_data.items():
|
||||
setattr(db_rel, field, value)
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(db_rel)
|
||||
|
||||
logger.info(f"更新关系成功:{relationship_id}")
|
||||
return db_rel
|
||||
|
||||
|
||||
@router.delete("/{relationship_id}", summary="删除关系")
|
||||
async def delete_relationship(
|
||||
relationship_id: str,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""删除角色关系"""
|
||||
result = await db.execute(
|
||||
select(CharacterRelationship).where(
|
||||
CharacterRelationship.id == relationship_id
|
||||
)
|
||||
)
|
||||
db_rel = result.scalar_one_or_none()
|
||||
|
||||
if not db_rel:
|
||||
raise HTTPException(status_code=404, detail="关系不存在")
|
||||
|
||||
await db.delete(db_rel)
|
||||
await db.commit()
|
||||
|
||||
logger.info(f"删除关系成功:{relationship_id}")
|
||||
return {"message": "关系删除成功", "id": relationship_id}
|
||||
@@ -0,0 +1,125 @@
|
||||
"""
|
||||
用户管理 API
|
||||
"""
|
||||
from fastapi import APIRouter, HTTPException, Request, Depends
|
||||
from pydantic import BaseModel
|
||||
from typing import List, Optional
|
||||
from app.user_manager import user_manager, User
|
||||
|
||||
router = APIRouter(prefix="/users", tags=["用户管理"])
|
||||
|
||||
|
||||
def require_login(request: Request):
|
||||
"""依赖:要求用户已登录"""
|
||||
if not hasattr(request.state, "user") or not request.state.user:
|
||||
raise HTTPException(status_code=401, detail="需要登录")
|
||||
return request.state.user
|
||||
|
||||
|
||||
def require_admin(request: Request):
|
||||
"""依赖:要求用户为管理员"""
|
||||
user = require_login(request)
|
||||
if not request.state.is_admin:
|
||||
raise HTTPException(status_code=403, detail="需要管理员权限")
|
||||
return user
|
||||
|
||||
|
||||
class SetAdminRequest(BaseModel):
|
||||
user_id: str
|
||||
is_admin: bool
|
||||
|
||||
|
||||
@router.get("/current")
|
||||
async def get_current_user(user: User = Depends(require_login)):
|
||||
"""获取当前登录用户信息"""
|
||||
return user.dict()
|
||||
|
||||
|
||||
@router.get("", response_model=List[dict])
|
||||
async def list_users(admin_user: User = Depends(require_admin)):
|
||||
"""
|
||||
获取所有用户列表(仅管理员)
|
||||
"""
|
||||
users = await user_manager.get_all_users()
|
||||
return [user.dict() for user in users]
|
||||
|
||||
|
||||
@router.post("/set-admin")
|
||||
async def set_admin(
|
||||
data: SetAdminRequest,
|
||||
request: Request,
|
||||
admin_user: User = Depends(require_admin)
|
||||
):
|
||||
"""
|
||||
设置用户的管理员权限(仅管理员)
|
||||
|
||||
限制:
|
||||
- 不能撤销自己的管理员权限
|
||||
- 至少保留一个管理员
|
||||
"""
|
||||
# 检查是否尝试撤销自己的权限
|
||||
if data.user_id == admin_user.user_id and not data.is_admin:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="不能撤销自己的管理员权限"
|
||||
)
|
||||
|
||||
# 尝试设置管理员权限
|
||||
success = await user_manager.set_admin(data.user_id, data.is_admin)
|
||||
|
||||
if not success:
|
||||
if not data.is_admin:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="无法撤销管理员权限,至少需要保留一个管理员"
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="用户不存在"
|
||||
)
|
||||
|
||||
return {
|
||||
"message": f"已{'授予' if data.is_admin else '撤销'}管理员权限",
|
||||
"user_id": data.user_id,
|
||||
"is_admin": data.is_admin
|
||||
}
|
||||
|
||||
|
||||
@router.delete("/{user_id}")
|
||||
async def delete_user(
|
||||
user_id: str,
|
||||
admin_user: User = Depends(require_admin)
|
||||
):
|
||||
"""
|
||||
删除用户(仅管理员)
|
||||
|
||||
限制:
|
||||
- 不能删除管理员用户
|
||||
"""
|
||||
success = await user_manager.delete_user(user_id)
|
||||
|
||||
if not success:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="无法删除该用户(用户不存在或为管理员)"
|
||||
)
|
||||
|
||||
return {
|
||||
"message": "用户已删除",
|
||||
"user_id": user_id
|
||||
}
|
||||
|
||||
|
||||
@router.get("/{user_id}")
|
||||
async def get_user(
|
||||
user_id: str,
|
||||
admin_user: User = Depends(require_admin)
|
||||
):
|
||||
"""获取指定用户信息(仅管理员)"""
|
||||
user = await user_manager.get_user(user_id)
|
||||
|
||||
if not user:
|
||||
raise HTTPException(status_code=404, detail="用户不存在")
|
||||
|
||||
return user.dict()
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,90 @@
|
||||
"""应用配置管理"""
|
||||
from pydantic_settings import BaseSettings
|
||||
from typing import Optional
|
||||
from pathlib import Path
|
||||
import logging
|
||||
|
||||
# 获取项目根目录(从backend/app/config.py向上两级)
|
||||
PROJECT_ROOT = Path(__file__).parent.parent
|
||||
DATA_DIR = PROJECT_ROOT / "data"
|
||||
DATA_DIR.mkdir(exist_ok=True)
|
||||
|
||||
# 配置模块使用标准logging(在logger.py初始化之前)
|
||||
config_logger = logging.getLogger(__name__)
|
||||
|
||||
# 数据库文件路径(绝对路径)
|
||||
DB_FILE = DATA_DIR / "ai_story.db"
|
||||
|
||||
# 生成数据库URL(在类外部生成,确保使用绝对路径)
|
||||
# 将Windows反斜杠转换为正斜杠,SQLite URL格式要求
|
||||
DATABASE_URL = f"sqlite+aiosqlite:///{str(DB_FILE.absolute()).replace(chr(92), '/')}"
|
||||
config_logger.debug(f"数据库文件路径: {DB_FILE}")
|
||||
config_logger.debug(f"数据库URL: {DATABASE_URL}")
|
||||
|
||||
class Settings(BaseSettings):
|
||||
"""应用配置"""
|
||||
|
||||
# 应用配置
|
||||
app_name: str = "MuMuAINovel"
|
||||
app_version: str = "1.0.0"
|
||||
app_host: str = "0.0.0.0"
|
||||
app_port: int = 8000
|
||||
debug: bool = True
|
||||
|
||||
# 日志配置
|
||||
log_level: str = "INFO" # DEBUG, INFO, WARNING, ERROR, CRITICAL
|
||||
log_to_file: bool = True # 是否输出到文件
|
||||
log_file_path: str = str(PROJECT_ROOT / "logs" / "app.log")
|
||||
log_max_bytes: int = 10 * 1024 * 1024 # 10MB
|
||||
log_backup_count: int = 30 # 保留30个备份文件
|
||||
|
||||
# CORS配置
|
||||
cors_origins: list[str] = ["http://localhost:8000", "http://127.0.0.1:8000"]
|
||||
|
||||
# 数据库配置 - 使用预先计算好的绝对路径URL
|
||||
database_url: str = DATABASE_URL
|
||||
|
||||
# AI服务配置
|
||||
openai_api_key: Optional[str] = None
|
||||
openai_base_url: Optional[str] = None
|
||||
gemini_api_key: Optional[str] = None
|
||||
gemini_base_url: Optional[str] = None
|
||||
anthropic_api_key: Optional[str] = None
|
||||
anthropic_base_url: Optional[str] = None
|
||||
default_ai_provider: str = "openai"
|
||||
default_model: str = "gpt-4"
|
||||
default_temperature: float = 0.7
|
||||
default_max_tokens: int = 2000
|
||||
|
||||
# LinuxDO OAuth2 配置
|
||||
LINUXDO_CLIENT_ID: Optional[str] = None
|
||||
LINUXDO_CLIENT_SECRET: Optional[str] = None
|
||||
# 回调地址:Docker部署时必须使用实际域名或服务器IP,不能使用localhost
|
||||
# 本地开发: http://localhost:8000/api/auth/callback
|
||||
# 生产环境: https://your-domain.com/api/auth/callback 或 http://your-ip:8000/api/auth/callback
|
||||
LINUXDO_REDIRECT_URI: Optional[str] = None
|
||||
|
||||
# 前端URL配置(用于OAuth回调后重定向)
|
||||
# 本地开发: http://localhost:8000
|
||||
# 生产环境: https://your-domain.com 或 http://your-ip:8000
|
||||
FRONTEND_URL: str = "http://localhost:8000"
|
||||
|
||||
# 初始管理员配置(LinuxDO user_id)
|
||||
INITIAL_ADMIN_LINUXDO_ID: Optional[str] = None
|
||||
|
||||
# 本地账户登录配置
|
||||
LOCAL_AUTH_ENABLED: bool = True # 是否启用本地账户登录
|
||||
LOCAL_AUTH_USERNAME: Optional[str] = None # 本地登录用户名
|
||||
LOCAL_AUTH_PASSWORD: Optional[str] = None # 本地登录密码
|
||||
LOCAL_AUTH_DISPLAY_NAME: str = "本地用户" # 本地用户显示名称
|
||||
|
||||
class Config:
|
||||
env_file = ".env"
|
||||
case_sensitive = False
|
||||
|
||||
|
||||
# 创建全局配置实例
|
||||
settings = Settings()
|
||||
config_logger.info(f"配置加载完成: {settings.app_name} v{settings.app_version}")
|
||||
config_logger.debug(f"调试模式: {settings.debug}")
|
||||
config_logger.debug(f"AI提供商: {settings.default_ai_provider}")
|
||||
@@ -0,0 +1,261 @@
|
||||
"""数据库连接和会话管理 - 支持多用户数据隔离"""
|
||||
import asyncio
|
||||
from typing import Dict, Any
|
||||
from datetime import datetime
|
||||
from sqlalchemy import select, text
|
||||
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker
|
||||
from sqlalchemy.orm import declarative_base
|
||||
from sqlalchemy.pool import StaticPool
|
||||
from fastapi import Request, HTTPException
|
||||
from app.config import settings
|
||||
from app.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# 创建基类
|
||||
Base = declarative_base()
|
||||
|
||||
# 引擎缓存:每个用户一个引擎
|
||||
_engine_cache: Dict[str, Any] = {}
|
||||
|
||||
# 锁管理:用于保护引擎创建过程
|
||||
_engine_locks: Dict[str, asyncio.Lock] = {}
|
||||
_cache_lock = asyncio.Lock()
|
||||
|
||||
# 会话统计(用于监控连接泄漏)
|
||||
_session_stats = {
|
||||
"created": 0,
|
||||
"closed": 0,
|
||||
"active": 0,
|
||||
"errors": 0,
|
||||
"generator_exits": 0,
|
||||
"last_check": None
|
||||
}
|
||||
|
||||
|
||||
async def get_engine(user_id: str):
|
||||
"""获取或创建用户专属的数据库引擎(线程安全)
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
|
||||
Returns:
|
||||
用户专属的异步引擎
|
||||
"""
|
||||
if user_id in _engine_cache:
|
||||
return _engine_cache[user_id]
|
||||
|
||||
async with _cache_lock:
|
||||
if user_id not in _engine_locks:
|
||||
_engine_locks[user_id] = asyncio.Lock()
|
||||
user_lock = _engine_locks[user_id]
|
||||
|
||||
async with user_lock:
|
||||
if user_id not in _engine_cache:
|
||||
db_url = f"sqlite+aiosqlite:///data/ai_story_user_{user_id}.db"
|
||||
engine = create_async_engine(
|
||||
db_url,
|
||||
echo=False,
|
||||
future=True,
|
||||
poolclass=StaticPool,
|
||||
pool_pre_ping=True,
|
||||
pool_recycle=3600,
|
||||
connect_args={
|
||||
"timeout": 30,
|
||||
"check_same_thread": False
|
||||
}
|
||||
)
|
||||
|
||||
try:
|
||||
async with engine.begin() as conn:
|
||||
await conn.execute(text("PRAGMA journal_mode=WAL"))
|
||||
await conn.execute(text("PRAGMA synchronous=NORMAL"))
|
||||
await conn.execute(text("PRAGMA cache_size=-64000"))
|
||||
await conn.execute(text("PRAGMA temp_store=MEMORY"))
|
||||
await conn.execute(text("PRAGMA busy_timeout=5000"))
|
||||
|
||||
logger.info(f"✅ 用户 {user_id} 的数据库已优化(WAL模式 + 64MB缓存)")
|
||||
except Exception as e:
|
||||
logger.warning(f"⚠️ 用户 {user_id} 数据库优化失败: {str(e)}")
|
||||
_engine_cache[user_id] = engine
|
||||
logger.info(f"为用户 {user_id} 创建数据库引擎")
|
||||
|
||||
return _engine_cache[user_id]
|
||||
|
||||
|
||||
async def get_db(request: Request):
|
||||
"""获取数据库会话的依赖函数
|
||||
|
||||
从 request.state.user_id 获取用户ID,然后返回该用户的数据库会话
|
||||
"""
|
||||
user_id = getattr(request.state, "user_id", None)
|
||||
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="未登录或用户ID缺失")
|
||||
|
||||
engine = await get_engine(user_id)
|
||||
|
||||
AsyncSessionLocal = async_sessionmaker(
|
||||
engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False
|
||||
)
|
||||
|
||||
session = AsyncSessionLocal()
|
||||
session_id = id(session)
|
||||
|
||||
global _session_stats
|
||||
_session_stats["created"] += 1
|
||||
_session_stats["active"] += 1
|
||||
|
||||
logger.debug(f"📊 会话创建 [User:{user_id}][ID:{session_id}] - 活跃:{_session_stats['active']}, 总创建:{_session_stats['created']}, 总关闭:{_session_stats['closed']}")
|
||||
|
||||
try:
|
||||
yield session
|
||||
if session.in_transaction():
|
||||
await session.rollback()
|
||||
except GeneratorExit:
|
||||
_session_stats["generator_exits"] += 1
|
||||
logger.warning(f"⚠️ GeneratorExit [User:{user_id}][ID:{session_id}] - SSE连接断开(总计:{_session_stats['generator_exits']}次)")
|
||||
try:
|
||||
if session.in_transaction():
|
||||
await session.rollback()
|
||||
logger.info(f"✅ 事务已回滚 [User:{user_id}][ID:{session_id}](GeneratorExit)")
|
||||
except Exception as rollback_error:
|
||||
_session_stats["errors"] += 1
|
||||
logger.error(f"❌ GeneratorExit回滚失败 [User:{user_id}][ID:{session_id}]: {str(rollback_error)}")
|
||||
except Exception as e:
|
||||
_session_stats["errors"] += 1
|
||||
logger.error(f"❌ 会话异常 [User:{user_id}][ID:{session_id}]: {str(e)}")
|
||||
try:
|
||||
if session.in_transaction():
|
||||
await session.rollback()
|
||||
logger.info(f"✅ 事务已回滚 [User:{user_id}][ID:{session_id}](异常)")
|
||||
except Exception as rollback_error:
|
||||
logger.error(f"❌ 异常回滚失败 [User:{user_id}][ID:{session_id}]: {str(rollback_error)}")
|
||||
raise
|
||||
finally:
|
||||
try:
|
||||
if session.in_transaction():
|
||||
await session.rollback()
|
||||
logger.warning(f"⚠️ finally中发现未提交事务 [User:{user_id}][ID:{session_id}],已回滚")
|
||||
|
||||
await session.close()
|
||||
|
||||
_session_stats["closed"] += 1
|
||||
_session_stats["active"] -= 1
|
||||
_session_stats["last_check"] = datetime.now().isoformat()
|
||||
|
||||
logger.debug(f"📊 会话关闭 [User:{user_id}][ID:{session_id}] - 活跃:{_session_stats['active']}, 总创建:{_session_stats['created']}, 总关闭:{_session_stats['closed']}, 错误:{_session_stats['errors']}")
|
||||
|
||||
if _session_stats["active"] > 10:
|
||||
logger.warning(f"🚨 活跃会话数过多: {_session_stats['active']},可能存在连接泄漏!")
|
||||
elif _session_stats["active"] < 0:
|
||||
logger.error(f"🚨 活跃会话数异常: {_session_stats['active']},统计可能不准确!")
|
||||
|
||||
except Exception as e:
|
||||
_session_stats["errors"] += 1
|
||||
logger.error(f"❌ 关闭会话时出错 [User:{user_id}][ID:{session_id}]: {str(e)}", exc_info=True)
|
||||
try:
|
||||
await session.close()
|
||||
except:
|
||||
pass
|
||||
|
||||
async def _init_relationship_types(user_id: str):
|
||||
"""为指定用户初始化预置的关系类型数据
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
"""
|
||||
from app.models.relationship import RelationshipType
|
||||
|
||||
relationship_types = [
|
||||
{"name": "父亲", "category": "family", "reverse_name": "子女", "intimacy_range": "high", "icon": "👨"},
|
||||
{"name": "母亲", "category": "family", "reverse_name": "子女", "intimacy_range": "high", "icon": "👩"},
|
||||
{"name": "兄弟", "category": "family", "reverse_name": "兄弟", "intimacy_range": "high", "icon": "👬"},
|
||||
{"name": "姐妹", "category": "family", "reverse_name": "姐妹", "intimacy_range": "high", "icon": "👭"},
|
||||
{"name": "子女", "category": "family", "reverse_name": "父母", "intimacy_range": "high", "icon": "👶"},
|
||||
{"name": "配偶", "category": "family", "reverse_name": "配偶", "intimacy_range": "high", "icon": "💑"},
|
||||
{"name": "恋人", "category": "family", "reverse_name": "恋人", "intimacy_range": "high", "icon": "💕"},
|
||||
|
||||
{"name": "师父", "category": "social", "reverse_name": "徒弟", "intimacy_range": "high", "icon": "🎓"},
|
||||
{"name": "徒弟", "category": "social", "reverse_name": "师父", "intimacy_range": "high", "icon": "📚"},
|
||||
{"name": "朋友", "category": "social", "reverse_name": "朋友", "intimacy_range": "medium", "icon": "🤝"},
|
||||
{"name": "同学", "category": "social", "reverse_name": "同学", "intimacy_range": "medium", "icon": "🎒"},
|
||||
{"name": "邻居", "category": "social", "reverse_name": "邻居", "intimacy_range": "low", "icon": "🏘️"},
|
||||
{"name": "知己", "category": "social", "reverse_name": "知己", "intimacy_range": "high", "icon": "💙"},
|
||||
|
||||
{"name": "上司", "category": "professional", "reverse_name": "下属", "intimacy_range": "low", "icon": "👔"},
|
||||
{"name": "下属", "category": "professional", "reverse_name": "上司", "intimacy_range": "low", "icon": "💼"},
|
||||
{"name": "同事", "category": "professional", "reverse_name": "同事", "intimacy_range": "medium", "icon": "🤵"},
|
||||
{"name": "合作伙伴", "category": "professional", "reverse_name": "合作伙伴", "intimacy_range": "medium", "icon": "🤜🤛"},
|
||||
|
||||
{"name": "敌人", "category": "hostile", "reverse_name": "敌人", "intimacy_range": "low", "icon": "⚔️"},
|
||||
{"name": "仇人", "category": "hostile", "reverse_name": "仇人", "intimacy_range": "low", "icon": "💢"},
|
||||
{"name": "竞争对手", "category": "hostile", "reverse_name": "竞争对手", "intimacy_range": "low", "icon": "🎯"},
|
||||
{"name": "宿敌", "category": "hostile", "reverse_name": "宿敌", "intimacy_range": "low", "icon": "⚡"},
|
||||
]
|
||||
|
||||
try:
|
||||
engine = await get_engine(user_id)
|
||||
AsyncSessionLocal = async_sessionmaker(
|
||||
engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False
|
||||
)
|
||||
|
||||
async with AsyncSessionLocal() as session:
|
||||
result = await session.execute(select(RelationshipType))
|
||||
existing = result.scalars().first()
|
||||
|
||||
if existing:
|
||||
logger.info(f"用户 {user_id} 的关系类型数据已存在,跳过初始化")
|
||||
return
|
||||
|
||||
logger.info(f"开始为用户 {user_id} 插入关系类型数据...")
|
||||
for rt_data in relationship_types:
|
||||
relationship_type = RelationshipType(**rt_data)
|
||||
session.add(relationship_type)
|
||||
|
||||
await session.commit()
|
||||
logger.info(f"成功为用户 {user_id} 插入 {len(relationship_types)} 条关系类型数据")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"用户 {user_id} 初始化关系类型数据失败: {str(e)}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
|
||||
async def init_db(user_id: str):
|
||||
"""初始化指定用户的数据库,创建所有表并插入预置数据
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
"""
|
||||
try:
|
||||
logger.info(f"开始初始化用户 {user_id} 的数据库...")
|
||||
engine = await get_engine(user_id)
|
||||
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
await _init_relationship_types(user_id)
|
||||
|
||||
logger.info(f"用户 {user_id} 的数据库初始化成功")
|
||||
except Exception as e:
|
||||
logger.error(f"用户 {user_id} 的数据库初始化失败: {str(e)}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
async def close_db():
|
||||
"""关闭所有数据库连接"""
|
||||
try:
|
||||
logger.info("正在关闭所有数据库连接...")
|
||||
for user_id, engine in _engine_cache.items():
|
||||
await engine.dispose()
|
||||
logger.info(f"用户 {user_id} 的数据库连接已关闭")
|
||||
_engine_cache.clear()
|
||||
logger.info("所有数据库连接已关闭")
|
||||
except Exception as e:
|
||||
logger.error(f"关闭数据库连接失败: {str(e)}", exc_info=True)
|
||||
raise
|
||||
@@ -0,0 +1,73 @@
|
||||
"""初始化关系类型数据"""
|
||||
import asyncio
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select
|
||||
from app.database import AsyncSessionLocal
|
||||
from app.models.relationship import RelationshipType
|
||||
from app.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
async def init_relationship_types():
|
||||
"""初始化预置的关系类型数据"""
|
||||
|
||||
# 预置关系类型数据
|
||||
relationship_types = [
|
||||
# 家族关系
|
||||
{"name": "父亲", "category": "family", "reverse_name": "子女", "intimacy_range": "high", "icon": "👨"},
|
||||
{"name": "母亲", "category": "family", "reverse_name": "子女", "intimacy_range": "high", "icon": "👩"},
|
||||
{"name": "兄弟", "category": "family", "reverse_name": "兄弟", "intimacy_range": "high", "icon": "👬"},
|
||||
{"name": "姐妹", "category": "family", "reverse_name": "姐妹", "intimacy_range": "high", "icon": "👭"},
|
||||
{"name": "子女", "category": "family", "reverse_name": "父母", "intimacy_range": "high", "icon": "👶"},
|
||||
{"name": "配偶", "category": "family", "reverse_name": "配偶", "intimacy_range": "high", "icon": "💑"},
|
||||
{"name": "恋人", "category": "family", "reverse_name": "恋人", "intimacy_range": "high", "icon": "💕"},
|
||||
|
||||
# 社交关系
|
||||
{"name": "师父", "category": "social", "reverse_name": "徒弟", "intimacy_range": "high", "icon": "🎓"},
|
||||
{"name": "徒弟", "category": "social", "reverse_name": "师父", "intimacy_range": "high", "icon": "📚"},
|
||||
{"name": "朋友", "category": "social", "reverse_name": "朋友", "intimacy_range": "medium", "icon": "🤝"},
|
||||
{"name": "同学", "category": "social", "reverse_name": "同学", "intimacy_range": "medium", "icon": "🎒"},
|
||||
{"name": "邻居", "category": "social", "reverse_name": "邻居", "intimacy_range": "low", "icon": "🏘️"},
|
||||
{"name": "知己", "category": "social", "reverse_name": "知己", "intimacy_range": "high", "icon": "💙"},
|
||||
|
||||
# 职业关系
|
||||
{"name": "上司", "category": "professional", "reverse_name": "下属", "intimacy_range": "low", "icon": "👔"},
|
||||
{"name": "下属", "category": "professional", "reverse_name": "上司", "intimacy_range": "low", "icon": "💼"},
|
||||
{"name": "同事", "category": "professional", "reverse_name": "同事", "intimacy_range": "medium", "icon": "🤵"},
|
||||
{"name": "合作伙伴", "category": "professional", "reverse_name": "合作伙伴", "intimacy_range": "medium", "icon": "🤜🤛"},
|
||||
|
||||
# 敌对关系
|
||||
{"name": "敌人", "category": "hostile", "reverse_name": "敌人", "intimacy_range": "low", "icon": "⚔️"},
|
||||
{"name": "仇人", "category": "hostile", "reverse_name": "仇人", "intimacy_range": "low", "icon": "💢"},
|
||||
{"name": "竞争对手", "category": "hostile", "reverse_name": "竞争对手", "intimacy_range": "low", "icon": "🎯"},
|
||||
{"name": "宿敌", "category": "hostile", "reverse_name": "宿敌", "intimacy_range": "low", "icon": "⚡"},
|
||||
]
|
||||
|
||||
async with AsyncSessionLocal() as session:
|
||||
try:
|
||||
# 检查是否已经有数据
|
||||
result = await session.execute(select(RelationshipType))
|
||||
existing = result.scalars().first()
|
||||
|
||||
if existing:
|
||||
logger.info("关系类型数据已存在,跳过初始化")
|
||||
return
|
||||
|
||||
# 插入预置数据
|
||||
logger.info("开始插入关系类型数据...")
|
||||
for rt_data in relationship_types:
|
||||
relationship_type = RelationshipType(**rt_data)
|
||||
session.add(relationship_type)
|
||||
|
||||
await session.commit()
|
||||
logger.info(f"成功插入 {len(relationship_types)} 条关系类型数据")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"初始化关系类型数据失败: {str(e)}", exc_info=True)
|
||||
await session.rollback()
|
||||
raise
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(init_relationship_types())
|
||||
@@ -0,0 +1,158 @@
|
||||
"""统一日志配置模块 - Uvicorn风格"""
|
||||
import logging
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from logging.handlers import RotatingFileHandler
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class UvicornFormatter(logging.Formatter):
|
||||
"""Uvicorn风格的日志格式化器"""
|
||||
|
||||
# 日志级别颜色(ANSI转义码)
|
||||
COLORS = {
|
||||
'DEBUG': '\033[36m', # 青色
|
||||
'INFO': '\033[32m', # 绿色
|
||||
'WARNING': '\033[33m', # 黄色
|
||||
'ERROR': '\033[31m', # 红色
|
||||
'CRITICAL': '\033[35m', # 紫色
|
||||
}
|
||||
RESET = '\033[0m'
|
||||
|
||||
def __init__(self, use_colors: bool = True):
|
||||
"""
|
||||
初始化格式化器
|
||||
|
||||
Args:
|
||||
use_colors: 是否使用颜色(控制台输出使用,文件输出不使用)
|
||||
"""
|
||||
super().__init__()
|
||||
self.use_colors = use_colors
|
||||
|
||||
def format(self, record):
|
||||
"""格式化日志记录为 Uvicorn 风格"""
|
||||
# 获取日志级别名称
|
||||
levelname = record.levelname
|
||||
|
||||
# 添加颜色(如果启用且终端支持)
|
||||
if self.use_colors and sys.stderr.isatty():
|
||||
colored_level = f"{self.COLORS.get(levelname, '')}{levelname}{self.RESET}"
|
||||
else:
|
||||
colored_level = levelname
|
||||
|
||||
# 添加请求追踪ID(如果存在)
|
||||
request_id = getattr(record, 'request_id', None)
|
||||
request_id_str = f" [{request_id}]" if request_id else ""
|
||||
|
||||
# Uvicorn风格格式: INFO: module_name - message [request_id]
|
||||
# 注意:INFO后面有5个空格,保持对齐
|
||||
return f"{colored_level}: {record.name}{request_id_str} - {record.getMessage()}"
|
||||
|
||||
|
||||
# 全局标志,防止重复初始化
|
||||
_logging_configured = False
|
||||
|
||||
def setup_logging(
|
||||
level: str = "INFO",
|
||||
log_to_file: bool = False,
|
||||
log_file_path: Optional[str] = None,
|
||||
max_bytes: int = 10 * 1024 * 1024,
|
||||
backup_count: int = 30
|
||||
):
|
||||
"""
|
||||
配置统一的 Uvicorn 风格日志系统
|
||||
|
||||
Args:
|
||||
level: 日志级别 (DEBUG, INFO, WARNING, ERROR, CRITICAL)
|
||||
log_to_file: 是否输出到文件
|
||||
log_file_path: 日志文件路径
|
||||
max_bytes: 单个日志文件最大字节数(默认10MB)
|
||||
backup_count: 保留的备份文件数量(默认30个)
|
||||
"""
|
||||
global _logging_configured
|
||||
|
||||
# 如果已经配置过,直接返回
|
||||
if _logging_configured:
|
||||
return logging.getLogger()
|
||||
|
||||
# 获取根日志器
|
||||
root_logger = logging.getLogger()
|
||||
root_logger.setLevel(getattr(logging, level.upper()))
|
||||
|
||||
# 清除已有的处理器,避免重复
|
||||
root_logger.handlers.clear()
|
||||
|
||||
# 1. 创建控制台处理器(带颜色)
|
||||
console_handler = logging.StreamHandler(sys.stderr)
|
||||
console_handler.setLevel(getattr(logging, level.upper()))
|
||||
console_formatter = UvicornFormatter(use_colors=True)
|
||||
console_handler.setFormatter(console_formatter)
|
||||
root_logger.addHandler(console_handler)
|
||||
|
||||
# 2. 创建文件处理器(如果启用)
|
||||
if log_to_file and log_file_path:
|
||||
# 确保日志目录存在
|
||||
log_file = Path(log_file_path)
|
||||
log_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 使用RotatingFileHandler实现日志轮转
|
||||
file_handler = RotatingFileHandler(
|
||||
filename=log_file_path,
|
||||
maxBytes=max_bytes,
|
||||
backupCount=backup_count,
|
||||
encoding='utf-8'
|
||||
)
|
||||
file_handler.setLevel(getattr(logging, level.upper()))
|
||||
|
||||
# 文件日志不使用颜色
|
||||
file_formatter = UvicornFormatter(use_colors=False)
|
||||
file_handler.setFormatter(file_formatter)
|
||||
root_logger.addHandler(file_handler)
|
||||
|
||||
# 记录日志配置信息
|
||||
root_logger.info(f"日志文件输出已启用: {log_file_path}")
|
||||
root_logger.info(f"日志轮转配置: 单文件最大{max_bytes / 1024 / 1024:.1f}MB, 保留{backup_count}个备份")
|
||||
|
||||
# 配置第三方库的日志级别
|
||||
_configure_third_party_loggers()
|
||||
|
||||
# 标记为已配置
|
||||
_logging_configured = True
|
||||
|
||||
return root_logger
|
||||
|
||||
|
||||
def _configure_third_party_loggers():
|
||||
"""配置第三方库的日志级别"""
|
||||
# SQLAlchemy - 禁用SQL日志
|
||||
logging.getLogger('sqlalchemy.engine').setLevel(logging.WARNING)
|
||||
logging.getLogger('sqlalchemy.pool').setLevel(logging.WARNING)
|
||||
logging.getLogger('sqlalchemy.dialects').setLevel(logging.WARNING)
|
||||
logging.getLogger('sqlalchemy.orm').setLevel(logging.WARNING)
|
||||
|
||||
# Watchfiles - 开发时的文件监控,降低级别
|
||||
logging.getLogger('watchfiles').setLevel(logging.WARNING)
|
||||
|
||||
# httpx - HTTP客户端
|
||||
logging.getLogger('httpx').setLevel(logging.WARNING)
|
||||
|
||||
# openai/anthropic - AI客户端库
|
||||
logging.getLogger('openai').setLevel(logging.WARNING)
|
||||
logging.getLogger('anthropic').setLevel(logging.WARNING)
|
||||
|
||||
# 应用模块 - 可根据需要调整
|
||||
logging.getLogger('app.services.ai_service').setLevel(logging.WARNING)
|
||||
logging.getLogger('app.api.wizard').setLevel(logging.WARNING)
|
||||
|
||||
|
||||
def get_logger(name: str) -> logging.Logger:
|
||||
"""
|
||||
获取指定名称的日志器
|
||||
|
||||
Args:
|
||||
name: 日志器名称,通常使用 __name__
|
||||
|
||||
Returns:
|
||||
配置好的日志器实例
|
||||
"""
|
||||
return logging.getLogger(name)
|
||||
@@ -0,0 +1,176 @@
|
||||
"""FastAPI应用主入口"""
|
||||
from fastapi import FastAPI, Request, status
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from fastapi.responses import JSONResponse, FileResponse
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from contextlib import asynccontextmanager
|
||||
from pathlib import Path
|
||||
|
||||
from app.config import settings
|
||||
from app.database import close_db, _session_stats
|
||||
from app.logger import setup_logging, get_logger
|
||||
from app.middleware import RequestIDMiddleware
|
||||
from app.middleware.auth_middleware import AuthMiddleware
|
||||
|
||||
setup_logging(
|
||||
level=settings.log_level,
|
||||
log_to_file=settings.log_to_file,
|
||||
log_file_path=settings.log_file_path,
|
||||
max_bytes=settings.log_max_bytes,
|
||||
backup_count=settings.log_backup_count
|
||||
)
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""应用生命周期管理"""
|
||||
logger.info("应用启动,等待用户登录...")
|
||||
|
||||
yield
|
||||
await close_db()
|
||||
logger.info("应用已关闭")
|
||||
|
||||
|
||||
app = FastAPI(
|
||||
title=settings.app_name,
|
||||
version=settings.app_version,
|
||||
description="AI写小说工具 - 智能小说创作助手",
|
||||
lifespan=lifespan
|
||||
)
|
||||
|
||||
@app.exception_handler(RequestValidationError)
|
||||
async def validation_exception_handler(request: Request, exc: RequestValidationError):
|
||||
"""处理请求验证错误"""
|
||||
logger.error(f"请求验证失败: {exc.errors()}")
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||
content={
|
||||
"detail": "请求参数验证失败",
|
||||
"errors": exc.errors()
|
||||
}
|
||||
)
|
||||
|
||||
@app.exception_handler(Exception)
|
||||
async def global_exception_handler(request: Request, exc: Exception):
|
||||
"""处理所有未捕获的异常"""
|
||||
logger.error(f"未处理的异常: {type(exc).__name__}: {str(exc)}", exc_info=True)
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
content={
|
||||
"detail": "服务器内部错误",
|
||||
"message": str(exc) if settings.debug else "请稍后重试"
|
||||
}
|
||||
)
|
||||
|
||||
app.add_middleware(RequestIDMiddleware)
|
||||
app.add_middleware(AuthMiddleware)
|
||||
|
||||
if settings.debug:
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
else:
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=settings.cors_origins,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
"""健康检查"""
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
@app.get("/health/db-sessions")
|
||||
async def db_session_stats():
|
||||
"""
|
||||
数据库会话统计(监控连接泄漏)
|
||||
|
||||
返回:
|
||||
- created: 总创建会话数
|
||||
- closed: 总关闭会话数
|
||||
- active: 当前活跃会话数(应该接近0)
|
||||
- errors: 错误次数
|
||||
- generator_exits: SSE断开次数
|
||||
- last_check: 最后检查时间
|
||||
"""
|
||||
return {
|
||||
"status": "ok",
|
||||
"session_stats": _session_stats,
|
||||
"warning": "活跃会话数过多" if _session_stats["active"] > 10 else None
|
||||
}
|
||||
|
||||
|
||||
from app.api import (
|
||||
projects, outlines, characters, chapters,
|
||||
wizard_stream, relationships, organizations,
|
||||
auth, users
|
||||
)
|
||||
|
||||
app.include_router(auth.router, prefix="/api")
|
||||
app.include_router(users.router, prefix="/api")
|
||||
|
||||
app.include_router(projects.router, prefix="/api")
|
||||
app.include_router(wizard_stream.router, prefix="/api")
|
||||
app.include_router(outlines.router, prefix="/api")
|
||||
app.include_router(characters.router, prefix="/api")
|
||||
app.include_router(chapters.router, prefix="/api")
|
||||
app.include_router(relationships.router, prefix="/api")
|
||||
app.include_router(organizations.router, prefix="/api")
|
||||
|
||||
static_dir = Path(__file__).parent.parent / "static"
|
||||
if static_dir.exists():
|
||||
app.mount("/assets", StaticFiles(directory=str(static_dir / "assets")), name="assets")
|
||||
|
||||
@app.get("/{full_path:path}")
|
||||
async def serve_spa(full_path: str):
|
||||
"""服务单页应用,所有非API路径返回index.html"""
|
||||
if full_path.startswith("api/"):
|
||||
return JSONResponse(
|
||||
status_code=404,
|
||||
content={"detail": "API路径不存在"}
|
||||
)
|
||||
|
||||
file_path = static_dir / full_path
|
||||
if file_path.is_file():
|
||||
return FileResponse(file_path)
|
||||
|
||||
index_file = static_dir / "index.html"
|
||||
if index_file.exists():
|
||||
return FileResponse(index_file)
|
||||
|
||||
return JSONResponse(
|
||||
status_code=404,
|
||||
content={"detail": "页面不存在"}
|
||||
)
|
||||
else:
|
||||
logger.warning("静态文件目录不存在,请先构建前端: cd frontend && npm run build")
|
||||
|
||||
@app.get("/")
|
||||
async def root():
|
||||
return {
|
||||
"message": "欢迎使用AI Story Creator",
|
||||
"version": settings.app_version,
|
||||
"docs": "/docs",
|
||||
"notice": "请先构建前端: cd frontend && npm run build"
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
uvicorn.run(
|
||||
"app.main:app",
|
||||
host=settings.app_host,
|
||||
port=settings.app_port,
|
||||
reload=settings.debug
|
||||
)
|
||||
@@ -0,0 +1,4 @@
|
||||
"""中间件模块"""
|
||||
from .request_id import RequestIDMiddleware
|
||||
|
||||
__all__ = ['RequestIDMiddleware']
|
||||
@@ -0,0 +1,39 @@
|
||||
"""
|
||||
认证中间件 - 从 Cookie 中提取用户信息并注入到 request.state
|
||||
"""
|
||||
from fastapi import Request
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from app.user_manager import user_manager
|
||||
|
||||
|
||||
class AuthMiddleware(BaseHTTPMiddleware):
|
||||
"""认证中间件"""
|
||||
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
"""
|
||||
处理请求,从 Cookie 中提取用户 ID 并注入到 request.state
|
||||
"""
|
||||
# 从 Cookie 中获取用户 ID
|
||||
user_id = request.cookies.get("user_id")
|
||||
|
||||
# 注入到 request.state
|
||||
if user_id:
|
||||
user = await user_manager.get_user(user_id)
|
||||
if user:
|
||||
request.state.user_id = user_id
|
||||
request.state.user = user
|
||||
request.state.is_admin = user.is_admin
|
||||
else:
|
||||
# 用户不存在,清除状态
|
||||
request.state.user_id = None
|
||||
request.state.user = None
|
||||
request.state.is_admin = False
|
||||
else:
|
||||
# 未登录
|
||||
request.state.user_id = None
|
||||
request.state.user = None
|
||||
request.state.is_admin = False
|
||||
|
||||
# 继续处理请求
|
||||
response = await call_next(request)
|
||||
return response
|
||||
@@ -0,0 +1,78 @@
|
||||
"""请求追踪ID中间件"""
|
||||
import uuid
|
||||
import logging
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import Response
|
||||
from typing import Callable
|
||||
|
||||
|
||||
class RequestIDMiddleware(BaseHTTPMiddleware):
|
||||
"""
|
||||
请求追踪ID中间件
|
||||
|
||||
为每个请求生成唯一ID,并添加到日志上下文中
|
||||
"""
|
||||
|
||||
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
||||
"""
|
||||
处理请求,添加追踪ID
|
||||
|
||||
Args:
|
||||
request: 请求对象
|
||||
call_next: 下一个处理器
|
||||
|
||||
Returns:
|
||||
响应对象
|
||||
"""
|
||||
# 从请求头获取追踪ID,或生成新的
|
||||
request_id = request.headers.get('X-Request-ID') or str(uuid.uuid4())
|
||||
|
||||
# 将请求ID存储到request.state中,方便后续访问
|
||||
request.state.request_id = request_id
|
||||
|
||||
# 创建日志过滤器,自动添加request_id到日志记录
|
||||
log_filter = RequestIDFilter(request_id)
|
||||
|
||||
# 获取根日志器并添加过滤器
|
||||
root_logger = logging.getLogger()
|
||||
root_logger.addFilter(log_filter)
|
||||
|
||||
try:
|
||||
# 处理请求
|
||||
response = await call_next(request)
|
||||
|
||||
# 将请求ID添加到响应头
|
||||
response.headers['X-Request-ID'] = request_id
|
||||
|
||||
return response
|
||||
finally:
|
||||
# 移除过滤器,避免影响其他请求
|
||||
root_logger.removeFilter(log_filter)
|
||||
|
||||
|
||||
class RequestIDFilter(logging.Filter):
|
||||
"""日志过滤器,为日志记录添加request_id属性"""
|
||||
|
||||
def __init__(self, request_id: str):
|
||||
"""
|
||||
初始化过滤器
|
||||
|
||||
Args:
|
||||
request_id: 请求追踪ID
|
||||
"""
|
||||
super().__init__()
|
||||
self.request_id = request_id
|
||||
|
||||
def filter(self, record: logging.LogRecord) -> bool:
|
||||
"""
|
||||
为日志记录添加request_id属性
|
||||
|
||||
Args:
|
||||
record: 日志记录
|
||||
|
||||
Returns:
|
||||
True(不过滤任何日志)
|
||||
"""
|
||||
record.request_id = self.request_id
|
||||
return True
|
||||
@@ -0,0 +1,26 @@
|
||||
"""数据库模型"""
|
||||
from app.models.project import Project
|
||||
from app.models.outline import Outline
|
||||
from app.models.character import Character
|
||||
from app.models.chapter import Chapter
|
||||
from app.models.generation_history import GenerationHistory
|
||||
from app.models.settings import Settings
|
||||
from app.models.relationship import (
|
||||
RelationshipType,
|
||||
CharacterRelationship,
|
||||
Organization,
|
||||
OrganizationMember
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"Project",
|
||||
"Outline",
|
||||
"Character",
|
||||
"Chapter",
|
||||
"GenerationHistory",
|
||||
"Settings",
|
||||
"RelationshipType",
|
||||
"CharacterRelationship",
|
||||
"Organization",
|
||||
"OrganizationMember",
|
||||
]
|
||||
@@ -0,0 +1,24 @@
|
||||
"""章节数据模型"""
|
||||
from sqlalchemy import Column, String, Text, Integer, DateTime, ForeignKey
|
||||
from sqlalchemy.sql import func
|
||||
from app.database import Base
|
||||
import uuid
|
||||
|
||||
|
||||
class Chapter(Base):
|
||||
"""章节表"""
|
||||
__tablename__ = "chapters"
|
||||
|
||||
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
|
||||
project_id = Column(String(36), ForeignKey("projects.id", ondelete="CASCADE"), nullable=False)
|
||||
chapter_number = Column(Integer, nullable=False, comment="章节序号")
|
||||
title = Column(String(200), nullable=False, comment="章节标题")
|
||||
content = Column(Text, comment="章节内容")
|
||||
summary = Column(Text, comment="章节摘要")
|
||||
word_count = Column(Integer, default=0, comment="字数统计")
|
||||
status = Column(String(20), default="draft", comment="章节状态")
|
||||
created_at = Column(DateTime, server_default=func.now(), comment="创建时间")
|
||||
updated_at = Column(DateTime, server_default=func.now(), onupdate=func.now(), comment="更新时间")
|
||||
|
||||
def __repr__(self):
|
||||
return f"<Chapter(id={self.id}, chapter_number={self.chapter_number}, title={self.title})>"
|
||||
@@ -0,0 +1,44 @@
|
||||
"""角色数据模型"""
|
||||
from sqlalchemy import Column, String, Text, DateTime, ForeignKey, Boolean
|
||||
from sqlalchemy.sql import func
|
||||
from app.database import Base
|
||||
import uuid
|
||||
|
||||
|
||||
class Character(Base):
|
||||
"""角色表(包括角色和组织)"""
|
||||
__tablename__ = "characters"
|
||||
|
||||
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
|
||||
project_id = Column(String(36), ForeignKey("projects.id", ondelete="CASCADE"), nullable=False)
|
||||
|
||||
# 基本信息
|
||||
name = Column(String(100), nullable=False, comment="角色/组织名称")
|
||||
age = Column(String(20), comment="年龄")
|
||||
gender = Column(String(20), comment="性别")
|
||||
is_organization = Column(Boolean, default=False, comment="是否为组织")
|
||||
|
||||
# 角色类型:protagonist(主角)/supporting(配角)/antagonist(反派)
|
||||
role_type = Column(String(50), comment="角色类型")
|
||||
|
||||
# 角色详细信息
|
||||
personality = Column(Text, comment="性格特点/组织特性")
|
||||
background = Column(Text, comment="背景故事")
|
||||
appearance = Column(Text, comment="外貌描述")
|
||||
relationships = Column(Text, comment="人物关系(JSON)")
|
||||
|
||||
# 组织特有字段
|
||||
organization_type = Column(String(100), comment="组织类型")
|
||||
organization_purpose = Column(String(500), comment="组织目的")
|
||||
organization_members = Column(Text, comment="组织成员(JSON)")
|
||||
|
||||
# 其他
|
||||
avatar_url = Column(String(500), comment="头像URL")
|
||||
traits = Column(Text, comment="特征标签(JSON)")
|
||||
|
||||
created_at = Column(DateTime, server_default=func.now(), comment="创建时间")
|
||||
updated_at = Column(DateTime, server_default=func.now(), onupdate=func.now(), comment="更新时间")
|
||||
|
||||
def __repr__(self):
|
||||
entity_type = "组织" if self.is_organization else "角色"
|
||||
return f"<Character(id={self.id}, name={self.name}, type={entity_type})>"
|
||||
@@ -0,0 +1,23 @@
|
||||
"""生成历史数据模型"""
|
||||
from sqlalchemy import Column, String, Text, Integer, Float, DateTime, ForeignKey
|
||||
from sqlalchemy.sql import func
|
||||
from app.database import Base
|
||||
import uuid
|
||||
|
||||
|
||||
class GenerationHistory(Base):
|
||||
"""生成历史表"""
|
||||
__tablename__ = "generation_history"
|
||||
|
||||
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
|
||||
project_id = Column(String(36), ForeignKey("projects.id", ondelete="CASCADE"), nullable=False)
|
||||
chapter_id = Column(String(36), ForeignKey("chapters.id", ondelete="SET NULL"), nullable=True)
|
||||
prompt = Column(Text, comment="使用的提示词")
|
||||
generated_content = Column(Text, comment="生成的内容")
|
||||
model = Column(String(50), comment="使用的模型")
|
||||
tokens_used = Column(Integer, comment="消耗的token数")
|
||||
generation_time = Column(Float, comment="生成耗时(秒)")
|
||||
created_at = Column(DateTime, server_default=func.now(), comment="创建时间")
|
||||
|
||||
def __repr__(self):
|
||||
return f"<GenerationHistory(id={self.id}, model={self.model})>"
|
||||
@@ -0,0 +1,22 @@
|
||||
"""大纲数据模型"""
|
||||
from sqlalchemy import Column, String, Text, Integer, DateTime, ForeignKey
|
||||
from sqlalchemy.sql import func
|
||||
from app.database import Base
|
||||
import uuid
|
||||
|
||||
|
||||
class Outline(Base):
|
||||
"""大纲表"""
|
||||
__tablename__ = "outlines"
|
||||
|
||||
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
|
||||
project_id = Column(String(36), ForeignKey("projects.id", ondelete="CASCADE"), nullable=False)
|
||||
title = Column(String(200), nullable=False, comment="大纲标题")
|
||||
content = Column(Text, comment="大纲内容")
|
||||
structure = Column(Text, comment="结构化大纲数据(JSON)")
|
||||
order_index = Column(Integer, comment="排序序号")
|
||||
created_at = Column(DateTime, server_default=func.now(), comment="创建时间")
|
||||
updated_at = Column(DateTime, server_default=func.now(), onupdate=func.now(), comment="更新时间")
|
||||
|
||||
def __repr__(self):
|
||||
return f"<Outline(id={self.id}, title={self.title})>"
|
||||
@@ -0,0 +1,38 @@
|
||||
"""项目数据模型"""
|
||||
from sqlalchemy import Column, String, Text, DateTime, Integer
|
||||
from sqlalchemy.sql import func
|
||||
from app.database import Base
|
||||
import uuid
|
||||
|
||||
|
||||
class Project(Base):
|
||||
"""项目表"""
|
||||
__tablename__ = "projects"
|
||||
|
||||
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
|
||||
title = Column(String(200), nullable=False, comment="项目标题")
|
||||
description = Column(Text, comment="项目简介")
|
||||
theme = Column(Text, comment="主题")
|
||||
genre = Column(String(50), comment="小说类型")
|
||||
target_words = Column(Integer, default=0, comment="目标字数")
|
||||
current_words = Column(Integer, default=0, comment="当前字数")
|
||||
status = Column(String(20), default="planning", comment="创作状态")
|
||||
wizard_status = Column(String(20), default="incomplete", comment="向导完成状态: incomplete/completed")
|
||||
wizard_step = Column(Integer, default=0, comment="向导当前步骤: 0-4")
|
||||
|
||||
# 世界构建字段
|
||||
world_time_period = Column(Text, comment="时间背景")
|
||||
world_location = Column(Text, comment="地理位置")
|
||||
world_atmosphere = Column(Text, comment="氛围基调")
|
||||
world_rules = Column(Text, comment="世界规则")
|
||||
|
||||
# 项目配置
|
||||
chapter_count = Column(Integer, comment="章节数量")
|
||||
narrative_perspective = Column(String(50), comment="叙事视角:first_person/third_person/omniscient")
|
||||
character_count = Column(Integer, default=5, comment="角色数量")
|
||||
|
||||
created_at = Column(DateTime, server_default=func.now(), comment="创建时间")
|
||||
updated_at = Column(DateTime, server_default=func.now(), onupdate=func.now(), comment="更新时间")
|
||||
|
||||
def __repr__(self):
|
||||
return f"<Project(id={self.id}, title={self.title})>"
|
||||
@@ -0,0 +1,116 @@
|
||||
"""角色关系和组织管理数据模型"""
|
||||
from sqlalchemy import Column, String, Integer, Text, DateTime, ForeignKey, Boolean
|
||||
from sqlalchemy.sql import func
|
||||
from app.database import Base
|
||||
import uuid
|
||||
|
||||
|
||||
class RelationshipType(Base):
|
||||
"""关系类型定义表"""
|
||||
__tablename__ = "relationship_types"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True, autoincrement=True)
|
||||
name = Column(String(50), nullable=False, comment="关系名称")
|
||||
category = Column(String(20), nullable=False, comment="分类:family/social/hostile/professional")
|
||||
reverse_name = Column(String(50), comment="反向关系名称")
|
||||
intimacy_range = Column(String(20), comment="亲密度范围:high/medium/low")
|
||||
icon = Column(String(50), comment="图标标识")
|
||||
description = Column(Text, comment="关系描述")
|
||||
created_at = Column(DateTime, server_default=func.now(), comment="创建时间")
|
||||
|
||||
def __repr__(self):
|
||||
return f"<RelationshipType(id={self.id}, name={self.name}, category={self.category})>"
|
||||
|
||||
|
||||
class CharacterRelationship(Base):
|
||||
"""角色关系表"""
|
||||
__tablename__ = "character_relationships"
|
||||
|
||||
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()), comment="关系ID")
|
||||
project_id = Column(String(36), ForeignKey("projects.id", ondelete="CASCADE"), nullable=False, index=True, comment="项目ID")
|
||||
|
||||
# 关系双方
|
||||
character_from_id = Column(String(36), ForeignKey("characters.id", ondelete="CASCADE"), nullable=False, index=True, comment="角色A的ID")
|
||||
character_to_id = Column(String(36), ForeignKey("characters.id", ondelete="CASCADE"), nullable=False, index=True, comment="角色B的ID")
|
||||
|
||||
# 关系类型
|
||||
relationship_type_id = Column(Integer, ForeignKey("relationship_types.id"), index=True, comment="关系类型ID")
|
||||
relationship_name = Column(String(100), comment="自定义关系名称")
|
||||
|
||||
# 关系属性
|
||||
intimacy_level = Column(Integer, default=50, comment="亲密度:0-100")
|
||||
status = Column(String(20), default="active", comment="状态:active/broken/past/complicated")
|
||||
description = Column(Text, comment="关系详细描述")
|
||||
|
||||
# 故事时间线
|
||||
started_at = Column(String(100), comment="关系开始时间(故事时间)")
|
||||
ended_at = Column(String(100), comment="关系结束时间(故事时间)")
|
||||
|
||||
# 来源标识
|
||||
source = Column(String(20), default="ai", comment="来源:ai/manual/imported")
|
||||
|
||||
created_at = Column(DateTime, server_default=func.now(), comment="创建时间")
|
||||
updated_at = Column(DateTime, server_default=func.now(), onupdate=func.now(), comment="更新时间")
|
||||
|
||||
def __repr__(self):
|
||||
return f"<CharacterRelationship(id={self.id}, from={self.character_from_id}, to={self.character_to_id})>"
|
||||
|
||||
|
||||
class Organization(Base):
|
||||
"""组织详情表"""
|
||||
__tablename__ = "organizations"
|
||||
|
||||
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()), comment="组织ID")
|
||||
character_id = Column(String(36), ForeignKey("characters.id", ondelete="CASCADE"), nullable=False, unique=True, comment="关联的角色ID")
|
||||
project_id = Column(String(36), ForeignKey("projects.id", ondelete="CASCADE"), nullable=False, index=True, comment="项目ID")
|
||||
|
||||
# 组织层级
|
||||
parent_org_id = Column(String(36), ForeignKey("organizations.id", ondelete="SET NULL"), comment="父组织ID")
|
||||
level = Column(Integer, default=0, comment="组织层级")
|
||||
|
||||
# 组织属性
|
||||
power_level = Column(Integer, default=50, comment="势力等级:0-100")
|
||||
member_count = Column(Integer, default=0, comment="成员数量")
|
||||
location = Column(Text, comment="所在地")
|
||||
|
||||
# 组织特色
|
||||
motto = Column(String(200), comment="宗旨/口号")
|
||||
color = Column(String(20), comment="代表颜色")
|
||||
|
||||
created_at = Column(DateTime, server_default=func.now(), comment="创建时间")
|
||||
updated_at = Column(DateTime, server_default=func.now(), onupdate=func.now(), comment="更新时间")
|
||||
|
||||
def __repr__(self):
|
||||
return f"<Organization(id={self.id}, character_id={self.character_id})>"
|
||||
|
||||
|
||||
class OrganizationMember(Base):
|
||||
"""组织成员关系表"""
|
||||
__tablename__ = "organization_members"
|
||||
|
||||
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()), comment="成员关系ID")
|
||||
organization_id = Column(String(36), ForeignKey("organizations.id", ondelete="CASCADE"), nullable=False, index=True, comment="组织ID")
|
||||
character_id = Column(String(36), ForeignKey("characters.id", ondelete="CASCADE"), nullable=False, index=True, comment="角色ID")
|
||||
|
||||
# 职位信息
|
||||
position = Column(String(100), nullable=False, comment="职位名称")
|
||||
rank = Column(Integer, default=0, comment="职位等级")
|
||||
|
||||
# 成员状态
|
||||
status = Column(String(20), default="active", comment="状态:active/retired/expelled/deceased")
|
||||
joined_at = Column(String(100), comment="加入时间(故事时间)")
|
||||
left_at = Column(String(100), comment="离开时间(故事时间)")
|
||||
|
||||
# 成员属性
|
||||
loyalty = Column(Integer, default=50, comment="忠诚度:0-100")
|
||||
contribution = Column(Integer, default=0, comment="贡献度:0-100")
|
||||
|
||||
# 来源标识
|
||||
source = Column(String(20), default="ai", comment="来源:ai/manual")
|
||||
|
||||
notes = Column(Text, comment="备注")
|
||||
created_at = Column(DateTime, server_default=func.now(), comment="创建时间")
|
||||
updated_at = Column(DateTime, server_default=func.now(), onupdate=func.now(), comment="更新时间")
|
||||
|
||||
def __repr__(self):
|
||||
return f"<OrganizationMember(id={self.id}, org={self.organization_id}, char={self.character_id})>"
|
||||
@@ -0,0 +1,24 @@
|
||||
"""设置数据模型"""
|
||||
from sqlalchemy import Column, String, Text, Float, Integer, DateTime
|
||||
from sqlalchemy.sql import func
|
||||
from app.database import Base
|
||||
import uuid
|
||||
|
||||
|
||||
class Settings(Base):
|
||||
"""设置表"""
|
||||
__tablename__ = "settings"
|
||||
|
||||
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
|
||||
api_provider = Column(String(50), default="openai", comment="API提供商")
|
||||
api_key = Column(String(500), comment="API密钥")
|
||||
api_base_url = Column(String(500), comment="自定义API地址")
|
||||
model_name = Column(String(100), default="gpt-4", comment="模型名称")
|
||||
temperature = Column(Float, default=0.7, comment="温度参数")
|
||||
max_tokens = Column(Integer, default=2000, comment="最大token数")
|
||||
preferences = Column(Text, comment="其他偏好设置(JSON)")
|
||||
created_at = Column(DateTime, server_default=func.now(), comment="创建时间")
|
||||
updated_at = Column(DateTime, server_default=func.now(), onupdate=func.now(), comment="更新时间")
|
||||
|
||||
def __repr__(self):
|
||||
return f"<Settings(id={self.id}, api_provider={self.api_provider})>"
|
||||
@@ -0,0 +1 @@
|
||||
"""Pydantic数据模型"""
|
||||
@@ -0,0 +1,57 @@
|
||||
"""章节相关的Pydantic模型"""
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Optional
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
class ChapterBase(BaseModel):
|
||||
"""章节基础模型"""
|
||||
title: str = Field(..., description="章节标题")
|
||||
chapter_number: int = Field(..., description="章节序号")
|
||||
content: Optional[str] = Field(None, description="章节内容")
|
||||
summary: Optional[str] = Field(None, description="章节摘要")
|
||||
word_count: Optional[int] = Field(0, description="字数")
|
||||
status: Optional[str] = Field("draft", description="章节状态")
|
||||
|
||||
|
||||
class ChapterCreate(BaseModel):
|
||||
"""创建章节的请求模型"""
|
||||
project_id: str = Field(..., description="所属项目ID")
|
||||
title: str = Field(..., description="章节标题")
|
||||
chapter_number: int = Field(..., description="章节序号")
|
||||
content: Optional[str] = Field(None, description="章节内容")
|
||||
summary: Optional[str] = Field(None, description="章节摘要")
|
||||
status: Optional[str] = Field("draft", description="章节状态")
|
||||
|
||||
|
||||
class ChapterUpdate(BaseModel):
|
||||
"""更新章节的请求模型"""
|
||||
title: Optional[str] = None
|
||||
content: Optional[str] = None
|
||||
# chapter_number 不允许修改,只能通过大纲的重排序来调整
|
||||
summary: Optional[str] = None
|
||||
# word_count 自动计算,不允许手动修改
|
||||
status: Optional[str] = None
|
||||
|
||||
|
||||
class ChapterResponse(BaseModel):
|
||||
"""章节响应模型"""
|
||||
id: str
|
||||
project_id: str
|
||||
title: str
|
||||
chapter_number: int
|
||||
content: Optional[str] = None
|
||||
summary: Optional[str] = None
|
||||
word_count: int = 0
|
||||
status: str
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class ChapterListResponse(BaseModel):
|
||||
"""章节列表响应模型"""
|
||||
total: int
|
||||
items: list[ChapterResponse]
|
||||
@@ -0,0 +1,67 @@
|
||||
"""角色相关的Pydantic模型"""
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Optional, List, Dict, Any
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
class CharacterBase(BaseModel):
|
||||
"""角色基础模型"""
|
||||
name: str = Field(..., description="角色/组织姓名")
|
||||
age: Optional[str] = Field(None, description="年龄")
|
||||
gender: Optional[str] = Field(None, description="性别")
|
||||
is_organization: bool = Field(False, description="是否为组织")
|
||||
role_type: Optional[str] = Field(None, description="角色类型:protagonist/supporting/antagonist")
|
||||
personality: Optional[str] = Field(None, description="性格特点/组织特性")
|
||||
background: Optional[str] = Field(None, description="背景故事")
|
||||
appearance: Optional[str] = Field(None, description="外貌特征")
|
||||
relationships: Optional[str] = Field(None, description="人际关系(JSON)")
|
||||
organization_type: Optional[str] = Field(None, description="组织类型")
|
||||
organization_purpose: Optional[str] = Field(None, description="组织目的")
|
||||
organization_members: Optional[str] = Field(None, description="组织成员(JSON)")
|
||||
traits: Optional[str] = Field(None, description="特征标签(JSON)")
|
||||
|
||||
|
||||
class CharacterUpdate(BaseModel):
|
||||
"""更新角色的请求模型"""
|
||||
name: Optional[str] = None
|
||||
age: Optional[str] = None
|
||||
gender: Optional[str] = None
|
||||
is_organization: Optional[bool] = None
|
||||
role_type: Optional[str] = None
|
||||
personality: Optional[str] = None
|
||||
background: Optional[str] = None
|
||||
appearance: Optional[str] = None
|
||||
relationships: Optional[str] = None
|
||||
organization_type: Optional[str] = None
|
||||
organization_purpose: Optional[str] = None
|
||||
organization_members: Optional[str] = None
|
||||
traits: Optional[str] = None
|
||||
|
||||
|
||||
class CharacterResponse(CharacterBase):
|
||||
"""角色响应模型"""
|
||||
id: str
|
||||
project_id: str
|
||||
avatar_url: Optional[str] = None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class CharacterGenerateRequest(BaseModel):
|
||||
"""AI生成角色的请求模型"""
|
||||
project_id: str = Field(..., description="项目ID")
|
||||
name: Optional[str] = Field(None, description="角色名称")
|
||||
role_type: Optional[str] = Field(None, description="角色类型")
|
||||
background: Optional[str] = Field(None, description="角色背景")
|
||||
requirements: Optional[str] = Field(None, description="特殊要求")
|
||||
provider: Optional[str] = Field(None, description="AI提供商")
|
||||
model: Optional[str] = Field(None, description="AI模型")
|
||||
|
||||
|
||||
class CharacterListResponse(BaseModel):
|
||||
"""角色列表响应模型"""
|
||||
total: int
|
||||
items: List[CharacterResponse]
|
||||
@@ -0,0 +1,88 @@
|
||||
"""大纲相关的Pydantic模型"""
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Optional
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
class OutlineBase(BaseModel):
|
||||
"""大纲基础模型"""
|
||||
title: str = Field(..., description="章节标题")
|
||||
content: str = Field(..., description="章节内容概要")
|
||||
|
||||
|
||||
class OutlineCreate(BaseModel):
|
||||
"""创建大纲的请求模型"""
|
||||
project_id: str = Field(..., description="所属项目ID")
|
||||
title: str = Field(..., description="章节标题")
|
||||
content: str = Field(..., description="章节内容概要")
|
||||
order_index: int = Field(..., description="章节序号", ge=1)
|
||||
structure: Optional[str] = Field(None, description="结构化大纲数据(JSON)")
|
||||
|
||||
|
||||
class OutlineUpdate(BaseModel):
|
||||
"""更新大纲的请求模型"""
|
||||
title: Optional[str] = None
|
||||
content: Optional[str] = None
|
||||
# order_index 不允许通过普通更新修改,只能通过 reorder_outlines 接口批量调整
|
||||
# structure 暂不支持修改
|
||||
|
||||
|
||||
class OutlineResponse(BaseModel):
|
||||
"""大纲响应模型"""
|
||||
id: str
|
||||
project_id: str
|
||||
title: str
|
||||
content: str
|
||||
structure: Optional[str] = None
|
||||
order_index: int
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class OutlineGenerateRequest(BaseModel):
|
||||
"""AI生成大纲的请求模型 - 支持全新生成和智能续写"""
|
||||
project_id: str = Field(..., description="项目ID")
|
||||
genre: Optional[str] = Field(None, description="小说类型,如:玄幻、都市、悬疑等")
|
||||
theme: str = Field(..., description="小说主题")
|
||||
chapter_count: int = Field(..., ge=1, description="章节数量")
|
||||
narrative_perspective: str = Field(..., description="叙事视角")
|
||||
world_context: Optional[dict] = Field(None, description="世界观背景")
|
||||
characters_context: Optional[list] = Field(None, description="角色信息")
|
||||
target_words: int = Field(100000, description="目标字数")
|
||||
requirements: Optional[str] = Field(None, description="其他特殊要求")
|
||||
provider: Optional[str] = Field(None, description="AI提供商")
|
||||
model: Optional[str] = Field(None, description="AI模型")
|
||||
|
||||
# 续写相关参数
|
||||
mode: str = Field("auto", description="生成模式: auto(自动判断), new(全新生成), continue(续写)")
|
||||
story_direction: Optional[str] = Field(None, description="故事发展方向提示(续写时使用)")
|
||||
plot_stage: str = Field("development", description="情节阶段: development(发展), climax(高潮), ending(结局)")
|
||||
keep_existing: bool = Field(False, description="是否保留现有大纲(续写时)")
|
||||
|
||||
|
||||
class ChapterOutlineGenerateRequest(BaseModel):
|
||||
"""为单个章节生成大纲的请求模型"""
|
||||
outline_id: str = Field(..., description="大纲ID")
|
||||
context: Optional[str] = Field(None, description="额外上下文")
|
||||
provider: Optional[str] = Field(None, description="AI提供商")
|
||||
model: Optional[str] = Field(None, description="AI模型")
|
||||
|
||||
|
||||
class OutlineListResponse(BaseModel):
|
||||
"""大纲列表响应模型"""
|
||||
total: int
|
||||
items: list[OutlineResponse]
|
||||
|
||||
|
||||
class OutlineReorderItem(BaseModel):
|
||||
"""单个大纲重排序项"""
|
||||
id: str = Field(..., description="大纲ID")
|
||||
order_index: int = Field(..., description="新的序号", ge=1)
|
||||
|
||||
|
||||
class OutlineReorderRequest(BaseModel):
|
||||
"""大纲批量重排序请求"""
|
||||
orders: list[OutlineReorderItem] = Field(..., description="排序列表")
|
||||
@@ -0,0 +1,20 @@
|
||||
"""AI去味相关的Pydantic模型"""
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class PolishRequest(BaseModel):
|
||||
"""AI去味请求模型"""
|
||||
original_text: str = Field(..., description="原始文本(AI生成的文本)")
|
||||
project_id: Optional[int] = Field(None, description="项目ID(可选,用于记录历史)")
|
||||
provider: Optional[str] = Field(None, description="AI提供商")
|
||||
model: Optional[str] = Field(None, description="AI模型")
|
||||
temperature: Optional[float] = Field(0.8, description="温度参数,建议0.7-0.9")
|
||||
|
||||
|
||||
class PolishResponse(BaseModel):
|
||||
"""AI去味响应模型"""
|
||||
original_text: str = Field(..., description="原始文本")
|
||||
polished_text: str = Field(..., description="去味后的文本")
|
||||
word_count_before: int = Field(..., description="处理前字数")
|
||||
word_count_after: int = Field(..., description="处理后字数")
|
||||
@@ -0,0 +1,83 @@
|
||||
"""项目相关的Pydantic模型"""
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Optional
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
class ProjectBase(BaseModel):
|
||||
"""项目基础模型"""
|
||||
title: str = Field(..., description="项目标题")
|
||||
description: Optional[str] = Field(None, description="项目描述")
|
||||
theme: Optional[str] = Field(None, description="主题")
|
||||
genre: Optional[str] = Field(None, description="小说类型")
|
||||
target_words: Optional[int] = Field(None, description="目标字数")
|
||||
|
||||
|
||||
class ProjectCreate(ProjectBase):
|
||||
"""创建项目的请求模型"""
|
||||
pass
|
||||
|
||||
|
||||
class ProjectUpdate(BaseModel):
|
||||
"""更新项目的请求模型"""
|
||||
title: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
theme: Optional[str] = None
|
||||
genre: Optional[str] = None
|
||||
target_words: Optional[int] = None
|
||||
status: Optional[str] = None
|
||||
# wizard_status 和 wizard_step 只能通过向导API修改,普通更新不允许
|
||||
world_time_period: Optional[str] = None
|
||||
world_location: Optional[str] = None
|
||||
world_atmosphere: Optional[str] = None
|
||||
world_rules: Optional[str] = None
|
||||
chapter_count: Optional[int] = None
|
||||
narrative_perspective: Optional[str] = None
|
||||
character_count: Optional[int] = None
|
||||
# current_words 由章节内容自动计算,不允许手动修改
|
||||
|
||||
|
||||
class ProjectResponse(ProjectBase):
|
||||
"""项目响应模型"""
|
||||
id: str # UUID字符串
|
||||
status: str
|
||||
current_words: int
|
||||
wizard_status: Optional[str] = None
|
||||
wizard_step: Optional[int] = None
|
||||
world_time_period: Optional[str] = None
|
||||
world_location: Optional[str] = None
|
||||
world_atmosphere: Optional[str] = None
|
||||
world_rules: Optional[str] = None
|
||||
chapter_count: Optional[int] = None
|
||||
narrative_perspective: Optional[str] = None
|
||||
character_count: Optional[int] = None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class ProjectListResponse(BaseModel):
|
||||
"""项目列表响应模型"""
|
||||
total: int
|
||||
items: list[ProjectResponse]
|
||||
|
||||
|
||||
class ProjectWizardRequest(BaseModel):
|
||||
"""项目创建向导请求模型"""
|
||||
title: str = Field(..., description="书名")
|
||||
theme: str = Field(..., description="主题")
|
||||
genre: Optional[str] = Field(None, description="类型")
|
||||
chapter_count: int = Field(..., ge=1, description="章节数量")
|
||||
narrative_perspective: str = Field(..., description="叙事视角")
|
||||
character_count: int = Field(5, ge=5, description="角色数量(至少5个)")
|
||||
target_words: Optional[int] = Field(None, description="目标字数")
|
||||
|
||||
|
||||
class WorldBuildingResponse(BaseModel):
|
||||
"""世界构建响应模型"""
|
||||
time_period: str = Field(..., description="时间背景")
|
||||
location: str = Field(..., description="地理位置")
|
||||
atmosphere: str = Field(..., description="氛围基调")
|
||||
rules: str = Field(..., description="世界规则")
|
||||
@@ -0,0 +1,204 @@
|
||||
"""关系管理相关的Pydantic模型"""
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Optional, List
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
# ============ 关系类型相关 ============
|
||||
|
||||
class RelationshipTypeResponse(BaseModel):
|
||||
"""关系类型响应模型"""
|
||||
id: int
|
||||
name: str
|
||||
category: str
|
||||
reverse_name: Optional[str] = None
|
||||
intimacy_range: Optional[str] = None
|
||||
icon: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
created_at: datetime
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
# ============ 角色关系相关 ============
|
||||
|
||||
class CharacterRelationshipBase(BaseModel):
|
||||
"""角色关系基础模型"""
|
||||
relationship_type_id: Optional[int] = Field(None, description="关系类型ID")
|
||||
relationship_name: Optional[str] = Field(None, description="自定义关系名称")
|
||||
intimacy_level: int = Field(50, ge=0, le=100, description="亲密度:0-100")
|
||||
status: str = Field("active", description="状态:active/broken/past/complicated")
|
||||
description: Optional[str] = Field(None, description="关系描述")
|
||||
started_at: Optional[str] = Field(None, description="关系开始时间(故事时间)")
|
||||
ended_at: Optional[str] = Field(None, description="关系结束时间")
|
||||
|
||||
|
||||
class CharacterRelationshipCreate(CharacterRelationshipBase):
|
||||
"""创建角色关系的请求模型"""
|
||||
project_id: str = Field(..., description="项目ID")
|
||||
character_from_id: str = Field(..., description="角色A的ID")
|
||||
character_to_id: str = Field(..., description="角色B的ID")
|
||||
|
||||
|
||||
class CharacterRelationshipUpdate(BaseModel):
|
||||
"""更新角色关系的请求模型"""
|
||||
relationship_type_id: Optional[int] = None
|
||||
relationship_name: Optional[str] = None
|
||||
intimacy_level: Optional[int] = Field(None, ge=0, le=100)
|
||||
status: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
started_at: Optional[str] = None
|
||||
ended_at: Optional[str] = None
|
||||
|
||||
|
||||
class CharacterRelationshipResponse(CharacterRelationshipBase):
|
||||
"""角色关系响应模型"""
|
||||
id: str
|
||||
project_id: str
|
||||
character_from_id: str
|
||||
character_to_id: str
|
||||
source: str
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class RelationshipGraphNode(BaseModel):
|
||||
"""关系图谱节点"""
|
||||
id: str
|
||||
name: str
|
||||
type: str # character / organization
|
||||
role_type: Optional[str] = None
|
||||
avatar: Optional[str] = None
|
||||
|
||||
|
||||
class RelationshipGraphLink(BaseModel):
|
||||
"""关系图谱连线"""
|
||||
source: str
|
||||
target: str
|
||||
relationship: str
|
||||
intimacy: int
|
||||
status: str
|
||||
|
||||
|
||||
class RelationshipGraphData(BaseModel):
|
||||
"""关系图谱数据"""
|
||||
nodes: List[RelationshipGraphNode]
|
||||
links: List[RelationshipGraphLink]
|
||||
|
||||
|
||||
# ============ 组织相关 ============
|
||||
|
||||
class OrganizationBase(BaseModel):
|
||||
"""组织基础模型"""
|
||||
parent_org_id: Optional[str] = Field(None, description="父组织ID")
|
||||
level: int = Field(0, description="组织层级")
|
||||
power_level: int = Field(50, ge=0, le=100, description="势力等级")
|
||||
location: Optional[str] = Field(None, description="所在地")
|
||||
motto: Optional[str] = Field(None, description="组织宗旨")
|
||||
color: Optional[str] = Field(None, description="代表颜色")
|
||||
|
||||
|
||||
class OrganizationCreate(OrganizationBase):
|
||||
"""创建组织的请求模型"""
|
||||
character_id: str = Field(..., description="关联的角色ID(组织记录)")
|
||||
project_id: str = Field(..., description="项目ID")
|
||||
|
||||
|
||||
class OrganizationUpdate(BaseModel):
|
||||
"""更新组织的请求模型"""
|
||||
parent_org_id: Optional[str] = None
|
||||
level: Optional[int] = None
|
||||
power_level: Optional[int] = Field(None, ge=0, le=100)
|
||||
location: Optional[str] = None
|
||||
motto: Optional[str] = None
|
||||
color: Optional[str] = None
|
||||
|
||||
|
||||
class OrganizationResponse(OrganizationBase):
|
||||
"""组织响应模型"""
|
||||
id: str
|
||||
character_id: str
|
||||
project_id: str
|
||||
member_count: int
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class OrganizationDetailResponse(BaseModel):
|
||||
"""组织详情响应(包含基本信息)"""
|
||||
id: str
|
||||
character_id: str
|
||||
name: str
|
||||
type: Optional[str] = None
|
||||
purpose: Optional[str] = None
|
||||
member_count: int
|
||||
power_level: int
|
||||
location: Optional[str] = None
|
||||
motto: Optional[str] = None
|
||||
color: Optional[str] = None
|
||||
|
||||
|
||||
# ============ 组织成员相关 ============
|
||||
|
||||
class OrganizationMemberBase(BaseModel):
|
||||
"""组织成员基础模型"""
|
||||
position: str = Field(..., description="职位名称")
|
||||
rank: int = Field(0, description="职位等级")
|
||||
status: str = Field("active", description="状态:active/retired/expelled/deceased")
|
||||
joined_at: Optional[str] = Field(None, description="加入时间(故事时间)")
|
||||
left_at: Optional[str] = Field(None, description="离开时间")
|
||||
loyalty: int = Field(50, ge=0, le=100, description="忠诚度")
|
||||
contribution: int = Field(0, ge=0, le=100, description="贡献度")
|
||||
notes: Optional[str] = Field(None, description="备注")
|
||||
|
||||
|
||||
class OrganizationMemberCreate(OrganizationMemberBase):
|
||||
"""创建组织成员的请求模型"""
|
||||
character_id: str = Field(..., description="角色ID")
|
||||
|
||||
|
||||
class OrganizationMemberUpdate(BaseModel):
|
||||
"""更新组织成员的请求模型"""
|
||||
position: Optional[str] = None
|
||||
rank: Optional[int] = None
|
||||
status: Optional[str] = None
|
||||
joined_at: Optional[str] = None
|
||||
left_at: Optional[str] = None
|
||||
loyalty: Optional[int] = Field(None, ge=0, le=100)
|
||||
contribution: Optional[int] = Field(None, ge=0, le=100)
|
||||
notes: Optional[str] = None
|
||||
|
||||
|
||||
class OrganizationMemberResponse(OrganizationMemberBase):
|
||||
"""组织成员响应模型"""
|
||||
id: str
|
||||
organization_id: str
|
||||
character_id: str
|
||||
source: str
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class OrganizationMemberDetailResponse(BaseModel):
|
||||
"""组织成员详情响应(包含角色信息)"""
|
||||
id: str
|
||||
character_id: str
|
||||
character_name: str
|
||||
position: str
|
||||
rank: int
|
||||
loyalty: int
|
||||
contribution: int
|
||||
status: str
|
||||
joined_at: Optional[str] = None
|
||||
left_at: Optional[str] = None
|
||||
notes: Optional[str] = None
|
||||
@@ -0,0 +1 @@
|
||||
"""服务层模块"""
|
||||
@@ -0,0 +1,363 @@
|
||||
"""AI服务封装 - 统一的OpenAI和Claude接口"""
|
||||
from typing import Optional, AsyncGenerator, List, Dict, Any
|
||||
from openai import AsyncOpenAI
|
||||
from anthropic import AsyncAnthropic
|
||||
from app.config import settings
|
||||
from app.logger import get_logger
|
||||
import httpx
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class AIService:
|
||||
"""AI服务统一接口"""
|
||||
|
||||
def __init__(self):
|
||||
"""初始化AI客户端(优化并发性能)"""
|
||||
# 初始化OpenAI客户端
|
||||
if settings.openai_api_key:
|
||||
# 创建自定义的httpx客户端来避免proxies参数问题
|
||||
try:
|
||||
# 配置连接池限制,支持高并发
|
||||
# max_keepalive_connections: 保持活跃的连接数(提高复用率)
|
||||
# max_connections: 最大并发连接数(防止资源耗尽)
|
||||
limits = httpx.Limits(
|
||||
max_keepalive_connections=50, # 保持50个活跃连接
|
||||
max_connections=100, # 最多100个并发连接
|
||||
keepalive_expiry=30.0 # 30秒后过期未使用的连接
|
||||
)
|
||||
|
||||
# 使用httpx.AsyncClient并设置超时和连接池
|
||||
# connect: 连接超时10秒
|
||||
# read: 读取超时180秒(3分钟,适合长文本生成)
|
||||
# write: 写入超时10秒
|
||||
# pool: 连接池超时10秒
|
||||
http_client = httpx.AsyncClient(
|
||||
timeout=httpx.Timeout(
|
||||
connect=10.0,
|
||||
read=180.0,
|
||||
write=10.0,
|
||||
pool=10.0
|
||||
),
|
||||
limits=limits
|
||||
)
|
||||
|
||||
client_kwargs = {
|
||||
"api_key": settings.openai_api_key,
|
||||
"http_client": http_client
|
||||
}
|
||||
|
||||
if settings.openai_base_url:
|
||||
client_kwargs["base_url"] = settings.openai_base_url
|
||||
|
||||
self.openai_client = AsyncOpenAI(**client_kwargs)
|
||||
logger.info("✅ OpenAI客户端初始化成功")
|
||||
logger.info(" - 超时设置:连接10s,读取180s")
|
||||
logger.info(" - 连接池:50个保活连接,最大100个并发")
|
||||
except Exception as e:
|
||||
logger.error(f"OpenAI客户端初始化失败: {e}")
|
||||
self.openai_client = None
|
||||
else:
|
||||
self.openai_client = None
|
||||
logger.warning("OpenAI API key未配置")
|
||||
|
||||
# 初始化Anthropic客户端
|
||||
if settings.anthropic_api_key:
|
||||
try:
|
||||
# 为Anthropic设置相同的超时和连接池配置
|
||||
limits = httpx.Limits(
|
||||
max_keepalive_connections=50,
|
||||
max_connections=100,
|
||||
keepalive_expiry=30.0
|
||||
)
|
||||
|
||||
http_client = httpx.AsyncClient(
|
||||
timeout=httpx.Timeout(
|
||||
connect=10.0,
|
||||
read=180.0,
|
||||
write=10.0,
|
||||
pool=10.0
|
||||
),
|
||||
limits=limits
|
||||
)
|
||||
|
||||
client_kwargs = {
|
||||
"api_key": settings.anthropic_api_key,
|
||||
"http_client": http_client
|
||||
}
|
||||
|
||||
if settings.anthropic_base_url:
|
||||
client_kwargs["base_url"] = settings.anthropic_base_url
|
||||
|
||||
self.anthropic_client = AsyncAnthropic(**client_kwargs)
|
||||
logger.info("✅ Anthropic客户端初始化成功")
|
||||
logger.info(" - 超时设置:连接10s,读取180s")
|
||||
logger.info(" - 连接池:50个保活连接,最大100个并发")
|
||||
except Exception as e:
|
||||
logger.error(f"Anthropic客户端初始化失败: {e}")
|
||||
self.anthropic_client = None
|
||||
else:
|
||||
self.anthropic_client = None
|
||||
logger.warning("Anthropic API key未配置")
|
||||
|
||||
async def generate_text(
|
||||
self,
|
||||
prompt: str,
|
||||
provider: Optional[str] = None,
|
||||
model: Optional[str] = None,
|
||||
temperature: Optional[float] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
system_prompt: Optional[str] = None
|
||||
) -> str:
|
||||
"""
|
||||
生成文本
|
||||
|
||||
Args:
|
||||
prompt: 用户提示词
|
||||
provider: AI提供商 (openai/anthropic)
|
||||
model: 模型名称
|
||||
temperature: 温度参数
|
||||
max_tokens: 最大token数
|
||||
system_prompt: 系统提示词
|
||||
|
||||
Returns:
|
||||
生成的文本
|
||||
"""
|
||||
provider = provider or settings.default_ai_provider
|
||||
model = model or settings.default_model
|
||||
temperature = temperature or settings.default_temperature
|
||||
max_tokens = max_tokens or settings.default_max_tokens
|
||||
|
||||
if provider == "openai":
|
||||
return await self._generate_openai(
|
||||
prompt, model, temperature, max_tokens, system_prompt
|
||||
)
|
||||
elif provider == "anthropic":
|
||||
return await self._generate_anthropic(
|
||||
prompt, model, temperature, max_tokens, system_prompt
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"不支持的AI提供商: {provider}")
|
||||
|
||||
async def generate_text_stream(
|
||||
self,
|
||||
prompt: str,
|
||||
provider: Optional[str] = None,
|
||||
model: Optional[str] = None,
|
||||
temperature: Optional[float] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
system_prompt: Optional[str] = None
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""
|
||||
流式生成文本
|
||||
|
||||
Args:
|
||||
prompt: 用户提示词
|
||||
provider: AI提供商
|
||||
model: 模型名称
|
||||
temperature: 温度参数
|
||||
max_tokens: 最大token数
|
||||
system_prompt: 系统提示词
|
||||
|
||||
Yields:
|
||||
生成的文本片段
|
||||
"""
|
||||
provider = provider or settings.default_ai_provider
|
||||
model = model or settings.default_model
|
||||
temperature = temperature or settings.default_temperature
|
||||
max_tokens = max_tokens or settings.default_max_tokens
|
||||
|
||||
if provider == "openai":
|
||||
async for chunk in self._generate_openai_stream(
|
||||
prompt, model, temperature, max_tokens, system_prompt
|
||||
):
|
||||
yield chunk
|
||||
elif provider == "anthropic":
|
||||
async for chunk in self._generate_anthropic_stream(
|
||||
prompt, model, temperature, max_tokens, system_prompt
|
||||
):
|
||||
yield chunk
|
||||
else:
|
||||
raise ValueError(f"不支持的AI提供商: {provider}")
|
||||
|
||||
async def _generate_openai(
|
||||
self,
|
||||
prompt: str,
|
||||
model: str,
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
system_prompt: Optional[str]
|
||||
) -> str:
|
||||
"""使用OpenAI生成文本"""
|
||||
if not self.openai_client:
|
||||
raise ValueError("OpenAI客户端未初始化,请检查API key配置")
|
||||
|
||||
messages = []
|
||||
if system_prompt:
|
||||
messages.append({"role": "system", "content": system_prompt})
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
|
||||
try:
|
||||
logger.info(f"🔵 开始调用OpenAI API")
|
||||
logger.info(f" - 模型: {model}")
|
||||
logger.info(f" - 温度: {temperature}")
|
||||
logger.info(f" - 最大tokens: {max_tokens}")
|
||||
logger.info(f" - Prompt长度: {len(prompt)} 字符")
|
||||
logger.info(f" - 消息数量: {len(messages)}")
|
||||
|
||||
response = await self.openai_client.chat.completions.create(
|
||||
model=model,
|
||||
messages=messages,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens
|
||||
)
|
||||
|
||||
logger.info(f"✅ OpenAI API调用成功")
|
||||
logger.info(f" - 响应ID: {response.id if hasattr(response, 'id') else 'N/A'}")
|
||||
logger.info(f" - 选项数量: {len(response.choices)}")
|
||||
|
||||
if not response.choices:
|
||||
logger.error("❌ OpenAI返回的choices为空")
|
||||
return ""
|
||||
|
||||
content = response.choices[0].message.content
|
||||
logger.info(f" - 返回内容长度: {len(content) if content else 0} 字符")
|
||||
|
||||
if content:
|
||||
logger.info(f" - 返回内容预览(前200字符): {content[:200]}")
|
||||
return content
|
||||
else:
|
||||
logger.error("❌ OpenAI返回了空内容")
|
||||
logger.error(f" - 完整响应: {response}")
|
||||
raise ValueError("AI返回了空内容,请检查API配置或稍后重试")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ OpenAI API调用失败")
|
||||
logger.error(f" - 错误类型: {type(e).__name__}")
|
||||
logger.error(f" - 错误信息: {str(e)}")
|
||||
logger.error(f" - 模型: {model}")
|
||||
raise
|
||||
|
||||
async def _generate_openai_stream(
|
||||
self,
|
||||
prompt: str,
|
||||
model: str,
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
system_prompt: Optional[str]
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""使用OpenAI流式生成文本"""
|
||||
if not self.openai_client:
|
||||
raise ValueError("OpenAI客户端未初始化,请检查API key配置")
|
||||
|
||||
messages = []
|
||||
if system_prompt:
|
||||
messages.append({"role": "system", "content": system_prompt})
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
|
||||
try:
|
||||
logger.info(f"🔵 开始调用OpenAI流式API")
|
||||
logger.info(f" - 模型: {model}")
|
||||
logger.info(f" - Prompt长度: {len(prompt)} 字符")
|
||||
logger.info(f" - 最大tokens: {max_tokens}")
|
||||
|
||||
stream = await self.openai_client.chat.completions.create(
|
||||
model=model,
|
||||
messages=messages,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
stream=True
|
||||
)
|
||||
|
||||
logger.info(f"✅ OpenAI流式API连接成功,开始接收数据...")
|
||||
|
||||
chunk_count = 0
|
||||
async for chunk in stream:
|
||||
if chunk.choices and len(chunk.choices) > 0:
|
||||
if chunk.choices[0].delta.content:
|
||||
chunk_count += 1
|
||||
yield chunk.choices[0].delta.content
|
||||
|
||||
logger.info(f"✅ OpenAI流式生成完成,共接收 {chunk_count} 个chunk")
|
||||
|
||||
except httpx.TimeoutException as e:
|
||||
logger.error(f"❌ OpenAI流式API超时")
|
||||
logger.error(f" - 错误: {str(e)}")
|
||||
logger.error(f" - 提示: 请检查网络连接或考虑缩短prompt长度")
|
||||
raise TimeoutError(f"AI服务超时(180秒),请稍后重试或减少上下文长度") from e
|
||||
except Exception as e:
|
||||
logger.error(f"❌ OpenAI流式API调用失败: {str(e)}")
|
||||
logger.error(f" - 错误类型: {type(e).__name__}")
|
||||
raise
|
||||
|
||||
async def _generate_anthropic(
|
||||
self,
|
||||
prompt: str,
|
||||
model: str,
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
system_prompt: Optional[str]
|
||||
) -> str:
|
||||
"""使用Anthropic生成文本"""
|
||||
if not self.anthropic_client:
|
||||
raise ValueError("Anthropic客户端未初始化,请检查API key配置")
|
||||
|
||||
try:
|
||||
response = await self.anthropic_client.messages.create(
|
||||
model=model,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
system=system_prompt or "",
|
||||
messages=[{"role": "user", "content": prompt}]
|
||||
)
|
||||
return response.content[0].text
|
||||
except Exception as e:
|
||||
logger.error(f"Anthropic API调用失败: {str(e)}")
|
||||
raise
|
||||
|
||||
async def _generate_anthropic_stream(
|
||||
self,
|
||||
prompt: str,
|
||||
model: str,
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
system_prompt: Optional[str]
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""使用Anthropic流式生成文本"""
|
||||
if not self.anthropic_client:
|
||||
raise ValueError("Anthropic客户端未初始化,请检查API key配置")
|
||||
|
||||
try:
|
||||
logger.info(f"🔵 开始调用Anthropic流式API")
|
||||
logger.info(f" - 模型: {model}")
|
||||
logger.info(f" - Prompt长度: {len(prompt)} 字符")
|
||||
logger.info(f" - 最大tokens: {max_tokens}")
|
||||
|
||||
async with self.anthropic_client.messages.stream(
|
||||
model=model,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
system=system_prompt or "",
|
||||
messages=[{"role": "user", "content": prompt}]
|
||||
) as stream:
|
||||
logger.info(f"✅ Anthropic流式API连接成功,开始接收数据...")
|
||||
|
||||
chunk_count = 0
|
||||
async for text in stream.text_stream:
|
||||
chunk_count += 1
|
||||
yield text
|
||||
|
||||
logger.info(f"✅ Anthropic流式生成完成,共接收 {chunk_count} 个chunk")
|
||||
|
||||
except httpx.TimeoutException as e:
|
||||
logger.error(f"❌ Anthropic流式API超时")
|
||||
logger.error(f" - 错误: {str(e)}")
|
||||
raise TimeoutError(f"AI服务超时(180秒),请稍后重试或减少上下文长度") from e
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Anthropic流式API调用失败: {str(e)}")
|
||||
logger.error(f" - 错误类型: {type(e).__name__}")
|
||||
raise
|
||||
|
||||
|
||||
# 创建全局AI服务实例
|
||||
ai_service = AIService()
|
||||
@@ -0,0 +1,149 @@
|
||||
"""
|
||||
LinuxDO OAuth2 服务
|
||||
"""
|
||||
import httpx
|
||||
import secrets
|
||||
from typing import Optional, Dict, Any
|
||||
from app.config import settings
|
||||
|
||||
|
||||
class LinuxDOOAuthService:
|
||||
"""LinuxDO OAuth2 服务类"""
|
||||
|
||||
# LinuxDO OAuth2 端点
|
||||
AUTHORIZE_URL = "https://connect.linux.do/oauth2/authorize"
|
||||
TOKEN_URL = "https://connect.linux.do/oauth2/token"
|
||||
USERINFO_URL = "https://connect.linux.do/api/user" # 修复:使用正确的用户信息端点
|
||||
|
||||
def __init__(self):
|
||||
self.client_id = settings.LINUXDO_CLIENT_ID
|
||||
self.client_secret = settings.LINUXDO_CLIENT_SECRET
|
||||
self.redirect_uri = settings.LINUXDO_REDIRECT_URI
|
||||
|
||||
# 验证redirect_uri配置
|
||||
if not self.redirect_uri:
|
||||
raise ValueError(
|
||||
"LINUXDO_REDIRECT_URI 未配置!\n"
|
||||
"请在 .env 文件中设置正确的回调地址:\n"
|
||||
"本地开发: LINUXDO_REDIRECT_URI=http://localhost:8000/api/auth/callback\n"
|
||||
"Docker部署: LINUXDO_REDIRECT_URI=https://your-domain.com/api/auth/callback"
|
||||
)
|
||||
|
||||
# 警告:检查是否使用了localhost(在非开发环境)
|
||||
if not settings.debug and "localhost" in self.redirect_uri.lower():
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.warning(
|
||||
f"⚠️ 生产环境检测到使用 localhost 作为回调地址: {self.redirect_uri}\n"
|
||||
"这可能导致OAuth回调失败!请使用实际的域名或服务器IP。"
|
||||
)
|
||||
|
||||
def generate_state(self) -> str:
|
||||
"""生成随机 state 参数"""
|
||||
return secrets.token_urlsafe(32)
|
||||
|
||||
def get_authorization_url(self, state: str) -> str:
|
||||
"""
|
||||
获取授权 URL
|
||||
|
||||
Args:
|
||||
state: 随机 state 参数
|
||||
|
||||
Returns:
|
||||
授权 URL
|
||||
"""
|
||||
params = {
|
||||
"client_id": self.client_id,
|
||||
"redirect_uri": self.redirect_uri,
|
||||
"response_type": "code",
|
||||
"scope": "read",
|
||||
"state": state
|
||||
}
|
||||
|
||||
query_string = "&".join([f"{k}={v}" for k, v in params.items()])
|
||||
return f"{self.AUTHORIZE_URL}?{query_string}"
|
||||
|
||||
async def get_access_token(self, code: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
使用授权码获取访问令牌
|
||||
|
||||
Args:
|
||||
code: 授权码
|
||||
|
||||
Returns:
|
||||
包含 access_token 的字典,失败返回 None
|
||||
"""
|
||||
data = {
|
||||
"client_id": self.client_id,
|
||||
"client_secret": self.client_secret,
|
||||
"code": code,
|
||||
"grant_type": "authorization_code",
|
||||
"redirect_uri": self.redirect_uri
|
||||
}
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
self.TOKEN_URL,
|
||||
data=data,
|
||||
headers={"Content-Type": "application/x-www-form-urlencoded"}
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
return response.json()
|
||||
else:
|
||||
print(f"获取访问令牌失败: {response.status_code} {response.text}")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
print(f"获取访问令牌异常: {e}")
|
||||
return None
|
||||
|
||||
async def get_user_info(self, access_token: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
使用访问令牌获取用户信息
|
||||
|
||||
Args:
|
||||
access_token: 访问令牌
|
||||
|
||||
Returns:
|
||||
用户信息字典,失败返回 None
|
||||
"""
|
||||
try:
|
||||
# 添加真实浏览器请求头,避免被 Cloudflare 拦截
|
||||
headers = {
|
||||
"Authorization": f"Bearer {access_token}",
|
||||
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36",
|
||||
"Accept": "application/json",
|
||||
"Accept-Language": "zh-CN,zh;q=0.9,en;q=0.8",
|
||||
}
|
||||
|
||||
# 不自动处理编码,让 httpx 自动解压
|
||||
async with httpx.AsyncClient(follow_redirects=True, timeout=30.0) as client:
|
||||
response = await client.get(
|
||||
self.USERINFO_URL,
|
||||
headers=headers
|
||||
)
|
||||
|
||||
print(f"获取用户信息响应状态: {response.status_code}")
|
||||
print(f"响应头: {response.headers}")
|
||||
|
||||
if response.status_code == 200:
|
||||
try:
|
||||
user_data = response.json()
|
||||
print(f"用户信息: {user_data}")
|
||||
return user_data
|
||||
except Exception as json_error:
|
||||
print(f"解析 JSON 失败: {json_error}")
|
||||
print(f"响应内容前100字符: {response.text[:100]}")
|
||||
return None
|
||||
else:
|
||||
print(f"获取用户信息失败: {response.status_code}")
|
||||
print(f"响应内容: {response.text[:200]}")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
print(f"获取用户信息异常: {type(e).__name__}: {str(e)}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return None
|
||||
@@ -0,0 +1,730 @@
|
||||
"""提示词管理服务"""
|
||||
from typing import Dict, Any
|
||||
import json
|
||||
|
||||
|
||||
class PromptService:
|
||||
"""提示词模板管理"""
|
||||
|
||||
# 世界构建提示词
|
||||
WORLD_BUILDING = """你是一位资深的世界观设计师。请根据以下信息构建一个完整的小说世界观:
|
||||
|
||||
书名:{title}
|
||||
主题:{theme}
|
||||
类型:{genre}
|
||||
|
||||
请生成包含以下内容的世界构建框架:
|
||||
|
||||
1. **时间背景**:具体的时代设定、时间流逝特点、重要历史事件
|
||||
2. **地理位置**:主要地点描述、地理环境特征、空间布局
|
||||
3. **氛围基调**:整体氛围感觉、情感色彩、视觉风格
|
||||
4. **世界规则**:基本运行法则、特殊设定、社会规则和禁忌、权力结构
|
||||
|
||||
要求:
|
||||
- 与主题高度契合
|
||||
- 设定要合理自洽
|
||||
- 为故事发展提供支撑
|
||||
- 具有独特性和吸引力
|
||||
|
||||
**重要:你必须只返回纯JSON格式,不要包含任何markdown标记、代码块标记或其他说明文字。**
|
||||
|
||||
请严格按照以下JSON格式返回(每个字段为200-300字的文本描述):
|
||||
{{
|
||||
"time_period": "时间背景的详细描述,包括时代设定、时间特点、历史事件",
|
||||
"location": "地理位置的详细描述,包括主要地点、环境特征、空间布局",
|
||||
"atmosphere": "氛围基调的详细描述,包括整体氛围、情感色彩、视觉风格",
|
||||
"rules": "世界规则的详细描述,包括运行法则、特殊设定、社会规则、权力结构"
|
||||
}}
|
||||
|
||||
再次强调:只返回纯JSON对象,不要有```json```这样的标记,不要有任何额外的文字说明。"""
|
||||
|
||||
# 批量角色生成提示词
|
||||
CHARACTERS_BATCH_GENERATION = """你是一位专业的角色设定师。请根据以下世界观和要求,生成{count}个立体丰满的角色和组织:
|
||||
|
||||
世界观信息:
|
||||
- 时间背景:{time_period}
|
||||
- 地理位置:{location}
|
||||
- 氛围基调:{atmosphere}
|
||||
- 世界规则:{rules}
|
||||
|
||||
主题:{theme}
|
||||
类型:{genre}
|
||||
特殊要求:{requirements}
|
||||
|
||||
【数量要求 - 必须严格遵守】
|
||||
请精确生成{count}个实体,不多不少。数组中必须包含且仅包含{count}个对象。
|
||||
|
||||
实体类型分配:
|
||||
- 至少1个主角(protagonist)
|
||||
- 多个配角(supporting)
|
||||
- 可以包含反派(antagonist)
|
||||
- 可以包含1-2个重要组织
|
||||
|
||||
要求:
|
||||
- 角色要符合世界观设定
|
||||
- 性格和背景要有深度
|
||||
- 角色之间要有关系网络
|
||||
- 组织要有存在的合理性
|
||||
- 所有实体要为故事服务
|
||||
|
||||
**重要:你必须只返回纯JSON数组格式,不要包含任何markdown标记、代码块标记或其他说明文字。**
|
||||
|
||||
请严格按照以下JSON数组格式返回(每个角色为数组中的一个对象):
|
||||
[
|
||||
{{
|
||||
"name": "角色姓名",
|
||||
"age": 25,
|
||||
"gender": "男/女/其他",
|
||||
"is_organization": false,
|
||||
"role_type": "protagonist/supporting/antagonist",
|
||||
"personality": "性格特点的详细描述(100-200字),包括核心性格、优缺点、特殊习惯",
|
||||
"background": "背景故事的详细描述(100-200字),包括家庭背景、成长经历、重要转折",
|
||||
"appearance": "外貌描述(50-100字),包括身高、体型、面容、着装风格",
|
||||
"traits": ["特长1", "特长2", "特长3"],
|
||||
"relationships_array": [
|
||||
{{
|
||||
"target_character_name": "已生成的角色名称",
|
||||
"relationship_type": "关系类型(师父/朋友/敌人/父亲/母亲等)",
|
||||
"intimacy_level": 75,
|
||||
"description": "关系描述"
|
||||
}}
|
||||
],
|
||||
"organization_memberships": [
|
||||
{{
|
||||
"organization_name": "已生成的组织名称",
|
||||
"position": "职位",
|
||||
"rank": 5,
|
||||
"loyalty": 80
|
||||
}}
|
||||
]
|
||||
}},
|
||||
{{
|
||||
"name": "组织名称",
|
||||
"is_organization": true,
|
||||
"role_type": "supporting",
|
||||
"personality": "组织特性描述(100-200字),包括运作方式、核心理念、行事风格",
|
||||
"background": "组织背景(100-200字),包括建立历史、发展历程、重要事件",
|
||||
"appearance": "组织外在表现(50-100字),如总部位置、标志性建筑等",
|
||||
"organization_type": "组织类型",
|
||||
"organization_purpose": "组织目的",
|
||||
"organization_members": ["成员1", "成员2"],
|
||||
"traits": []
|
||||
}}
|
||||
]
|
||||
|
||||
**关系类型参考(从中选择或自定义):**
|
||||
- 家族:父亲、母亲、兄弟、姐妹、子女、配偶、恋人
|
||||
- 社交:师父、徒弟、朋友、同学、同事、邻居、知己
|
||||
- 职业:上司、下属、合作伙伴
|
||||
- 敌对:敌人、仇人、竞争对手、宿敌
|
||||
|
||||
**重要说明:**
|
||||
1. **数量控制**:数组中必须精确包含{count}个对象,不能多也不能少
|
||||
2. **关系约束**:relationships_array只能引用本批次中已经出现的角色名称
|
||||
3. **组织约束**:organization_memberships只能引用本批次中is_organization=true的实体名称
|
||||
4. **禁止幻觉**:不要引用任何不存在的角色或组织,如果没有可引用的就留空数组[]
|
||||
5. intimacy_level和loyalty都是0-100的整数
|
||||
6. 角色之间要形成合理的关系网络
|
||||
|
||||
**示例说明**:
|
||||
- 如果生成了角色A、组织B、角色C,则角色A的organization_memberships只能是[组织B],不能是其他组织
|
||||
- 如果角色A在数组第一位,它的relationships_array必须为空[],因为还没有其他角色
|
||||
- 如果角色C在数组第三位,它的relationships_array可以引用角色A,但不能引用不存在的角色D
|
||||
|
||||
再次强调:
|
||||
1. 只返回纯JSON数组,不要有```json```这样的标记
|
||||
2. 数组中必须精确包含{count}个对象
|
||||
3. 不要引用任何本批次中不存在的角色或组织名称"""
|
||||
|
||||
# 完整大纲生成提示词
|
||||
COMPLETE_OUTLINE_GENERATION = """你是一位经验丰富的小说作家和编剧。请根据以下信息生成完整的{chapter_count}章小说大纲:
|
||||
|
||||
基本信息:
|
||||
- 书名:{title}
|
||||
- 主题:{theme}
|
||||
- 类型:{genre}
|
||||
- 章节数:{chapter_count}
|
||||
- 叙事视角:{narrative_perspective}
|
||||
- 目标字数:{target_words}
|
||||
|
||||
世界观:
|
||||
- 时间背景:{time_period}
|
||||
- 地理位置:{location}
|
||||
- 氛围基调:{atmosphere}
|
||||
- 世界规则:{rules}
|
||||
|
||||
角色信息:
|
||||
{characters_info}
|
||||
|
||||
其他要求:{requirements}
|
||||
|
||||
整体要求:
|
||||
- 结构完整:起承转合清晰
|
||||
- 情节连贯:章节之间紧密衔接
|
||||
- 冲突递进:矛盾逐步升级
|
||||
- 人物成长:角色有明确的变化弧线
|
||||
- 节奏把控:有张有弛
|
||||
- 视角统一:采用{narrative_perspective}视角叙事
|
||||
|
||||
**重要:你必须只返回纯JSON数组格式,不要包含任何markdown标记、代码块标记或其他说明文字。**
|
||||
|
||||
请严格按照以下JSON数组格式返回(共{chapter_count}个章节对象):
|
||||
[
|
||||
{{
|
||||
"chapter_number": 1,
|
||||
"title": "第一章标题",
|
||||
"summary": "章节概要的详细描述(100-200字),包含主要情节、冲突、转折等",
|
||||
"scenes": ["场景1描述", "场景2描述", "场景3描述"],
|
||||
"characters": ["角色1", "角色2"],
|
||||
"key_points": ["情节要点1", "情节要点2"],
|
||||
"emotion": "本章情感基调",
|
||||
"goal": "本章叙事目标"
|
||||
}},
|
||||
{{
|
||||
"chapter_number": 2,
|
||||
"title": "第二章标题",
|
||||
"summary": "章节概要...",
|
||||
"scenes": ["场景1", "场景2"],
|
||||
"characters": ["角色1", "角色2"],
|
||||
"key_points": ["要点1", "要点2"],
|
||||
"emotion": "情感基调",
|
||||
"goal": "叙事目标"
|
||||
}}
|
||||
]
|
||||
|
||||
再次强调:只返回纯JSON数组,不要有```json```这样的标记,不要有任何额外的文字说明。数组中要包含{chapter_count}个章节对象。"""
|
||||
|
||||
# 大纲续写提示词
|
||||
OUTLINE_CONTINUE_GENERATION = """你是一位经验丰富的小说作家和编剧。请基于以下信息续写小说大纲:
|
||||
|
||||
【项目信息】
|
||||
- 书名:{title}
|
||||
- 主题:{theme}
|
||||
- 类型:{genre}
|
||||
- 叙事视角:{narrative_perspective}
|
||||
- 续写章节数:{chapter_count}章
|
||||
|
||||
【世界观】
|
||||
- 时间背景:{time_period}
|
||||
- 地理位置:{location}
|
||||
- 氛围基调:{atmosphere}
|
||||
- 世界规则:{rules}
|
||||
|
||||
【角色信息】
|
||||
{characters_info}
|
||||
|
||||
【已有章节概览】(共{current_chapter_count}章)
|
||||
{all_chapters_brief}
|
||||
|
||||
【最近剧情】
|
||||
{recent_plot}
|
||||
|
||||
【续写指导】
|
||||
- 当前情节阶段:{plot_stage_instruction}
|
||||
- 起始章节编号:第{start_chapter}章
|
||||
- 故事发展方向:{story_direction}
|
||||
- 其他要求:{requirements}
|
||||
|
||||
请生成第{start_chapter}章到第{end_chapter}章的大纲。
|
||||
要求:
|
||||
- 与前文自然衔接,保持故事连贯性
|
||||
- 遵循情节阶段的发展要求
|
||||
- 保持与已有章节相同的风格和详细程度
|
||||
- 推进角色成长和情节发展
|
||||
|
||||
**重要:你必须只返回纯JSON数组格式,不要包含任何markdown标记、代码块标记或其他说明文字。**
|
||||
|
||||
请严格按照以下JSON数组格式返回(共{chapter_count}个章节对象):
|
||||
[
|
||||
{{
|
||||
"chapter_number": {start_chapter},
|
||||
"title": "章节标题",
|
||||
"summary": "章节概要的详细描述(100-200字),包含主要情节、角色互动、关键事件、冲突与转折",
|
||||
"scenes": ["场景1描述", "场景2描述", "场景3描述"],
|
||||
"characters": ["涉及角色1", "涉及角色2"],
|
||||
"key_points": ["情节要点1", "情节要点2"],
|
||||
"emotion": "本章情感基调",
|
||||
"goal": "本章叙事目标"
|
||||
}},
|
||||
{{
|
||||
"chapter_number": {start_chapter} + 1,
|
||||
"title": "章节标题",
|
||||
"summary": "章节概要...",
|
||||
"scenes": ["场景1", "场景2"],
|
||||
"characters": ["角色1", "角色2"],
|
||||
"key_points": ["要点1", "要点2"],
|
||||
"emotion": "情感基调",
|
||||
"goal": "叙事目标"
|
||||
}}
|
||||
]
|
||||
|
||||
再次强调:
|
||||
1. 只返回纯JSON数组,不要有```json```这样的标记
|
||||
2. 数组中要包含{chapter_count}个章节对象
|
||||
3. 每个summary必须是100-200字的详细描述
|
||||
4. 确保字段结构与已有章节完全一致"""
|
||||
|
||||
# AI去味提示词(核心特色功能)
|
||||
AI_DENOISING = """你是一位追求自然写作风格的编辑。你的任务是将AI生成的文本改写得更像人类作家的手笔。
|
||||
|
||||
原文:
|
||||
{original_text}
|
||||
|
||||
修改要求:
|
||||
1. 去除AI痕迹:
|
||||
- 删除过于工整的排比句
|
||||
- 减少重复的修辞手法
|
||||
- 去掉刻意的对称结构
|
||||
- 避免机械式的总结陈词
|
||||
|
||||
2. 增加人性化:
|
||||
- 使用更口语化的表达
|
||||
- 添加不完美的细节
|
||||
- 保留适度的随意性
|
||||
- 增加真实的情感波动
|
||||
|
||||
3. 优化叙事:
|
||||
- 让节奏更自然不做作
|
||||
- 用简单词汇替换华丽辞藻
|
||||
- 保持叙述的松弛感
|
||||
- 让对话更生活化
|
||||
|
||||
4. 保持原意:
|
||||
- 不改变核心情节
|
||||
- 保留关键信息点
|
||||
- 维持角色性格
|
||||
- 确保逻辑连贯
|
||||
|
||||
修改风格:
|
||||
- 像是一个喜欢讲故事的普通人写的
|
||||
- 有点粗糙但很真诚
|
||||
- 自然流畅不刻意
|
||||
- 让人读起来很舒服
|
||||
|
||||
请直接输出修改后的文本,无需解释。"""
|
||||
|
||||
# 章节完整创作提示词
|
||||
CHAPTER_GENERATION = """你是一位专业的小说作家。请根据以下信息创作本章内容:
|
||||
|
||||
项目信息:
|
||||
- 书名:{title}
|
||||
- 主题:{theme}
|
||||
- 类型:{genre}
|
||||
- 叙事视角:{narrative_perspective}
|
||||
|
||||
世界观:
|
||||
- 时间背景:{time_period}
|
||||
- 地理位置:{location}
|
||||
- 氛围基调:{atmosphere}
|
||||
- 世界规则:{rules}
|
||||
|
||||
角色信息:
|
||||
{characters_info}
|
||||
|
||||
全书大纲:
|
||||
{outlines_context}
|
||||
|
||||
本章信息:
|
||||
- 章节序号:第{chapter_number}章
|
||||
- 章节标题:{chapter_title}
|
||||
- 章节大纲:{chapter_outline}
|
||||
|
||||
创作要求:
|
||||
1. 严格按照大纲内容展开情节
|
||||
2. 保持与前后章节的连贯性
|
||||
3. 符合角色性格设定
|
||||
4. 体现世界观特色
|
||||
5. 使用{narrative_perspective}视角
|
||||
6. 字数不得低于3000字
|
||||
7. 语言自然流畅,避免AI痕迹
|
||||
|
||||
**写作风格要求(重要):**
|
||||
- 让故事自然流淌,写到哪算哪
|
||||
- 结尾处直接结束情节,不要加总结性段落
|
||||
- 不要在章节末尾写"这一天/这一夜就这样过去了"之类的总结句
|
||||
- 不要用"他/她陷入了沉思"作为结尾
|
||||
- 避免刻意的情感升华或哲理感悟收尾
|
||||
- 章节结尾可以戛然而止,可以是对话,可以是动作,可以是悬念
|
||||
- 就像在讲一个故事,讲完了就停,不需要画龙点睛
|
||||
|
||||
请直接输出章节正文内容,不要包含章节标题和其他说明文字。"""
|
||||
|
||||
# 章节完整创作提示词(带前置章节上下文)
|
||||
CHAPTER_GENERATION_WITH_CONTEXT = """你是一位专业的小说作家。请根据以下信息创作本章内容:
|
||||
|
||||
项目信息:
|
||||
- 书名:{title}
|
||||
- 主题:{theme}
|
||||
- 类型:{genre}
|
||||
- 叙事视角:{narrative_perspective}
|
||||
|
||||
世界观:
|
||||
- 时间背景:{time_period}
|
||||
- 地理位置:{location}
|
||||
- 氛围基调:{atmosphere}
|
||||
- 世界规则:{rules}
|
||||
|
||||
角色信息:
|
||||
{characters_info}
|
||||
|
||||
全书大纲:
|
||||
{outlines_context}
|
||||
|
||||
【已完成的前置章节内容】
|
||||
{previous_content}
|
||||
|
||||
本章信息:
|
||||
- 章节序号:第{chapter_number}章
|
||||
- 章节标题:{chapter_title}
|
||||
- 章节大纲:{chapter_outline}
|
||||
|
||||
创作要求:
|
||||
1. **剧情连贯性(最重要)**:
|
||||
- 必须承接前面章节的剧情发展
|
||||
- 注意角色状态、情节进展、时间线的连续性
|
||||
- 不能出现与前文矛盾的内容
|
||||
- 自然过渡,避免突兀的跳跃
|
||||
|
||||
2. **情节推进**:
|
||||
- 严格按照本章大纲展开情节
|
||||
- 推动故事向前发展
|
||||
- 保持与全书大纲的一致性
|
||||
|
||||
3. **角色一致性**:
|
||||
- 符合角色性格设定
|
||||
- 延续角色在前文中的成长和变化
|
||||
- 保持角色关系的连贯性
|
||||
|
||||
4. **写作风格**:
|
||||
- 使用{narrative_perspective}视角
|
||||
- 字数不得低于3000字
|
||||
- 语言自然流畅,避免AI痕迹
|
||||
- 体现世界观特色
|
||||
|
||||
5. **承上启下**:
|
||||
- 开头自然衔接上一章结尾
|
||||
- 结尾为下一章做好铺垫
|
||||
|
||||
**写作风格要求(重要):**
|
||||
- 让故事自然流淌,写到哪算哪
|
||||
- 结尾处直接结束情节,不要加总结性段落
|
||||
- 不要在章节末尾写"这一天/这一夜就这样过去了"之类的总结句
|
||||
- 不要用"他/她陷入了沉思"作为结尾
|
||||
- 避免刻意的情感升华或哲理感悟收尾
|
||||
- 章节结尾可以戛然而止,可以是对话,可以是动作,可以是悬念
|
||||
- 就像在讲一个故事,讲完了就停,不需要画龙点睛
|
||||
|
||||
请直接输出章节正文内容,不要包含章节标题和其他说明文字。"""
|
||||
|
||||
# 大纲生成提示词
|
||||
OUTLINE_GENERATION = """你是一位经验丰富的小说作家和编剧。请根据以下信息生成小说大纲:
|
||||
|
||||
类型:{genre}
|
||||
主题:{theme}
|
||||
目标字数:{target_words}
|
||||
其他要求:{requirements}
|
||||
|
||||
请生成一个完整的章节大纲框架,包含:
|
||||
1. 合理的章节数量(根据字数)
|
||||
2. 每章的标题和内容概要
|
||||
3. 清晰的故事结构(起承转合)
|
||||
4. 情节的递进和冲突升级
|
||||
5. 角色的成长弧线
|
||||
|
||||
**重要:你必须只返回纯JSON格式,不要包含任何markdown标记、代码块标记或其他说明文字。**
|
||||
|
||||
请严格按照以下JSON格式返回:
|
||||
{{
|
||||
"chapters": [
|
||||
{{
|
||||
"order": 1,
|
||||
"title": "章节标题",
|
||||
"content": "章节内容概要(150-200字)"
|
||||
}}
|
||||
]
|
||||
}}
|
||||
|
||||
再次强调:只返回纯JSON对象,不要有```json```这样的标记,不要有任何额外的文字说明。"""
|
||||
|
||||
# 单个角色生成提示词
|
||||
SINGLE_CHARACTER_GENERATION = """你是一位专业的角色设定师。请根据以下信息创建一个立体饱满的小说角色。
|
||||
|
||||
{project_context}
|
||||
|
||||
{user_input}
|
||||
|
||||
请生成一个完整的角色卡片,包含以下所有信息:
|
||||
|
||||
1. **基本信息**:
|
||||
- 姓名:如果用户未提供,请生成一个符合世界观的名字
|
||||
- 年龄:具体数字或年龄段
|
||||
- 性别:男/女/其他
|
||||
|
||||
2. **外貌特征**(100-150字):
|
||||
- 身高体型、面容特征、着装风格
|
||||
- 要符合角色定位和世界观设定
|
||||
|
||||
3. **性格特点**(150-200字):
|
||||
- 核心性格特质(至少3个)
|
||||
- 优点和缺点
|
||||
- 特殊习惯或癖好
|
||||
- 性格要有复杂性和矛盾性
|
||||
|
||||
4. **背景故事**(200-300字):
|
||||
- 家庭背景
|
||||
- 成长经历
|
||||
- 重要转折事件
|
||||
- 如何与项目主题关联
|
||||
- 融入用户提供的背景设定
|
||||
|
||||
5. **人际关系**:
|
||||
- 与现有角色的关系(如果有)
|
||||
- 重要的人际纽带
|
||||
- 社会地位和人脉
|
||||
|
||||
6. **特殊能力/特长**:
|
||||
- 擅长的领域
|
||||
- 特殊技能或知识
|
||||
- 符合世界观设定
|
||||
|
||||
**你必须只返回纯JSON格式,不要包含任何markdown标记、代码块标记或其他说明文字。**
|
||||
|
||||
请严格按照以下JSON格式返回:
|
||||
{{
|
||||
"name": "角色姓名",
|
||||
"age": "年龄",
|
||||
"gender": "性别",
|
||||
"appearance": "外貌描述(100-150字)",
|
||||
"personality": "性格特点(150-200字)",
|
||||
"background": "背景故事(200-300字)",
|
||||
"traits": ["特长1", "特长2", "特长3"],
|
||||
|
||||
"relationships_text": "人际关系的文字描述(用于显示)",
|
||||
|
||||
"relationships": [
|
||||
{{
|
||||
"target_character_name": "已存在的角色名称",
|
||||
"relationship_type": "关系类型(如:师父、朋友、敌人、父亲、母亲等)",
|
||||
"intimacy_level": 75,
|
||||
"description": "这段关系的详细描述",
|
||||
"started_at": "关系开始的故事时间点(可选)"
|
||||
}}
|
||||
],
|
||||
|
||||
"organization_memberships": [
|
||||
{{
|
||||
"organization_name": "已存在的组织名称",
|
||||
"position": "职位名称",
|
||||
"rank": 8,
|
||||
"loyalty": 80,
|
||||
"joined_at": "加入时间(可选)",
|
||||
"status": "active"
|
||||
}}
|
||||
]
|
||||
}}
|
||||
|
||||
**关系类型参考(请从中选择或自定义):**
|
||||
- 家族关系:父亲、母亲、兄弟、姐妹、子女、配偶、恋人
|
||||
- 社交关系:师父、徒弟、朋友、同学、同事、邻居、知己
|
||||
- 职业关系:上司、下属、合作伙伴
|
||||
- 敌对关系:敌人、仇人、竞争对手、宿敌
|
||||
|
||||
**重要说明:**
|
||||
1. relationships数组:只包含与上面列出的已存在角色的关系,通过target_character_name匹配
|
||||
2. organization_memberships数组:只包含与上面列出的已存在组织的关系
|
||||
3. intimacy_level和loyalty都是0-100的整数
|
||||
4. 如果没有关系或组织,对应数组为空[]
|
||||
5. relationships_text是自然语言描述,用于展示给用户看
|
||||
|
||||
**角色设定要求:**
|
||||
- 角色要符合项目的世界观和主题
|
||||
- 如果是主角,要有明确的成长空间和目标动机
|
||||
- 如果是反派,要有合理的动机,不能脸谱化
|
||||
- 配角要有独特性,不能是工具人
|
||||
- 所有设定要为故事服务
|
||||
|
||||
再次强调:只返回纯JSON对象,不要有```json```这样的标记,不要有任何额外的文字说明。"""
|
||||
|
||||
@staticmethod
|
||||
def format_prompt(template: str, **kwargs) -> str:
|
||||
"""
|
||||
格式化提示词模板
|
||||
|
||||
Args:
|
||||
template: 提示词模板
|
||||
**kwargs: 模板参数
|
||||
|
||||
Returns:
|
||||
格式化后的提示词
|
||||
"""
|
||||
try:
|
||||
return template.format(**kwargs)
|
||||
except KeyError as e:
|
||||
raise ValueError(f"缺少必需的参数: {e}")
|
||||
|
||||
@classmethod
|
||||
def get_denoising_prompt(cls, original_text: str) -> str:
|
||||
"""获取AI去味提示词"""
|
||||
return cls.format_prompt(
|
||||
cls.AI_DENOISING,
|
||||
original_text=original_text
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_world_building_prompt(cls, title: str, theme: str, genre: str = "") -> str:
|
||||
"""获取世界构建提示词"""
|
||||
return cls.format_prompt(
|
||||
cls.WORLD_BUILDING,
|
||||
title=title,
|
||||
theme=theme,
|
||||
genre=genre or "通用类型"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_characters_batch_prompt(cls, count: int, time_period: str, location: str,
|
||||
atmosphere: str, rules: str, theme: str,
|
||||
genre: str = "", requirements: str = "") -> str:
|
||||
"""获取批量角色生成提示词"""
|
||||
return cls.format_prompt(
|
||||
cls.CHARACTERS_BATCH_GENERATION,
|
||||
count=count,
|
||||
time_period=time_period,
|
||||
location=location,
|
||||
atmosphere=atmosphere,
|
||||
rules=rules,
|
||||
theme=theme,
|
||||
genre=genre or "通用类型",
|
||||
requirements=requirements or "无特殊要求"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_complete_outline_prompt(cls, title: str, theme: str, genre: str,
|
||||
chapter_count: int, narrative_perspective: str,
|
||||
target_words: int, time_period: str, location: str,
|
||||
atmosphere: str, rules: str, characters_info: str,
|
||||
requirements: str = "") -> str:
|
||||
"""获取完整大纲生成提示词"""
|
||||
return cls.format_prompt(
|
||||
cls.COMPLETE_OUTLINE_GENERATION,
|
||||
title=title,
|
||||
theme=theme,
|
||||
genre=genre,
|
||||
chapter_count=chapter_count,
|
||||
narrative_perspective=narrative_perspective,
|
||||
target_words=target_words,
|
||||
time_period=time_period,
|
||||
location=location,
|
||||
atmosphere=atmosphere,
|
||||
rules=rules,
|
||||
characters_info=characters_info,
|
||||
requirements=requirements or "无特殊要求"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_chapter_generation_prompt(cls, title: str, theme: str, genre: str,
|
||||
narrative_perspective: str, time_period: str,
|
||||
location: str, atmosphere: str, rules: str,
|
||||
characters_info: str, outlines_context: str,
|
||||
chapter_number: int, chapter_title: str,
|
||||
chapter_outline: str) -> str:
|
||||
"""获取章节完整创作提示词"""
|
||||
return cls.format_prompt(
|
||||
cls.CHAPTER_GENERATION,
|
||||
title=title,
|
||||
theme=theme,
|
||||
genre=genre,
|
||||
narrative_perspective=narrative_perspective,
|
||||
time_period=time_period,
|
||||
location=location,
|
||||
atmosphere=atmosphere,
|
||||
rules=rules,
|
||||
characters_info=characters_info,
|
||||
outlines_context=outlines_context,
|
||||
chapter_number=chapter_number,
|
||||
chapter_title=chapter_title,
|
||||
chapter_outline=chapter_outline
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_chapter_generation_with_context_prompt(cls, title: str, theme: str, genre: str,
|
||||
narrative_perspective: str, time_period: str,
|
||||
location: str, atmosphere: str, rules: str,
|
||||
characters_info: str, outlines_context: str,
|
||||
previous_content: str, chapter_number: int,
|
||||
chapter_title: str, chapter_outline: str) -> str:
|
||||
"""获取章节完整创作提示词(带前置章节上下文)"""
|
||||
return cls.format_prompt(
|
||||
cls.CHAPTER_GENERATION_WITH_CONTEXT,
|
||||
title=title,
|
||||
theme=theme,
|
||||
genre=genre,
|
||||
narrative_perspective=narrative_perspective,
|
||||
time_period=time_period,
|
||||
location=location,
|
||||
atmosphere=atmosphere,
|
||||
rules=rules,
|
||||
characters_info=characters_info,
|
||||
outlines_context=outlines_context,
|
||||
previous_content=previous_content,
|
||||
chapter_number=chapter_number,
|
||||
chapter_title=chapter_title,
|
||||
chapter_outline=chapter_outline
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_outline_prompt(cls, genre: str, theme: str, target_words: int,
|
||||
requirements: str = "") -> str:
|
||||
"""获取大纲生成提示词"""
|
||||
return cls.format_prompt(
|
||||
cls.OUTLINE_GENERATION,
|
||||
genre=genre,
|
||||
theme=theme,
|
||||
target_words=target_words,
|
||||
requirements=requirements or "无特殊要求"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_outline_continue_prompt(cls, title: str, theme: str, genre: str,
|
||||
narrative_perspective: str, chapter_count: int,
|
||||
time_period: str, location: str, atmosphere: str,
|
||||
rules: str, characters_info: str,
|
||||
current_chapter_count: int, all_chapters_brief: str,
|
||||
recent_plot: str, plot_stage_instruction: str,
|
||||
start_chapter: int, story_direction: str,
|
||||
requirements: str = "") -> str:
|
||||
"""获取大纲续写提示词"""
|
||||
end_chapter = start_chapter + chapter_count - 1
|
||||
return cls.format_prompt(
|
||||
cls.OUTLINE_CONTINUE_GENERATION,
|
||||
title=title,
|
||||
theme=theme,
|
||||
genre=genre,
|
||||
narrative_perspective=narrative_perspective,
|
||||
chapter_count=chapter_count,
|
||||
time_period=time_period,
|
||||
location=location,
|
||||
atmosphere=atmosphere,
|
||||
rules=rules,
|
||||
characters_info=characters_info,
|
||||
current_chapter_count=current_chapter_count,
|
||||
all_chapters_brief=all_chapters_brief,
|
||||
recent_plot=recent_plot,
|
||||
plot_stage_instruction=plot_stage_instruction,
|
||||
start_chapter=start_chapter,
|
||||
end_chapter=end_chapter,
|
||||
story_direction=story_direction,
|
||||
requirements=requirements or "无特殊要求"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_single_character_prompt(cls, project_context: str, user_input: str) -> str:
|
||||
"""获取单个角色生成提示词"""
|
||||
return cls.format_prompt(
|
||||
cls.SINGLE_CHARACTER_GENERATION,
|
||||
project_context=project_context,
|
||||
user_input=user_input
|
||||
)
|
||||
|
||||
|
||||
# 创建全局提示词服务实例
|
||||
prompt_service = PromptService()
|
||||
@@ -0,0 +1,294 @@
|
||||
"""
|
||||
用户管理模块 - 支持 LinuxDO OAuth2
|
||||
"""
|
||||
import json
|
||||
import os
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
from typing import Optional, Dict, List
|
||||
from pydantic import BaseModel
|
||||
from app.config import settings, DATA_DIR
|
||||
|
||||
|
||||
class User(BaseModel):
|
||||
"""用户模型"""
|
||||
user_id: str # 格式: linuxdo_{linuxdo_id}
|
||||
username: str
|
||||
display_name: str
|
||||
avatar_url: Optional[str] = None
|
||||
trust_level: int = 0 # 仅用于显示
|
||||
is_admin: bool = False # 手动设置的管理员权限
|
||||
linuxdo_id: str # LinuxDO 用户 ID
|
||||
created_at: str
|
||||
last_login: str
|
||||
|
||||
|
||||
class UserManager:
|
||||
"""用户管理器 - 线程安全版本"""
|
||||
|
||||
USERS_FILE = str(DATA_DIR / "users.json")
|
||||
ADMINS_FILE = str(DATA_DIR / "admins.json")
|
||||
|
||||
def __init__(self):
|
||||
"""初始化用户管理器"""
|
||||
# DATA_DIR 已在 config.py 中创建,无需重复创建
|
||||
# 添加文件锁保护并发读写
|
||||
self._users_lock = asyncio.Lock()
|
||||
self._admins_lock = asyncio.Lock()
|
||||
self._ensure_files_exist()
|
||||
|
||||
def _ensure_files_exist(self):
|
||||
"""确保必要的文件存在"""
|
||||
if not os.path.exists(self.USERS_FILE):
|
||||
with open(self.USERS_FILE, "w", encoding="utf-8") as f:
|
||||
json.dump({}, f, ensure_ascii=False, indent=2)
|
||||
|
||||
if not os.path.exists(self.ADMINS_FILE):
|
||||
with open(self.ADMINS_FILE, "w", encoding="utf-8") as f:
|
||||
json.dump({"admins": []}, f, ensure_ascii=False, indent=2)
|
||||
|
||||
def _load_users_unsafe(self) -> Dict[str, dict]:
|
||||
"""加载用户数据(不加锁,内部使用)"""
|
||||
try:
|
||||
with open(self.USERS_FILE, "r", encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
except Exception as e:
|
||||
print(f"加载用户数据失败: {e}")
|
||||
return {}
|
||||
|
||||
def _save_users_unsafe(self, users: Dict[str, dict]):
|
||||
"""保存用户数据(不加锁,内部使用)"""
|
||||
try:
|
||||
with open(self.USERS_FILE, "w", encoding="utf-8") as f:
|
||||
json.dump(users, f, ensure_ascii=False, indent=2)
|
||||
except Exception as e:
|
||||
print(f"保存用户数据失败: {e}")
|
||||
|
||||
async def _load_users(self) -> Dict[str, dict]:
|
||||
"""加载用户数据(加锁)"""
|
||||
async with self._users_lock:
|
||||
return self._load_users_unsafe()
|
||||
|
||||
async def _save_users(self, users: Dict[str, dict]):
|
||||
"""保存用户数据(加锁)"""
|
||||
async with self._users_lock:
|
||||
self._save_users_unsafe(users)
|
||||
|
||||
def _load_admin_list_unsafe(self) -> List[str]:
|
||||
"""加载管理员列表(不加锁,内部使用)"""
|
||||
try:
|
||||
with open(self.ADMINS_FILE, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
return data.get("admins", [])
|
||||
except Exception as e:
|
||||
print(f"加载管理员列表失败: {e}")
|
||||
return []
|
||||
|
||||
def _save_admin_list_unsafe(self, admin_list: List[str]):
|
||||
"""保存管理员列表(不加锁,内部使用)"""
|
||||
try:
|
||||
with open(self.ADMINS_FILE, "w", encoding="utf-8") as f:
|
||||
json.dump({"admins": admin_list}, f, ensure_ascii=False, indent=2)
|
||||
except Exception as e:
|
||||
print(f"保存管理员列表失败: {e}")
|
||||
|
||||
async def _load_admin_list(self) -> List[str]:
|
||||
"""加载管理员列表(加锁)"""
|
||||
async with self._admins_lock:
|
||||
return self._load_admin_list_unsafe()
|
||||
|
||||
async def _save_admin_list(self, admin_list: List[str]):
|
||||
"""保存管理员列表(加锁)"""
|
||||
async with self._admins_lock:
|
||||
self._save_admin_list_unsafe(admin_list)
|
||||
|
||||
async def create_or_update_from_linuxdo(
|
||||
self,
|
||||
linuxdo_id: str,
|
||||
username: str,
|
||||
display_name: str,
|
||||
avatar_url: Optional[str],
|
||||
trust_level: int
|
||||
) -> User:
|
||||
"""
|
||||
从 LinuxDO 用户信息创建或更新用户(线程安全)
|
||||
|
||||
Args:
|
||||
linuxdo_id: LinuxDO 用户 ID(本地用户时为 local_xxx 格式)
|
||||
username: 用户名
|
||||
display_name: 显示名称
|
||||
avatar_url: 头像 URL
|
||||
trust_level: 信任等级 (仅用于显示)
|
||||
|
||||
Returns:
|
||||
用户对象
|
||||
"""
|
||||
# 如果已经是 local_ 开头,直接使用;否则添加 linuxdo_ 前缀
|
||||
if linuxdo_id.startswith("local_"):
|
||||
user_id = linuxdo_id
|
||||
else:
|
||||
user_id = f"linuxdo_{linuxdo_id}"
|
||||
|
||||
# 使用锁保护整个读-改-写操作
|
||||
async with self._users_lock:
|
||||
async with self._admins_lock:
|
||||
users = self._load_users_unsafe()
|
||||
admin_list = self._load_admin_list_unsafe()
|
||||
|
||||
now = datetime.now().isoformat()
|
||||
|
||||
# 检查是否为初始管理员
|
||||
initial_admin_id = settings.INITIAL_ADMIN_LINUXDO_ID
|
||||
is_initial_admin = (initial_admin_id and linuxdo_id == initial_admin_id)
|
||||
|
||||
# 检查是否为本地用户(所有 local_ 开头的用户默认为管理员)
|
||||
is_local_user = user_id.startswith("local_")
|
||||
|
||||
if user_id in users:
|
||||
# 更新现有用户
|
||||
user_data = users[user_id]
|
||||
user_data["username"] = username
|
||||
user_data["display_name"] = display_name
|
||||
user_data["avatar_url"] = avatar_url
|
||||
user_data["trust_level"] = trust_level
|
||||
user_data["last_login"] = now
|
||||
|
||||
# 如果是初始管理员或本地用户且还不在管理员列表中,添加进去
|
||||
if (is_initial_admin or is_local_user) and user_id not in admin_list:
|
||||
admin_list.append(user_id)
|
||||
self._save_admin_list_unsafe(admin_list)
|
||||
user_data["is_admin"] = True
|
||||
else:
|
||||
# 从管理员列表同步 is_admin 状态
|
||||
user_data["is_admin"] = user_id in admin_list
|
||||
else:
|
||||
# 创建新用户(本地用户默认为管理员)
|
||||
is_admin = is_initial_admin or is_local_user
|
||||
if is_admin and user_id not in admin_list:
|
||||
admin_list.append(user_id)
|
||||
self._save_admin_list_unsafe(admin_list)
|
||||
|
||||
user_data = {
|
||||
"user_id": user_id,
|
||||
"username": username,
|
||||
"display_name": display_name,
|
||||
"avatar_url": avatar_url,
|
||||
"trust_level": trust_level,
|
||||
"is_admin": is_admin,
|
||||
"linuxdo_id": linuxdo_id,
|
||||
"created_at": now,
|
||||
"last_login": now
|
||||
}
|
||||
users[user_id] = user_data
|
||||
|
||||
self._save_users_unsafe(users)
|
||||
return User(**user_data)
|
||||
|
||||
async def get_user(self, user_id: str) -> Optional[User]:
|
||||
"""获取用户(线程安全)"""
|
||||
users = await self._load_users()
|
||||
user_data = users.get(user_id)
|
||||
if user_data:
|
||||
# 同步管理员状态
|
||||
admin_list = await self._load_admin_list()
|
||||
user_data["is_admin"] = user_id in admin_list
|
||||
return User(**user_data)
|
||||
return None
|
||||
|
||||
async def get_all_users(self) -> List[User]:
|
||||
"""获取所有用户(线程安全)"""
|
||||
users = await self._load_users()
|
||||
admin_list = await self._load_admin_list()
|
||||
|
||||
user_list = []
|
||||
for user_data in users.values():
|
||||
# 同步管理员状态
|
||||
user_data["is_admin"] = user_data["user_id"] in admin_list
|
||||
user_list.append(User(**user_data))
|
||||
|
||||
return user_list
|
||||
|
||||
async def set_admin(self, user_id: str, is_admin: bool) -> bool:
|
||||
"""
|
||||
设置用户的管理员权限(线程安全)
|
||||
|
||||
Args:
|
||||
user_id: 用户 ID
|
||||
is_admin: 是否为管理员
|
||||
|
||||
Returns:
|
||||
是否成功
|
||||
"""
|
||||
# 使用锁保护整个读-改-写操作
|
||||
async with self._users_lock:
|
||||
async with self._admins_lock:
|
||||
users = self._load_users_unsafe()
|
||||
if user_id not in users:
|
||||
return False
|
||||
|
||||
admin_list = self._load_admin_list_unsafe()
|
||||
|
||||
if is_admin:
|
||||
# 授予管理员权限
|
||||
if user_id not in admin_list:
|
||||
admin_list.append(user_id)
|
||||
self._save_admin_list_unsafe(admin_list)
|
||||
else:
|
||||
# 撤销管理员权限
|
||||
if user_id in admin_list:
|
||||
# 确保至少保留一个管理员
|
||||
if len(admin_list) <= 1:
|
||||
return False
|
||||
admin_list.remove(user_id)
|
||||
self._save_admin_list_unsafe(admin_list)
|
||||
|
||||
# 更新用户数据中的 is_admin 字段
|
||||
users[user_id]["is_admin"] = is_admin
|
||||
self._save_users_unsafe(users)
|
||||
|
||||
return True
|
||||
|
||||
async def delete_user(self, user_id: str) -> bool:
|
||||
"""
|
||||
删除用户(线程安全)
|
||||
|
||||
Args:
|
||||
user_id: 用户 ID
|
||||
|
||||
Returns:
|
||||
是否成功
|
||||
"""
|
||||
# 使用锁保护整个读-改-写操作
|
||||
async with self._users_lock:
|
||||
async with self._admins_lock:
|
||||
users = self._load_users_unsafe()
|
||||
if user_id not in users:
|
||||
return False
|
||||
|
||||
# 不能删除管理员
|
||||
admin_list = self._load_admin_list_unsafe()
|
||||
if user_id in admin_list:
|
||||
return False
|
||||
|
||||
# 删除用户数据
|
||||
del users[user_id]
|
||||
self._save_users_unsafe(users)
|
||||
|
||||
# 删除用户数据库文件(在锁外执行,避免阻塞)
|
||||
db_file = str(DATA_DIR / f"ai_story_user_{user_id}.db")
|
||||
if os.path.exists(db_file):
|
||||
try:
|
||||
os.remove(db_file)
|
||||
except Exception as e:
|
||||
print(f"删除用户数据库文件失败: {e}")
|
||||
|
||||
return True
|
||||
|
||||
async def is_admin(self, user_id: str) -> bool:
|
||||
"""检查用户是否为管理员(线程安全)"""
|
||||
admin_list = await self._load_admin_list()
|
||||
return user_id in admin_list
|
||||
|
||||
|
||||
# 全局用户管理器实例
|
||||
user_manager = UserManager()
|
||||
@@ -0,0 +1,347 @@
|
||||
"""数据一致性辅助函数"""
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select
|
||||
from typing import Optional, Tuple, List
|
||||
from app.models.character import Character
|
||||
from app.models.relationship import Organization, OrganizationMember, CharacterRelationship
|
||||
from app.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
async def ensure_organization_record(
|
||||
character: Character,
|
||||
db: AsyncSession,
|
||||
power_level: int = 50,
|
||||
location: Optional[str] = None,
|
||||
motto: Optional[str] = None
|
||||
) -> Optional[Organization]:
|
||||
"""
|
||||
确保组织角色拥有对应的Organization记录
|
||||
|
||||
Args:
|
||||
character: Character对象(必须是is_organization=True)
|
||||
db: 数据库会话
|
||||
power_level: 势力等级(默认50)
|
||||
location: 所在地
|
||||
motto: 宗旨/口号
|
||||
|
||||
Returns:
|
||||
Organization对象,如果character不是组织则返回None
|
||||
"""
|
||||
if not character.is_organization:
|
||||
logger.debug(f"角色 {character.name} 不是组织,跳过Organization记录创建")
|
||||
return None
|
||||
|
||||
# 检查是否已存在
|
||||
result = await db.execute(
|
||||
select(Organization).where(Organization.character_id == character.id)
|
||||
)
|
||||
org = result.scalar_one_or_none()
|
||||
|
||||
if not org:
|
||||
# 创建新的Organization记录
|
||||
org = Organization(
|
||||
character_id=character.id,
|
||||
project_id=character.project_id,
|
||||
member_count=0,
|
||||
power_level=power_level,
|
||||
location=location,
|
||||
motto=motto
|
||||
)
|
||||
db.add(org)
|
||||
await db.flush()
|
||||
await db.refresh(org)
|
||||
logger.info(f"✅ 自动创建组织详情:{character.name} (Org ID: {org.id})")
|
||||
else:
|
||||
logger.debug(f"组织详情已存在:{character.name} (Org ID: {org.id})")
|
||||
|
||||
return org
|
||||
|
||||
|
||||
async def sync_organization_member_count(
|
||||
organization: Organization,
|
||||
db: AsyncSession
|
||||
) -> int:
|
||||
"""
|
||||
同步组织的成员计数,从实际成员记录计算
|
||||
|
||||
Args:
|
||||
organization: Organization对象
|
||||
db: 数据库会话
|
||||
|
||||
Returns:
|
||||
实际成员数量
|
||||
"""
|
||||
result = await db.execute(
|
||||
select(OrganizationMember).where(
|
||||
OrganizationMember.organization_id == organization.id,
|
||||
OrganizationMember.status == "active"
|
||||
)
|
||||
)
|
||||
members = result.scalars().all()
|
||||
actual_count = len(members)
|
||||
|
||||
if organization.member_count != actual_count:
|
||||
logger.warning(
|
||||
f"组织 {organization.id} 成员计数不一致:"
|
||||
f"记录值={organization.member_count}, 实际值={actual_count},已修正"
|
||||
)
|
||||
organization.member_count = actual_count
|
||||
await db.flush()
|
||||
|
||||
return actual_count
|
||||
|
||||
|
||||
async def fix_missing_organization_records(
|
||||
project_id: str,
|
||||
db: AsyncSession
|
||||
) -> Tuple[int, int]:
|
||||
"""
|
||||
修复项目中缺失的Organization记录
|
||||
|
||||
为所有is_organization=True但没有Organization记录的Character创建记录
|
||||
|
||||
Args:
|
||||
project_id: 项目ID
|
||||
db: 数据库会话
|
||||
|
||||
Returns:
|
||||
(修复数量, 检查总数)
|
||||
"""
|
||||
# 查找所有组织角色
|
||||
result = await db.execute(
|
||||
select(Character).where(
|
||||
Character.project_id == project_id,
|
||||
Character.is_organization == True
|
||||
)
|
||||
)
|
||||
org_characters = result.scalars().all()
|
||||
|
||||
fixed_count = 0
|
||||
for char in org_characters:
|
||||
org = await ensure_organization_record(char, db)
|
||||
if org and org.id: # 新创建的才计数
|
||||
# 检查是否是新创建的(通过查询历史)
|
||||
result = await db.execute(
|
||||
select(Organization).where(Organization.character_id == char.id)
|
||||
)
|
||||
if result.scalar_one_or_none():
|
||||
fixed_count += 1
|
||||
|
||||
await db.commit()
|
||||
|
||||
logger.info(f"📊 修复统计 - 检查了 {len(org_characters)} 个组织,修复了 {fixed_count} 个缺失的Organization记录")
|
||||
return fixed_count, len(org_characters)
|
||||
|
||||
|
||||
async def fix_organization_member_counts(
|
||||
project_id: str,
|
||||
db: AsyncSession
|
||||
) -> Tuple[int, int]:
|
||||
"""
|
||||
修复项目中所有组织的成员计数
|
||||
|
||||
Args:
|
||||
project_id: 项目ID
|
||||
db: 数据库会话
|
||||
|
||||
Returns:
|
||||
(修复数量, 检查总数)
|
||||
"""
|
||||
# 查找所有组织
|
||||
result = await db.execute(
|
||||
select(Organization).where(Organization.project_id == project_id)
|
||||
)
|
||||
organizations = result.scalars().all()
|
||||
|
||||
fixed_count = 0
|
||||
for org in organizations:
|
||||
old_count = org.member_count
|
||||
actual_count = await sync_organization_member_count(org, db)
|
||||
if old_count != actual_count:
|
||||
fixed_count += 1
|
||||
|
||||
await db.commit()
|
||||
|
||||
logger.info(f"📊 修复统计 - 检查了 {len(organizations)} 个组织,修复了 {fixed_count} 个计数错误")
|
||||
return fixed_count, len(organizations)
|
||||
|
||||
|
||||
async def validate_relationships(
|
||||
project_id: str,
|
||||
db: AsyncSession
|
||||
) -> List[dict]:
|
||||
"""
|
||||
验证项目中的关系数据完整性
|
||||
|
||||
检查所有关系中的character_from_id和character_to_id是否都指向存在的角色
|
||||
|
||||
Args:
|
||||
project_id: 项目ID
|
||||
db: 数据库会话
|
||||
|
||||
Returns:
|
||||
问题列表,每个问题包含 {issue_type, relationship_id, details}
|
||||
"""
|
||||
issues = []
|
||||
|
||||
# 获取所有关系
|
||||
result = await db.execute(
|
||||
select(CharacterRelationship).where(CharacterRelationship.project_id == project_id)
|
||||
)
|
||||
relationships = result.scalars().all()
|
||||
|
||||
for rel in relationships:
|
||||
# 检查from角色
|
||||
from_char = await db.execute(
|
||||
select(Character).where(Character.id == rel.character_from_id)
|
||||
)
|
||||
if not from_char.scalar_one_or_none():
|
||||
issues.append({
|
||||
"issue_type": "missing_from_character",
|
||||
"relationship_id": rel.id,
|
||||
"details": f"关系 {rel.id} 的源角色 {rel.character_from_id} 不存在"
|
||||
})
|
||||
|
||||
# 检查to角色
|
||||
to_char = await db.execute(
|
||||
select(Character).where(Character.id == rel.character_to_id)
|
||||
)
|
||||
if not to_char.scalar_one_or_none():
|
||||
issues.append({
|
||||
"issue_type": "missing_to_character",
|
||||
"relationship_id": rel.id,
|
||||
"details": f"关系 {rel.id} 的目标角色 {rel.character_to_id} 不存在"
|
||||
})
|
||||
|
||||
if issues:
|
||||
logger.warning(f"⚠️ 发现 {len(issues)} 个关系数据问题")
|
||||
for issue in issues:
|
||||
logger.warning(f" - {issue['details']}")
|
||||
else:
|
||||
logger.info(f"✅ 所有 {len(relationships)} 条关系数据完整")
|
||||
|
||||
return issues
|
||||
|
||||
|
||||
async def validate_organization_members(
|
||||
project_id: str,
|
||||
db: AsyncSession
|
||||
) -> List[dict]:
|
||||
"""
|
||||
验证项目中的组织成员数据完整性
|
||||
|
||||
检查所有成员关系中的organization_id和character_id是否都有效
|
||||
|
||||
Args:
|
||||
project_id: 项目ID
|
||||
db: 数据库会话
|
||||
|
||||
Returns:
|
||||
问题列表
|
||||
"""
|
||||
issues = []
|
||||
|
||||
# 获取所有成员关系
|
||||
result = await db.execute(
|
||||
select(OrganizationMember).where(
|
||||
OrganizationMember.organization_id.in_(
|
||||
select(Organization.id).where(Organization.project_id == project_id)
|
||||
)
|
||||
)
|
||||
)
|
||||
members = result.scalars().all()
|
||||
|
||||
for member in members:
|
||||
# 检查组织
|
||||
org = await db.execute(
|
||||
select(Organization).where(Organization.id == member.organization_id)
|
||||
)
|
||||
if not org.scalar_one_or_none():
|
||||
issues.append({
|
||||
"issue_type": "missing_organization",
|
||||
"member_id": member.id,
|
||||
"details": f"成员 {member.id} 的组织 {member.organization_id} 不存在"
|
||||
})
|
||||
|
||||
# 检查角色
|
||||
char = await db.execute(
|
||||
select(Character).where(Character.id == member.character_id)
|
||||
)
|
||||
if not char.scalar_one_or_none():
|
||||
issues.append({
|
||||
"issue_type": "missing_character",
|
||||
"member_id": member.id,
|
||||
"details": f"成员 {member.id} 的角色 {member.character_id} 不存在"
|
||||
})
|
||||
|
||||
if issues:
|
||||
logger.warning(f"⚠️ 发现 {len(issues)} 个组织成员数据问题")
|
||||
for issue in issues:
|
||||
logger.warning(f" - {issue['details']}")
|
||||
else:
|
||||
logger.info(f"✅ 所有 {len(members)} 条组织成员数据完整")
|
||||
|
||||
return issues
|
||||
|
||||
|
||||
async def run_full_data_consistency_check(
|
||||
project_id: str,
|
||||
db: AsyncSession,
|
||||
auto_fix: bool = True
|
||||
) -> dict:
|
||||
"""
|
||||
对项目运行完整的数据一致性检查和修复
|
||||
|
||||
Args:
|
||||
project_id: 项目ID
|
||||
db: 数据库会话
|
||||
auto_fix: 是否自动修复问题(默认True)
|
||||
|
||||
Returns:
|
||||
检查报告字典
|
||||
"""
|
||||
logger.info(f"🔍 开始数据一致性检查 - 项目 {project_id}")
|
||||
|
||||
report = {
|
||||
"project_id": project_id,
|
||||
"checks": {}
|
||||
}
|
||||
|
||||
# 1. 检查并修复缺失的Organization记录
|
||||
if auto_fix:
|
||||
fixed, total = await fix_missing_organization_records(project_id, db)
|
||||
report["checks"]["organization_records"] = {
|
||||
"checked": total,
|
||||
"fixed": fixed,
|
||||
"status": "ok" if fixed == 0 else "fixed"
|
||||
}
|
||||
|
||||
# 2. 检查并修复成员计数
|
||||
if auto_fix:
|
||||
fixed, total = await fix_organization_member_counts(project_id, db)
|
||||
report["checks"]["member_counts"] = {
|
||||
"checked": total,
|
||||
"fixed": fixed,
|
||||
"status": "ok" if fixed == 0 else "fixed"
|
||||
}
|
||||
|
||||
# 3. 验证关系数据
|
||||
rel_issues = await validate_relationships(project_id, db)
|
||||
report["checks"]["relationships"] = {
|
||||
"issues_found": len(rel_issues),
|
||||
"issues": rel_issues,
|
||||
"status": "ok" if len(rel_issues) == 0 else "warning"
|
||||
}
|
||||
|
||||
# 4. 验证组织成员数据
|
||||
member_issues = await validate_organization_members(project_id, db)
|
||||
report["checks"]["organization_members"] = {
|
||||
"issues_found": len(member_issues),
|
||||
"issues": member_issues,
|
||||
"status": "ok" if len(member_issues) == 0 else "warning"
|
||||
}
|
||||
|
||||
logger.info(f"✅ 数据一致性检查完成")
|
||||
return report
|
||||
@@ -0,0 +1,169 @@
|
||||
"""Server-Sent Events (SSE) 响应工具类"""
|
||||
import json
|
||||
import asyncio
|
||||
from typing import AsyncGenerator, Dict, Any, Optional
|
||||
from fastapi.responses import StreamingResponse
|
||||
from app.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class SSEResponse:
|
||||
"""SSE响应构建器"""
|
||||
|
||||
@staticmethod
|
||||
def format_sse(data: Dict[str, Any], event: Optional[str] = None) -> str:
|
||||
"""
|
||||
格式化SSE消息
|
||||
|
||||
Args:
|
||||
data: 要发送的数据字典
|
||||
event: 事件类型(可选)
|
||||
|
||||
Returns:
|
||||
格式化后的SSE消息字符串
|
||||
"""
|
||||
message = ""
|
||||
if event:
|
||||
message += f"event: {event}\n"
|
||||
message += f"data: {json.dumps(data, ensure_ascii=False)}\n\n"
|
||||
return message
|
||||
|
||||
@staticmethod
|
||||
async def send_progress(
|
||||
message: str,
|
||||
progress: int,
|
||||
status: str = "processing"
|
||||
) -> str:
|
||||
"""
|
||||
发送进度消息
|
||||
|
||||
Args:
|
||||
message: 进度消息
|
||||
progress: 进度百分比(0-100)
|
||||
status: 状态(processing/success/error)
|
||||
"""
|
||||
return SSEResponse.format_sse({
|
||||
"type": "progress",
|
||||
"message": message,
|
||||
"progress": progress,
|
||||
"status": status
|
||||
})
|
||||
|
||||
@staticmethod
|
||||
async def send_chunk(content: str) -> str:
|
||||
"""
|
||||
发送内容块(用于流式输出AI生成内容)
|
||||
|
||||
Args:
|
||||
content: 内容块
|
||||
"""
|
||||
return SSEResponse.format_sse({
|
||||
"type": "chunk",
|
||||
"content": content
|
||||
})
|
||||
|
||||
@staticmethod
|
||||
async def send_result(data: Dict[str, Any]) -> str:
|
||||
"""
|
||||
发送最终结果
|
||||
|
||||
Args:
|
||||
data: 结果数据
|
||||
"""
|
||||
return SSEResponse.format_sse({
|
||||
"type": "result",
|
||||
"data": data
|
||||
})
|
||||
|
||||
@staticmethod
|
||||
async def send_error(error: str, code: int = 500) -> str:
|
||||
"""
|
||||
发送错误消息
|
||||
|
||||
Args:
|
||||
error: 错误描述
|
||||
code: 错误码
|
||||
"""
|
||||
return SSEResponse.format_sse({
|
||||
"type": "error",
|
||||
"error": error,
|
||||
"code": code
|
||||
})
|
||||
|
||||
@staticmethod
|
||||
async def send_done() -> str:
|
||||
"""发送完成消息"""
|
||||
return SSEResponse.format_sse({
|
||||
"type": "done"
|
||||
})
|
||||
|
||||
@staticmethod
|
||||
async def send_heartbeat() -> str:
|
||||
"""发送心跳消息(保持连接活跃)"""
|
||||
return ": heartbeat\n\n"
|
||||
|
||||
|
||||
async def create_sse_generator(
|
||||
async_gen: AsyncGenerator[str, None],
|
||||
show_progress: bool = True
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""
|
||||
创建SSE生成器包装器
|
||||
|
||||
Args:
|
||||
async_gen: 异步生成器
|
||||
show_progress: 是否显示进度
|
||||
|
||||
Yields:
|
||||
格式化的SSE消息
|
||||
"""
|
||||
try:
|
||||
if show_progress:
|
||||
yield await SSEResponse.send_progress("开始生成...", 0)
|
||||
|
||||
# 累积内容用于进度计算
|
||||
accumulated_content = ""
|
||||
chunk_count = 0
|
||||
|
||||
async for chunk in async_gen:
|
||||
chunk_count += 1
|
||||
accumulated_content += chunk
|
||||
|
||||
# 发送内容块
|
||||
yield await SSEResponse.send_chunk(chunk)
|
||||
|
||||
# 每10个块发送一次心跳
|
||||
if chunk_count % 10 == 0:
|
||||
yield await SSEResponse.send_heartbeat()
|
||||
|
||||
if show_progress:
|
||||
yield await SSEResponse.send_progress("生成完成", 100, "success")
|
||||
|
||||
# 发送完成信号
|
||||
yield await SSEResponse.send_done()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"SSE生成器错误: {str(e)}")
|
||||
yield await SSEResponse.send_error(str(e))
|
||||
|
||||
|
||||
def create_sse_response(generator: AsyncGenerator[str, None]) -> StreamingResponse:
|
||||
"""
|
||||
创建SSE StreamingResponse
|
||||
|
||||
Args:
|
||||
generator: SSE消息生成器
|
||||
|
||||
Returns:
|
||||
StreamingResponse对象
|
||||
"""
|
||||
return StreamingResponse(
|
||||
generator,
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no", # 禁用nginx缓冲
|
||||
}
|
||||
)
|
||||
Reference in New Issue
Block a user