init
This commit is contained in:
@@ -0,0 +1 @@
|
||||
"""API路由模块"""
|
||||
@@ -0,0 +1,230 @@
|
||||
"""
|
||||
认证 API - LinuxDO OAuth2 登录 + 本地账户登录
|
||||
"""
|
||||
from fastapi import APIRouter, HTTPException, Response, Request
|
||||
from fastapi.responses import RedirectResponse
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional
|
||||
import hashlib
|
||||
from app.services.oauth_service import LinuxDOOAuthService
|
||||
from app.user_manager import user_manager
|
||||
from app.database import init_db
|
||||
from app.logger import get_logger
|
||||
from app.config import settings
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/auth", tags=["认证"])
|
||||
|
||||
# OAuth2 服务实例
|
||||
oauth_service = LinuxDOOAuthService()
|
||||
|
||||
# State 临时存储(生产环境应使用 Redis)
|
||||
_state_storage = {}
|
||||
|
||||
|
||||
class AuthUrlResponse(BaseModel):
|
||||
auth_url: str
|
||||
state: str
|
||||
|
||||
|
||||
class LocalLoginRequest(BaseModel):
|
||||
"""本地登录请求"""
|
||||
username: str
|
||||
password: str
|
||||
|
||||
|
||||
class LocalLoginResponse(BaseModel):
|
||||
"""本地登录响应"""
|
||||
success: bool
|
||||
message: str
|
||||
user: Optional[dict] = None
|
||||
|
||||
|
||||
@router.get("/config")
|
||||
async def get_auth_config():
|
||||
"""获取认证配置信息"""
|
||||
return {
|
||||
"local_auth_enabled": settings.LOCAL_AUTH_ENABLED,
|
||||
"linuxdo_enabled": bool(settings.LINUXDO_CLIENT_ID and settings.LINUXDO_CLIENT_SECRET)
|
||||
}
|
||||
|
||||
|
||||
@router.post("/local/login", response_model=LocalLoginResponse)
|
||||
async def local_login(request: LocalLoginRequest, response: Response):
|
||||
"""本地账户登录"""
|
||||
# 检查是否启用本地登录
|
||||
if not settings.LOCAL_AUTH_ENABLED:
|
||||
raise HTTPException(status_code=403, detail="本地账户登录未启用")
|
||||
|
||||
# 检查是否配置了本地账户
|
||||
if not settings.LOCAL_AUTH_USERNAME or not settings.LOCAL_AUTH_PASSWORD:
|
||||
raise HTTPException(status_code=500, detail="本地账户未配置")
|
||||
|
||||
# 验证用户名和密码
|
||||
if request.username != settings.LOCAL_AUTH_USERNAME or request.password != settings.LOCAL_AUTH_PASSWORD:
|
||||
raise HTTPException(status_code=401, detail="用户名或密码错误")
|
||||
|
||||
# 生成本地用户ID(使用用户名的hash)
|
||||
user_id = f"local_{hashlib.md5(request.username.encode()).hexdigest()[:16]}"
|
||||
|
||||
# 创建或更新本地用户
|
||||
user = await user_manager.create_or_update_from_linuxdo(
|
||||
linuxdo_id=user_id,
|
||||
username=request.username,
|
||||
display_name=settings.LOCAL_AUTH_DISPLAY_NAME,
|
||||
avatar_url=None,
|
||||
trust_level=9 # 本地用户给予高信任级别
|
||||
)
|
||||
|
||||
# 初始化用户数据库
|
||||
try:
|
||||
await init_db(user.user_id)
|
||||
logger.info(f"本地用户 {user.user_id} 数据库初始化成功")
|
||||
except Exception as e:
|
||||
logger.error(f"本地用户 {user.user_id} 数据库初始化失败: {e}")
|
||||
|
||||
# 设置 Cookie(7天有效)
|
||||
response.set_cookie(
|
||||
key="user_id",
|
||||
value=user.user_id,
|
||||
max_age=7 * 24 * 60 * 60, # 7天
|
||||
httponly=True,
|
||||
samesite="lax"
|
||||
)
|
||||
|
||||
return LocalLoginResponse(
|
||||
success=True,
|
||||
message="登录成功",
|
||||
user=user.dict()
|
||||
)
|
||||
|
||||
|
||||
@router.get("/linuxdo/url", response_model=AuthUrlResponse)
|
||||
async def get_linuxdo_auth_url():
|
||||
"""获取 LinuxDO 授权 URL"""
|
||||
state = oauth_service.generate_state()
|
||||
auth_url = oauth_service.get_authorization_url(state)
|
||||
|
||||
# 临时存储 state(5分钟有效)
|
||||
_state_storage[state] = True
|
||||
|
||||
return AuthUrlResponse(auth_url=auth_url, state=state)
|
||||
|
||||
|
||||
async def _handle_callback(
|
||||
code: Optional[str] = None,
|
||||
state: Optional[str] = None,
|
||||
error: Optional[str] = None,
|
||||
response: Response = None
|
||||
):
|
||||
"""
|
||||
LinuxDO OAuth2 回调处理
|
||||
|
||||
成功后重定向到前端首页,并设置 user_id Cookie
|
||||
"""
|
||||
# 检查是否有错误
|
||||
if error:
|
||||
raise HTTPException(status_code=400, detail=f"授权失败: {error}")
|
||||
|
||||
# 检查必需参数
|
||||
if not code or not state:
|
||||
raise HTTPException(status_code=400, detail="缺少 code 或 state 参数")
|
||||
|
||||
# 验证 state(防止 CSRF)
|
||||
if state not in _state_storage:
|
||||
raise HTTPException(status_code=400, detail="无效的 state 参数")
|
||||
|
||||
# 删除已使用的 state
|
||||
del _state_storage[state]
|
||||
|
||||
# 1. 使用 code 获取 access_token
|
||||
token_data = await oauth_service.get_access_token(code)
|
||||
if not token_data or "access_token" not in token_data:
|
||||
raise HTTPException(status_code=400, detail="获取访问令牌失败")
|
||||
|
||||
access_token = token_data["access_token"]
|
||||
|
||||
# 2. 使用 access_token 获取用户信息
|
||||
user_info = await oauth_service.get_user_info(access_token)
|
||||
if not user_info:
|
||||
raise HTTPException(status_code=400, detail="获取用户信息失败")
|
||||
|
||||
# 3. 创建或更新用户
|
||||
linuxdo_id = str(user_info.get("id"))
|
||||
username = user_info.get("username", "")
|
||||
display_name = user_info.get("name", username)
|
||||
avatar_url = user_info.get("avatar_url")
|
||||
trust_level = user_info.get("trust_level", 0)
|
||||
|
||||
user = await user_manager.create_or_update_from_linuxdo(
|
||||
linuxdo_id=linuxdo_id,
|
||||
username=username,
|
||||
display_name=display_name,
|
||||
avatar_url=avatar_url,
|
||||
trust_level=trust_level
|
||||
)
|
||||
|
||||
# 3.5. 初始化用户数据库(如果是新用户)
|
||||
try:
|
||||
await init_db(user.user_id)
|
||||
logger.info(f"用户 {user.user_id} 数据库初始化成功")
|
||||
except Exception as e:
|
||||
logger.error(f"用户 {user.user_id} 数据库初始化失败: {e}")
|
||||
# 继续执行,不影响登录流程(可能是已存在的用户)
|
||||
|
||||
# 4. 设置 Cookie 并重定向到前端回调页面
|
||||
# 使用配置的前端URL,支持不同的部署环境
|
||||
frontend_url = settings.FRONTEND_URL.rstrip('/')
|
||||
redirect_url = f"{frontend_url}/auth/callback"
|
||||
logger.info(f"OAuth回调成功,重定向到前端: {redirect_url}")
|
||||
redirect_response = RedirectResponse(url=redirect_url)
|
||||
|
||||
# 设置 httponly Cookie(7天有效)
|
||||
redirect_response.set_cookie(
|
||||
key="user_id",
|
||||
value=user.user_id,
|
||||
max_age=7 * 24 * 60 * 60, # 7天
|
||||
httponly=True,
|
||||
samesite="lax"
|
||||
)
|
||||
|
||||
return redirect_response
|
||||
|
||||
|
||||
@router.get("/linuxdo/callback")
|
||||
async def linuxdo_callback(
|
||||
code: Optional[str] = None,
|
||||
state: Optional[str] = None,
|
||||
error: Optional[str] = None,
|
||||
response: Response = None
|
||||
):
|
||||
"""LinuxDO OAuth2 回调处理(标准路径)"""
|
||||
return await _handle_callback(code, state, error, response)
|
||||
|
||||
|
||||
@router.get("/callback")
|
||||
async def callback_alias(
|
||||
code: Optional[str] = None,
|
||||
state: Optional[str] = None,
|
||||
error: Optional[str] = None,
|
||||
response: Response = None
|
||||
):
|
||||
"""LinuxDO OAuth2 回调处理(兼容路径)"""
|
||||
return await _handle_callback(code, state, error, response)
|
||||
|
||||
|
||||
@router.post("/logout")
|
||||
async def logout(response: Response):
|
||||
"""退出登录"""
|
||||
response.delete_cookie("user_id")
|
||||
return {"message": "退出登录成功"}
|
||||
|
||||
|
||||
@router.get("/user")
|
||||
async def get_current_user(request: Request):
|
||||
"""获取当前登录用户信息"""
|
||||
if not hasattr(request.state, "user") or not request.state.user:
|
||||
raise HTTPException(status_code=401, detail="未登录")
|
||||
|
||||
return request.state.user.dict()
|
||||
@@ -0,0 +1,655 @@
|
||||
"""章节管理API"""
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, func
|
||||
import json
|
||||
import asyncio
|
||||
|
||||
from app.database import get_db
|
||||
from app.models.chapter import Chapter
|
||||
from app.models.project import Project
|
||||
from app.models.outline import Outline
|
||||
from app.models.character import Character
|
||||
from app.models.generation_history import GenerationHistory
|
||||
from app.schemas.chapter import (
|
||||
ChapterCreate,
|
||||
ChapterUpdate,
|
||||
ChapterResponse,
|
||||
ChapterListResponse
|
||||
)
|
||||
from app.services.ai_service import ai_service
|
||||
from app.services.prompt_service import prompt_service
|
||||
from app.logger import get_logger
|
||||
|
||||
router = APIRouter(prefix="/chapters", tags=["章节管理"])
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@router.post("", response_model=ChapterResponse, summary="创建章节")
|
||||
async def create_chapter(
|
||||
chapter: ChapterCreate,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""创建新的章节"""
|
||||
# 验证项目是否存在
|
||||
result = await db.execute(
|
||||
select(Project).where(Project.id == chapter.project_id)
|
||||
)
|
||||
project = result.scalar_one_or_none()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="项目不存在")
|
||||
|
||||
# 计算字数
|
||||
word_count = len(chapter.content)
|
||||
|
||||
db_chapter = Chapter(
|
||||
**chapter.model_dump(),
|
||||
word_count=word_count
|
||||
)
|
||||
db.add(db_chapter)
|
||||
|
||||
# 更新项目的当前字数
|
||||
project.current_words = project.current_words + word_count
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(db_chapter)
|
||||
return db_chapter
|
||||
|
||||
|
||||
@router.get("/project/{project_id}", response_model=ChapterListResponse, summary="获取项目的所有章节")
|
||||
async def get_project_chapters(
|
||||
project_id: str,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""获取指定项目的所有章节(路径参数版本)"""
|
||||
# 获取总数
|
||||
count_result = await db.execute(
|
||||
select(func.count(Chapter.id)).where(Chapter.project_id == project_id)
|
||||
)
|
||||
total = count_result.scalar_one()
|
||||
|
||||
# 获取章节列表
|
||||
result = await db.execute(
|
||||
select(Chapter)
|
||||
.where(Chapter.project_id == project_id)
|
||||
.order_by(Chapter.chapter_number)
|
||||
)
|
||||
chapters = result.scalars().all()
|
||||
|
||||
return ChapterListResponse(total=total, items=chapters)
|
||||
|
||||
|
||||
@router.get("/{chapter_id}", response_model=ChapterResponse, summary="获取章节详情")
|
||||
async def get_chapter(
|
||||
chapter_id: str,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""根据ID获取章节详情"""
|
||||
result = await db.execute(
|
||||
select(Chapter).where(Chapter.id == chapter_id)
|
||||
)
|
||||
chapter = result.scalar_one_or_none()
|
||||
|
||||
if not chapter:
|
||||
raise HTTPException(status_code=404, detail="章节不存在")
|
||||
|
||||
return chapter
|
||||
|
||||
|
||||
@router.put("/{chapter_id}", response_model=ChapterResponse, summary="更新章节")
|
||||
async def update_chapter(
|
||||
chapter_id: str,
|
||||
chapter_update: ChapterUpdate,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""更新章节信息"""
|
||||
result = await db.execute(
|
||||
select(Chapter).where(Chapter.id == chapter_id)
|
||||
)
|
||||
chapter = result.scalar_one_or_none()
|
||||
|
||||
if not chapter:
|
||||
raise HTTPException(status_code=404, detail="章节不存在")
|
||||
|
||||
# 记录旧字数
|
||||
old_word_count = chapter.word_count or 0
|
||||
|
||||
# 更新字段
|
||||
update_data = chapter_update.model_dump(exclude_unset=True)
|
||||
for field, value in update_data.items():
|
||||
setattr(chapter, field, value)
|
||||
|
||||
# 如果内容更新了,重新计算字数
|
||||
if "content" in update_data and chapter.content:
|
||||
new_word_count = len(chapter.content)
|
||||
chapter.word_count = new_word_count
|
||||
|
||||
# 更新项目字数
|
||||
result = await db.execute(
|
||||
select(Project).where(Project.id == chapter.project_id)
|
||||
)
|
||||
project = result.scalar_one_or_none()
|
||||
if project:
|
||||
project.current_words = project.current_words - old_word_count + new_word_count
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(chapter)
|
||||
return chapter
|
||||
|
||||
|
||||
@router.delete("/{chapter_id}", summary="删除章节")
|
||||
async def delete_chapter(
|
||||
chapter_id: str,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""删除章节"""
|
||||
result = await db.execute(
|
||||
select(Chapter).where(Chapter.id == chapter_id)
|
||||
)
|
||||
chapter = result.scalar_one_or_none()
|
||||
|
||||
if not chapter:
|
||||
raise HTTPException(status_code=404, detail="章节不存在")
|
||||
|
||||
# 更新项目字数
|
||||
result = await db.execute(
|
||||
select(Project).where(Project.id == chapter.project_id)
|
||||
)
|
||||
project = result.scalar_one_or_none()
|
||||
if project:
|
||||
project.current_words = max(0, project.current_words - chapter.word_count)
|
||||
|
||||
await db.delete(chapter)
|
||||
await db.commit()
|
||||
|
||||
return {"message": "章节删除成功"}
|
||||
|
||||
|
||||
async def check_prerequisites(db: AsyncSession, chapter: Chapter) -> tuple[bool, str, list[Chapter]]:
|
||||
"""
|
||||
检查章节前置条件
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
chapter: 当前章节
|
||||
|
||||
Returns:
|
||||
(可否生成, 错误信息, 前置章节列表)
|
||||
"""
|
||||
# 如果是第一章,无需检查前置
|
||||
if chapter.chapter_number == 1:
|
||||
return True, "", []
|
||||
|
||||
# 查询所有前置章节(序号小于当前章节的)
|
||||
result = await db.execute(
|
||||
select(Chapter)
|
||||
.where(Chapter.project_id == chapter.project_id)
|
||||
.where(Chapter.chapter_number < chapter.chapter_number)
|
||||
.order_by(Chapter.chapter_number)
|
||||
)
|
||||
previous_chapters = result.scalars().all()
|
||||
|
||||
# 检查是否所有前置章节都有内容
|
||||
incomplete_chapters = [
|
||||
ch for ch in previous_chapters
|
||||
if not ch.content or ch.content.strip() == ""
|
||||
]
|
||||
|
||||
if incomplete_chapters:
|
||||
missing_numbers = [str(ch.chapter_number) for ch in incomplete_chapters]
|
||||
error_msg = f"需要先完成前置章节:第 {', '.join(missing_numbers)} 章"
|
||||
return False, error_msg, previous_chapters
|
||||
|
||||
return True, "", previous_chapters
|
||||
|
||||
|
||||
@router.get("/{chapter_id}/can-generate", summary="检查章节是否可以生成")
|
||||
async def check_can_generate(
|
||||
chapter_id: str,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
检查章节是否满足生成条件
|
||||
返回可生成状态和前置章节信息
|
||||
"""
|
||||
# 获取章节
|
||||
result = await db.execute(
|
||||
select(Chapter).where(Chapter.id == chapter_id)
|
||||
)
|
||||
chapter = result.scalar_one_or_none()
|
||||
if not chapter:
|
||||
raise HTTPException(status_code=404, detail="章节不存在")
|
||||
|
||||
# 检查前置条件
|
||||
can_generate, error_msg, previous_chapters = await check_prerequisites(db, chapter)
|
||||
|
||||
# 构建前置章节信息
|
||||
previous_info = [
|
||||
{
|
||||
"id": ch.id,
|
||||
"chapter_number": ch.chapter_number,
|
||||
"title": ch.title,
|
||||
"has_content": bool(ch.content and ch.content.strip()),
|
||||
"word_count": ch.word_count or 0
|
||||
}
|
||||
for ch in previous_chapters
|
||||
]
|
||||
|
||||
return {
|
||||
"can_generate": can_generate,
|
||||
"reason": error_msg if not can_generate else "",
|
||||
"previous_chapters": previous_info,
|
||||
"chapter_number": chapter.chapter_number
|
||||
}
|
||||
|
||||
|
||||
@router.post("/{chapter_id}/generate", summary="AI创作章节内容")
|
||||
async def generate_chapter_content(
|
||||
chapter_id: str,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
根据大纲、前置章节内容和项目信息AI创作章节完整内容
|
||||
要求:必须按顺序生成,确保前置章节都已完成
|
||||
"""
|
||||
# 获取章节
|
||||
result = await db.execute(
|
||||
select(Chapter).where(Chapter.id == chapter_id)
|
||||
)
|
||||
chapter = result.scalar_one_or_none()
|
||||
if not chapter:
|
||||
raise HTTPException(status_code=404, detail="章节不存在")
|
||||
|
||||
# 检查前置条件
|
||||
can_generate, error_msg, previous_chapters = await check_prerequisites(db, chapter)
|
||||
if not can_generate:
|
||||
raise HTTPException(status_code=400, detail=error_msg)
|
||||
|
||||
try:
|
||||
# 获取项目信息
|
||||
project_result = await db.execute(
|
||||
select(Project).where(Project.id == chapter.project_id)
|
||||
)
|
||||
project = project_result.scalar_one_or_none()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="项目不存在")
|
||||
|
||||
# 获取对应的大纲(使用新的查询确保获取最新数据)
|
||||
outline_result = await db.execute(
|
||||
select(Outline)
|
||||
.where(Outline.project_id == chapter.project_id)
|
||||
.where(Outline.order_index == chapter.chapter_number)
|
||||
.execution_options(populate_existing=True)
|
||||
)
|
||||
outline = outline_result.scalar_one_or_none()
|
||||
|
||||
# 获取所有大纲用于上下文(使用新的查询确保获取最新数据)
|
||||
all_outlines_result = await db.execute(
|
||||
select(Outline)
|
||||
.where(Outline.project_id == chapter.project_id)
|
||||
.order_by(Outline.order_index)
|
||||
.execution_options(populate_existing=True)
|
||||
)
|
||||
all_outlines = all_outlines_result.scalars().all()
|
||||
outlines_context = "\n".join([
|
||||
f"第{o.order_index}章 {o.title}: {o.content[:100]}..."
|
||||
for o in all_outlines
|
||||
])
|
||||
|
||||
# 获取角色信息
|
||||
characters_result = await db.execute(
|
||||
select(Character).where(Character.project_id == chapter.project_id)
|
||||
)
|
||||
characters = characters_result.scalars().all()
|
||||
characters_info = "\n".join([
|
||||
f"- {c.name}({'组织' if c.is_organization else '角色'}, {c.role_type}): {c.personality[:100] if c.personality else ''}"
|
||||
for c in characters
|
||||
])
|
||||
|
||||
# 构建前置章节内容上下文(如果有前置章节)
|
||||
previous_content = ""
|
||||
if previous_chapters:
|
||||
# Token控制:保留最近3章的完整内容,早期章节使用摘要
|
||||
recent_chapters = previous_chapters[-3:] if len(previous_chapters) > 3 else previous_chapters
|
||||
early_chapters = previous_chapters[:-3] if len(previous_chapters) > 3 else []
|
||||
|
||||
# 早期章节摘要
|
||||
if early_chapters:
|
||||
early_summary = "【前期剧情概要】\n" + "\n".join([
|
||||
f"第{ch.chapter_number}章《{ch.title}》:{ch.content[:200] if ch.content else ''}..."
|
||||
for ch in early_chapters
|
||||
])
|
||||
previous_content += early_summary + "\n\n"
|
||||
|
||||
# 最近章节完整内容
|
||||
if recent_chapters:
|
||||
recent_content = "【最近章节完整内容】\n" + "\n\n".join([
|
||||
f"=== 第{ch.chapter_number}章:{ch.title} ===\n{ch.content}"
|
||||
for ch in recent_chapters
|
||||
])
|
||||
previous_content += recent_content
|
||||
|
||||
logger.info(f"构建前置上下文:{len(early_chapters)}章摘要 + {len(recent_chapters)}章完整内容")
|
||||
|
||||
# 根据是否有前置内容选择不同的提示词
|
||||
if previous_content:
|
||||
# 使用带上下文的提示词
|
||||
prompt = prompt_service.get_chapter_generation_with_context_prompt(
|
||||
title=project.title,
|
||||
theme=project.theme or '',
|
||||
genre=project.genre or '',
|
||||
narrative_perspective=project.narrative_perspective or '第三人称',
|
||||
time_period=project.world_time_period or '未设定',
|
||||
location=project.world_location or '未设定',
|
||||
atmosphere=project.world_atmosphere or '未设定',
|
||||
rules=project.world_rules or '未设定',
|
||||
characters_info=characters_info or '暂无角色信息',
|
||||
outlines_context=outlines_context,
|
||||
previous_content=previous_content,
|
||||
chapter_number=chapter.chapter_number,
|
||||
chapter_title=chapter.title,
|
||||
chapter_outline=outline.content if outline else chapter.summary or '暂无大纲'
|
||||
)
|
||||
else:
|
||||
# 第一章,使用原有提示词
|
||||
prompt = prompt_service.get_chapter_generation_prompt(
|
||||
title=project.title,
|
||||
theme=project.theme or '',
|
||||
genre=project.genre or '',
|
||||
narrative_perspective=project.narrative_perspective or '第三人称',
|
||||
time_period=project.world_time_period or '未设定',
|
||||
location=project.world_location or '未设定',
|
||||
atmosphere=project.world_atmosphere or '未设定',
|
||||
rules=project.world_rules or '未设定',
|
||||
characters_info=characters_info or '暂无角色信息',
|
||||
outlines_context=outlines_context,
|
||||
chapter_number=chapter.chapter_number,
|
||||
chapter_title=chapter.title,
|
||||
chapter_outline=outline.content if outline else chapter.summary or '暂无大纲'
|
||||
)
|
||||
|
||||
logger.info(f"开始AI创作章节 {chapter_id}")
|
||||
|
||||
# 调用AI生成
|
||||
ai_content = await ai_service.generate_text(
|
||||
prompt=prompt
|
||||
)
|
||||
|
||||
# 更新章节内容
|
||||
old_word_count = chapter.word_count or 0
|
||||
chapter.content = ai_content
|
||||
new_word_count = len(ai_content)
|
||||
chapter.word_count = new_word_count
|
||||
chapter.status = "completed"
|
||||
|
||||
# 更新项目字数
|
||||
project.current_words = project.current_words - old_word_count + new_word_count
|
||||
|
||||
# 记录生成历史
|
||||
history = GenerationHistory(
|
||||
project_id=chapter.project_id,
|
||||
chapter_id=chapter.id,
|
||||
prompt=f"创作章节: 第{chapter.chapter_number}章 {chapter.title}",
|
||||
generated_content=ai_content[:500] if len(ai_content) > 500 else ai_content,
|
||||
model="default"
|
||||
)
|
||||
db.add(history)
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(chapter)
|
||||
|
||||
logger.info(f"成功创作章节 {chapter_id},共 {new_word_count} 字")
|
||||
|
||||
return {"content": ai_content}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"创作章节失败: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"创作章节失败: {str(e)}")
|
||||
|
||||
@router.post("/{chapter_id}/generate-stream", summary="AI创作章节内容(流式)")
|
||||
async def generate_chapter_content_stream(
|
||||
chapter_id: str,
|
||||
request: Request
|
||||
):
|
||||
"""
|
||||
根据大纲、前置章节内容和项目信息AI创作章节完整内容(流式返回)
|
||||
要求:必须按顺序生成,确保前置章节都已完成
|
||||
|
||||
注意:此函数不使用依赖注入的db,而是在生成器内部创建独立的数据库会话
|
||||
以避免流式响应期间的连接泄漏问题
|
||||
"""
|
||||
# 预先验证章节存在性(使用临时会话)
|
||||
async for temp_db in get_db(request):
|
||||
try:
|
||||
result = await temp_db.execute(
|
||||
select(Chapter).where(Chapter.id == chapter_id)
|
||||
)
|
||||
chapter = result.scalar_one_or_none()
|
||||
if not chapter:
|
||||
raise HTTPException(status_code=404, detail="章节不存在")
|
||||
|
||||
# 检查前置条件
|
||||
can_generate, error_msg, previous_chapters = await check_prerequisites(temp_db, chapter)
|
||||
if not can_generate:
|
||||
raise HTTPException(status_code=400, detail=error_msg)
|
||||
|
||||
# 保存前置章节数据供生成器使用
|
||||
previous_chapters_data = [
|
||||
{
|
||||
'id': ch.id,
|
||||
'chapter_number': ch.chapter_number,
|
||||
'title': ch.title,
|
||||
'content': ch.content
|
||||
}
|
||||
for ch in previous_chapters
|
||||
]
|
||||
finally:
|
||||
await temp_db.close()
|
||||
break
|
||||
|
||||
async def event_generator():
|
||||
# 在生成器内部创建独立的数据库会话
|
||||
db_session = None
|
||||
db_committed = False
|
||||
try:
|
||||
# 创建新的数据库会话
|
||||
async for db_session in get_db(request):
|
||||
# 重新获取章节信息
|
||||
chapter_result = await db_session.execute(
|
||||
select(Chapter).where(Chapter.id == chapter_id)
|
||||
)
|
||||
current_chapter = chapter_result.scalar_one_or_none()
|
||||
if not current_chapter:
|
||||
yield f"data: {json.dumps({'type': 'error', 'error': '章节不存在'}, ensure_ascii=False)}\n\n"
|
||||
return
|
||||
|
||||
# 获取项目信息
|
||||
project_result = await db_session.execute(
|
||||
select(Project).where(Project.id == current_chapter.project_id)
|
||||
)
|
||||
project = project_result.scalar_one_or_none()
|
||||
if not project:
|
||||
yield f"data: {json.dumps({'type': 'error', 'error': '项目不存在'}, ensure_ascii=False)}\n\n"
|
||||
return
|
||||
|
||||
# 获取对应的大纲
|
||||
outline_result = await db_session.execute(
|
||||
select(Outline)
|
||||
.where(Outline.project_id == current_chapter.project_id)
|
||||
.where(Outline.order_index == current_chapter.chapter_number)
|
||||
.execution_options(populate_existing=True)
|
||||
)
|
||||
outline = outline_result.scalar_one_or_none()
|
||||
|
||||
# 获取所有大纲用于上下文
|
||||
all_outlines_result = await db_session.execute(
|
||||
select(Outline)
|
||||
.where(Outline.project_id == current_chapter.project_id)
|
||||
.order_by(Outline.order_index)
|
||||
.execution_options(populate_existing=True)
|
||||
)
|
||||
all_outlines = all_outlines_result.scalars().all()
|
||||
outlines_context = "\n".join([
|
||||
f"第{o.order_index}章 {o.title}: {o.content[:100]}..."
|
||||
for o in all_outlines
|
||||
])
|
||||
|
||||
# 获取角色信息
|
||||
characters_result = await db_session.execute(
|
||||
select(Character).where(Character.project_id == current_chapter.project_id)
|
||||
)
|
||||
characters = characters_result.scalars().all()
|
||||
characters_info = "\n".join([
|
||||
f"- {c.name}({'组织' if c.is_organization else '角色'}, {c.role_type}): {c.personality[:100] if c.personality else ''}"
|
||||
for c in characters
|
||||
])
|
||||
|
||||
# 构建前置章节内容上下文(使用之前保存的数据)
|
||||
previous_content = ""
|
||||
if previous_chapters_data:
|
||||
recent_chapters = previous_chapters_data[-3:] if len(previous_chapters_data) > 3 else previous_chapters_data
|
||||
early_chapters = previous_chapters_data[:-3] if len(previous_chapters_data) > 3 else []
|
||||
|
||||
if early_chapters:
|
||||
early_summary = "【前期剧情概要】\n" + "\n".join([
|
||||
f"第{ch['chapter_number']}章《{ch['title']}》:{ch['content'][:200] if ch['content'] else ''}..."
|
||||
for ch in early_chapters
|
||||
])
|
||||
previous_content += early_summary + "\n\n"
|
||||
|
||||
if recent_chapters:
|
||||
recent_content = "【最近章节完整内容】\n" + "\n\n".join([
|
||||
f"=== 第{ch['chapter_number']}章:{ch['title']} ===\n{ch['content']}"
|
||||
for ch in recent_chapters
|
||||
])
|
||||
previous_content += recent_content
|
||||
|
||||
logger.info(f"构建前置上下文:{len(early_chapters)}章摘要 + {len(recent_chapters)}章完整内容")
|
||||
|
||||
# 发送开始事件
|
||||
yield f"data: {json.dumps({'type': 'start', 'message': '开始AI创作...'}, ensure_ascii=False)}\n\n"
|
||||
|
||||
# 根据是否有前置内容选择不同的提示词
|
||||
if previous_content:
|
||||
prompt = prompt_service.get_chapter_generation_with_context_prompt(
|
||||
title=project.title,
|
||||
theme=project.theme or '',
|
||||
genre=project.genre or '',
|
||||
narrative_perspective=project.narrative_perspective or '第三人称',
|
||||
time_period=project.world_time_period or '未设定',
|
||||
location=project.world_location or '未设定',
|
||||
atmosphere=project.world_atmosphere or '未设定',
|
||||
rules=project.world_rules or '未设定',
|
||||
characters_info=characters_info or '暂无角色信息',
|
||||
outlines_context=outlines_context,
|
||||
previous_content=previous_content,
|
||||
chapter_number=current_chapter.chapter_number,
|
||||
chapter_title=current_chapter.title,
|
||||
chapter_outline=outline.content if outline else current_chapter.summary or '暂无大纲'
|
||||
)
|
||||
else:
|
||||
prompt = prompt_service.get_chapter_generation_prompt(
|
||||
title=project.title,
|
||||
theme=project.theme or '',
|
||||
genre=project.genre or '',
|
||||
narrative_perspective=project.narrative_perspective or '第三人称',
|
||||
time_period=project.world_time_period or '未设定',
|
||||
location=project.world_location or '未设定',
|
||||
atmosphere=project.world_atmosphere or '未设定',
|
||||
rules=project.world_rules or '未设定',
|
||||
characters_info=characters_info or '暂无角色信息',
|
||||
outlines_context=outlines_context,
|
||||
chapter_number=current_chapter.chapter_number,
|
||||
chapter_title=current_chapter.title,
|
||||
chapter_outline=outline.content if outline else current_chapter.summary or '暂无大纲'
|
||||
)
|
||||
|
||||
logger.info(f"开始AI流式创作章节 {chapter_id}")
|
||||
|
||||
# 流式生成内容
|
||||
full_content = ""
|
||||
async for chunk in ai_service.generate_text_stream(prompt=prompt):
|
||||
full_content += chunk
|
||||
yield f"data: {json.dumps({'type': 'content', 'content': chunk}, ensure_ascii=False)}\n\n"
|
||||
await asyncio.sleep(0) # 让出控制权
|
||||
|
||||
# 更新章节内容到数据库
|
||||
old_word_count = current_chapter.word_count or 0
|
||||
current_chapter.content = full_content
|
||||
new_word_count = len(full_content)
|
||||
current_chapter.word_count = new_word_count
|
||||
current_chapter.status = "completed"
|
||||
|
||||
# 更新项目字数
|
||||
project.current_words = project.current_words - old_word_count + new_word_count
|
||||
|
||||
# 记录生成历史
|
||||
history = GenerationHistory(
|
||||
project_id=current_chapter.project_id,
|
||||
chapter_id=current_chapter.id,
|
||||
prompt=f"创作章节: 第{current_chapter.chapter_number}章 {current_chapter.title}",
|
||||
generated_content=full_content[:500] if len(full_content) > 500 else full_content,
|
||||
model="default"
|
||||
)
|
||||
db_session.add(history)
|
||||
|
||||
await db_session.commit()
|
||||
db_committed = True
|
||||
await db_session.refresh(current_chapter)
|
||||
|
||||
logger.info(f"成功创作章节 {chapter_id},共 {new_word_count} 字")
|
||||
|
||||
# 发送完成事件
|
||||
yield f"data: {json.dumps({'type': 'done', 'message': '创作完成', 'word_count': new_word_count}, ensure_ascii=False)}\n\n"
|
||||
|
||||
break # 退出async for db_session循环
|
||||
|
||||
except GeneratorExit:
|
||||
# SSE连接断开
|
||||
logger.warning("章节生成器被提前关闭(SSE断开)")
|
||||
if db_session and not db_committed:
|
||||
try:
|
||||
if db_session.in_transaction():
|
||||
await db_session.rollback()
|
||||
logger.info("章节生成事务已回滚(GeneratorExit)")
|
||||
except Exception as e:
|
||||
logger.error(f"GeneratorExit回滚失败: {str(e)}")
|
||||
except Exception as e:
|
||||
logger.error(f"流式创作章节失败: {str(e)}")
|
||||
if db_session and not db_committed:
|
||||
try:
|
||||
if db_session.in_transaction():
|
||||
await db_session.rollback()
|
||||
logger.info("章节生成事务已回滚(异常)")
|
||||
except Exception as rollback_error:
|
||||
logger.error(f"回滚失败: {str(rollback_error)}")
|
||||
yield f"data: {json.dumps({'type': 'error', 'error': str(e)}, ensure_ascii=False)}\n\n"
|
||||
finally:
|
||||
# 确保数据库会话被正确关闭
|
||||
if db_session:
|
||||
try:
|
||||
# 最后检查:确保没有未提交的事务
|
||||
if not db_committed and db_session.in_transaction():
|
||||
await db_session.rollback()
|
||||
logger.warning("在finally中发现未提交事务,已回滚")
|
||||
|
||||
await db_session.close()
|
||||
logger.info("数据库会话已关闭")
|
||||
except Exception as close_error:
|
||||
logger.error(f"关闭数据库会话失败: {str(close_error)}")
|
||||
# 强制关闭
|
||||
try:
|
||||
await db_session.close()
|
||||
except:
|
||||
pass
|
||||
|
||||
return StreamingResponse(
|
||||
event_generator(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no"
|
||||
}
|
||||
)
|
||||
@@ -0,0 +1,491 @@
|
||||
"""角色管理API"""
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, func
|
||||
import json
|
||||
|
||||
from app.database import get_db
|
||||
from app.models.character import Character
|
||||
from app.models.project import Project
|
||||
from app.models.generation_history import GenerationHistory
|
||||
from app.models.relationship import CharacterRelationship, Organization, OrganizationMember, RelationshipType
|
||||
from app.schemas.character import (
|
||||
CharacterUpdate,
|
||||
CharacterResponse,
|
||||
CharacterListResponse,
|
||||
CharacterGenerateRequest
|
||||
)
|
||||
from app.services.ai_service import ai_service
|
||||
from app.services.prompt_service import prompt_service
|
||||
from app.logger import get_logger
|
||||
|
||||
router = APIRouter(prefix="/characters", tags=["角色管理"])
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@router.get("", response_model=CharacterListResponse, summary="获取角色列表")
|
||||
async def get_characters(
|
||||
project_id: str,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""获取指定项目的所有角色(query参数版本)"""
|
||||
# 获取总数
|
||||
count_result = await db.execute(
|
||||
select(func.count(Character.id)).where(Character.project_id == project_id)
|
||||
)
|
||||
total = count_result.scalar_one()
|
||||
|
||||
# 获取角色列表
|
||||
result = await db.execute(
|
||||
select(Character)
|
||||
.where(Character.project_id == project_id)
|
||||
.order_by(Character.created_at.desc())
|
||||
)
|
||||
characters = result.scalars().all()
|
||||
|
||||
return CharacterListResponse(total=total, items=characters)
|
||||
|
||||
|
||||
@router.get("/project/{project_id}", response_model=CharacterListResponse, summary="获取项目的所有角色")
|
||||
async def get_project_characters(
|
||||
project_id: str,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""获取指定项目的所有角色(路径参数版本)"""
|
||||
# 获取总数
|
||||
count_result = await db.execute(
|
||||
select(func.count(Character.id)).where(Character.project_id == project_id)
|
||||
)
|
||||
total = count_result.scalar_one()
|
||||
|
||||
# 获取角色列表
|
||||
result = await db.execute(
|
||||
select(Character)
|
||||
.where(Character.project_id == project_id)
|
||||
.order_by(Character.created_at.desc())
|
||||
)
|
||||
characters = result.scalars().all()
|
||||
|
||||
return CharacterListResponse(total=total, items=characters)
|
||||
|
||||
|
||||
@router.get("/{character_id}", response_model=CharacterResponse, summary="获取角色详情")
|
||||
async def get_character(
|
||||
character_id: str,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""根据ID获取角色详情"""
|
||||
result = await db.execute(
|
||||
select(Character).where(Character.id == character_id)
|
||||
)
|
||||
character = result.scalar_one_or_none()
|
||||
|
||||
if not character:
|
||||
raise HTTPException(status_code=404, detail="角色不存在")
|
||||
|
||||
return character
|
||||
|
||||
|
||||
@router.put("/{character_id}", response_model=CharacterResponse, summary="更新角色")
|
||||
async def update_character(
|
||||
character_id: str,
|
||||
character_update: CharacterUpdate,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""更新角色信息"""
|
||||
result = await db.execute(
|
||||
select(Character).where(Character.id == character_id)
|
||||
)
|
||||
character = result.scalar_one_or_none()
|
||||
|
||||
if not character:
|
||||
raise HTTPException(status_code=404, detail="角色不存在")
|
||||
|
||||
# 更新字段
|
||||
update_data = character_update.model_dump(exclude_unset=True)
|
||||
for field, value in update_data.items():
|
||||
setattr(character, field, value)
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(character)
|
||||
return character
|
||||
|
||||
|
||||
@router.delete("/{character_id}", summary="删除角色")
|
||||
async def delete_character(
|
||||
character_id: str,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""删除角色"""
|
||||
result = await db.execute(
|
||||
select(Character).where(Character.id == character_id)
|
||||
)
|
||||
character = result.scalar_one_or_none()
|
||||
|
||||
if not character:
|
||||
raise HTTPException(status_code=404, detail="角色不存在")
|
||||
|
||||
await db.delete(character)
|
||||
await db.commit()
|
||||
|
||||
return {"message": "角色删除成功"}
|
||||
|
||||
|
||||
@router.post("/generate", response_model=CharacterResponse, summary="AI生成角色")
|
||||
async def generate_character(
|
||||
request: CharacterGenerateRequest,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
使用AI生成角色卡
|
||||
|
||||
根据用户输入的信息,结合项目的世界观、主题等背景,
|
||||
AI会生成一个完整、详细的角色设定卡片。
|
||||
|
||||
生成内容包括:姓名、年龄、性别、性格、外貌、背景故事、人际关系等
|
||||
"""
|
||||
# 验证项目是否存在并获取项目信息
|
||||
result = await db.execute(
|
||||
select(Project).where(Project.id == request.project_id)
|
||||
)
|
||||
project = result.scalar_one_or_none()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="项目不存在")
|
||||
|
||||
try:
|
||||
# 获取已存在的角色列表,用于关系网络
|
||||
existing_chars_result = await db.execute(
|
||||
select(Character)
|
||||
.where(Character.project_id == request.project_id)
|
||||
.order_by(Character.created_at.desc())
|
||||
)
|
||||
existing_characters = existing_chars_result.scalars().all()
|
||||
|
||||
# 构建现有角色信息摘要(包含组织)
|
||||
existing_chars_info = ""
|
||||
character_list = []
|
||||
organization_list = []
|
||||
|
||||
if existing_characters:
|
||||
for c in existing_characters[:10]: # 最多显示10个
|
||||
if c.is_organization:
|
||||
organization_list.append(f"- {c.name} [{c.organization_type or '组织'}]")
|
||||
else:
|
||||
character_list.append(f"- {c.name}({c.role_type or '未知'})")
|
||||
|
||||
if character_list:
|
||||
existing_chars_info += "\n已有角色:\n" + "\n".join(character_list)
|
||||
if organization_list:
|
||||
existing_chars_info += "\n\n已有组织:\n" + "\n".join(organization_list)
|
||||
|
||||
# 构建项目上下文信息
|
||||
project_context = f"""
|
||||
项目信息:
|
||||
- 书名:{project.title}
|
||||
- 主题:{project.theme or '未设定'}
|
||||
- 类型:{project.genre or '未设定'}
|
||||
- 时间背景:{project.world_time_period or '未设定'}
|
||||
- 地理位置:{project.world_location or '未设定'}
|
||||
- 氛围基调:{project.world_atmosphere or '未设定'}
|
||||
- 世界规则:{project.world_rules or '未设定'}
|
||||
{existing_chars_info}
|
||||
"""
|
||||
|
||||
# 构建用户输入信息
|
||||
user_input = f"""
|
||||
用户要求:
|
||||
- 角色名称:{request.name or '请AI生成'}
|
||||
- 角色定位:{request.role_type or 'supporting'}(protagonist=主角, supporting=配角, antagonist=反派)
|
||||
- 背景设定:{request.background or '无特殊要求'}
|
||||
- 其他要求:{request.requirements or '无'}
|
||||
"""
|
||||
|
||||
# 使用统一的提示词服务
|
||||
prompt = prompt_service.get_single_character_prompt(
|
||||
project_context=project_context,
|
||||
user_input=user_input
|
||||
)
|
||||
|
||||
# 调用AI生成角色
|
||||
logger.info(f"🎯 开始为项目 {request.project_id} 生成角色")
|
||||
logger.info(f" - 角色名:{request.name or 'AI生成'}")
|
||||
logger.info(f" - 角色定位:{request.role_type}")
|
||||
logger.info(f" - 背景设定:{request.background or '无'}")
|
||||
logger.info(f" - AI提供商:{request.provider or 'default'}")
|
||||
logger.info(f" - AI模型:{request.model or 'default'}")
|
||||
logger.info(f" - Prompt长度:{len(prompt)} 字符")
|
||||
|
||||
try:
|
||||
ai_response = await ai_service.generate_text(
|
||||
prompt=prompt,
|
||||
provider=request.provider,
|
||||
model=request.model
|
||||
)
|
||||
logger.info(f"✅ AI响应接收完成,长度:{len(ai_response) if ai_response else 0} 字符")
|
||||
except Exception as ai_error:
|
||||
logger.error(f"❌ AI服务调用异常:{str(ai_error)}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"AI服务调用失败:{str(ai_error)}"
|
||||
)
|
||||
|
||||
# 检查AI响应
|
||||
if not ai_response or not ai_response.strip():
|
||||
logger.error("❌ AI返回了空响应")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="AI服务返回空响应。可能原因:1) API配置错误 2) 模型不支持 3) 网络问题。请检查后端日志。"
|
||||
)
|
||||
|
||||
logger.info(f"📝 开始清理AI响应")
|
||||
# 清理AI响应,移除可能的markdown标记
|
||||
cleaned_response = ai_response.strip()
|
||||
original_length = len(cleaned_response)
|
||||
|
||||
if cleaned_response.startswith("```json"):
|
||||
cleaned_response = cleaned_response[7:]
|
||||
logger.info(" - 移除了 ```json 标记")
|
||||
if cleaned_response.startswith("```"):
|
||||
cleaned_response = cleaned_response[3:]
|
||||
logger.info(" - 移除了 ``` 标记")
|
||||
if cleaned_response.endswith("```"):
|
||||
cleaned_response = cleaned_response[:-3]
|
||||
logger.info(" - 移除了末尾 ``` 标记")
|
||||
cleaned_response = cleaned_response.strip()
|
||||
|
||||
logger.info(f" - 清理前长度:{original_length},清理后长度:{len(cleaned_response)}")
|
||||
logger.info(f" - 清理后内容预览(前300字符):{cleaned_response[:300]}")
|
||||
|
||||
# 解析AI响应
|
||||
logger.info(f"🔍 开始解析JSON")
|
||||
try:
|
||||
character_data = json.loads(cleaned_response)
|
||||
logger.info(f"✅ JSON解析成功")
|
||||
logger.info(f" - 解析后的字段:{list(character_data.keys())}")
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"❌ JSON解析失败")
|
||||
logger.error(f" - 错误位置:line {e.lineno}, column {e.colno}")
|
||||
logger.error(f" - 错误信息:{str(e)}")
|
||||
logger.error(f" - 完整响应内容(前1000字符):{cleaned_response[:1000]}")
|
||||
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"AI返回的内容无法解析为JSON。错误:{str(e)}。响应内容已记录到日志,请查看后端日志排查。"
|
||||
)
|
||||
|
||||
# 转换traits为JSON字符串
|
||||
traits_json = json.dumps(character_data.get("traits", []), ensure_ascii=False) if character_data.get("traits") else None
|
||||
|
||||
# 判断是否为组织
|
||||
is_organization = character_data.get("is_organization", False)
|
||||
|
||||
# 创建角色
|
||||
character = Character(
|
||||
project_id=request.project_id,
|
||||
name=character_data.get("name", request.name or "未命名角色"),
|
||||
age=str(character_data.get("age", "")),
|
||||
gender=character_data.get("gender"),
|
||||
is_organization=is_organization,
|
||||
role_type=request.role_type or "supporting",
|
||||
personality=character_data.get("personality", ""),
|
||||
background=character_data.get("background", ""),
|
||||
appearance=character_data.get("appearance", ""),
|
||||
relationships=character_data.get("relationships_text", character_data.get("relationships", "")), # 优先使用文本描述
|
||||
organization_type=character_data.get("organization_type") if is_organization else None,
|
||||
organization_purpose=character_data.get("organization_purpose") if is_organization else None,
|
||||
organization_members=json.dumps(character_data.get("organization_members", []), ensure_ascii=False) if is_organization else None,
|
||||
traits=traits_json
|
||||
)
|
||||
db.add(character)
|
||||
await db.flush() # 获取character.id
|
||||
|
||||
logger.info(f"✅ 角色创建成功:{character.name} (ID: {character.id}, 是否组织: {is_organization})")
|
||||
|
||||
# 如果是组织,自动创建Organization详情记录
|
||||
if is_organization:
|
||||
org_check = await db.execute(
|
||||
select(Organization).where(Organization.character_id == character.id)
|
||||
)
|
||||
existing_org = org_check.scalar_one_or_none()
|
||||
|
||||
if not existing_org:
|
||||
organization = Organization(
|
||||
character_id=character.id,
|
||||
project_id=request.project_id,
|
||||
member_count=0,
|
||||
power_level=character_data.get("power_level", 50),
|
||||
location=character_data.get("location"),
|
||||
motto=character_data.get("motto")
|
||||
)
|
||||
db.add(organization)
|
||||
await db.flush()
|
||||
logger.info(f"✅ 自动创建组织详情:{character.name} (Org ID: {organization.id})")
|
||||
else:
|
||||
logger.info(f"ℹ️ 组织详情已存在:{character.name}")
|
||||
|
||||
# 处理结构化关系数据(仅针对非组织角色)
|
||||
if not is_organization:
|
||||
relationships_data = character_data.get("relationships", [])
|
||||
if relationships_data and isinstance(relationships_data, list):
|
||||
logger.info(f"📊 开始处理 {len(relationships_data)} 条关系数据")
|
||||
created_rels = 0
|
||||
|
||||
for rel in relationships_data:
|
||||
try:
|
||||
target_name = rel.get("target_character_name")
|
||||
if not target_name:
|
||||
logger.debug(f" ⚠️ 关系缺少target_character_name,跳过")
|
||||
continue
|
||||
|
||||
target_result = await db.execute(
|
||||
select(Character).where(
|
||||
Character.project_id == request.project_id,
|
||||
Character.name == target_name
|
||||
)
|
||||
)
|
||||
target_char = target_result.scalar_one_or_none()
|
||||
|
||||
if target_char:
|
||||
# 检查是否已存在相同关系
|
||||
existing_rel = await db.execute(
|
||||
select(CharacterRelationship).where(
|
||||
CharacterRelationship.project_id == request.project_id,
|
||||
CharacterRelationship.character_from_id == character.id,
|
||||
CharacterRelationship.character_to_id == target_char.id
|
||||
)
|
||||
)
|
||||
if existing_rel.scalar_one_or_none():
|
||||
logger.debug(f" ℹ️ 关系已存在:{character.name} -> {target_name}")
|
||||
continue
|
||||
|
||||
relationship = CharacterRelationship(
|
||||
project_id=request.project_id,
|
||||
character_from_id=character.id,
|
||||
character_to_id=target_char.id,
|
||||
relationship_name=rel.get("relationship_type", "未知关系"),
|
||||
intimacy_level=rel.get("intimacy_level", 50),
|
||||
description=rel.get("description", ""),
|
||||
started_at=rel.get("started_at"),
|
||||
source="ai"
|
||||
)
|
||||
|
||||
# 匹配预定义关系类型
|
||||
rel_type_result = await db.execute(
|
||||
select(RelationshipType).where(
|
||||
RelationshipType.name == rel.get("relationship_type")
|
||||
)
|
||||
)
|
||||
rel_type = rel_type_result.scalar_one_or_none()
|
||||
if rel_type:
|
||||
relationship.relationship_type_id = rel_type.id
|
||||
|
||||
db.add(relationship)
|
||||
created_rels += 1
|
||||
logger.info(f" ✅ 创建关系:{character.name} -> {target_name} ({rel.get('relationship_type')})")
|
||||
else:
|
||||
logger.warning(f" ⚠️ 目标角色不存在:{target_name}")
|
||||
|
||||
except Exception as rel_error:
|
||||
logger.warning(f" ❌ 创建关系失败:{str(rel_error)}")
|
||||
continue
|
||||
|
||||
logger.info(f"✅ 成功创建 {created_rels} 条关系记录")
|
||||
|
||||
# 处理组织成员关系(仅针对非组织角色)
|
||||
if not is_organization:
|
||||
org_memberships = character_data.get("organization_memberships", [])
|
||||
if org_memberships and isinstance(org_memberships, list):
|
||||
logger.info(f"🏢 开始处理 {len(org_memberships)} 条组织成员关系")
|
||||
created_members = 0
|
||||
|
||||
for membership in org_memberships:
|
||||
try:
|
||||
org_name = membership.get("organization_name")
|
||||
if not org_name:
|
||||
logger.debug(f" ⚠️ 组织成员关系缺少organization_name,跳过")
|
||||
continue
|
||||
|
||||
org_char_result = await db.execute(
|
||||
select(Character).where(
|
||||
Character.project_id == request.project_id,
|
||||
Character.name == org_name,
|
||||
Character.is_organization == True
|
||||
)
|
||||
)
|
||||
org_char = org_char_result.scalar_one_or_none()
|
||||
|
||||
if org_char:
|
||||
# 获取或创建Organization记录
|
||||
org_result = await db.execute(
|
||||
select(Organization).where(Organization.character_id == org_char.id)
|
||||
)
|
||||
org = org_result.scalar_one_or_none()
|
||||
|
||||
if not org:
|
||||
# 如果组织Character存在但Organization不存在,自动创建
|
||||
org = Organization(
|
||||
character_id=org_char.id,
|
||||
project_id=request.project_id,
|
||||
member_count=0
|
||||
)
|
||||
db.add(org)
|
||||
await db.flush()
|
||||
logger.info(f" ℹ️ 自动创建缺失的组织详情:{org_name}")
|
||||
|
||||
# 检查是否已存在成员关系
|
||||
existing_member = await db.execute(
|
||||
select(OrganizationMember).where(
|
||||
OrganizationMember.organization_id == org.id,
|
||||
OrganizationMember.character_id == character.id
|
||||
)
|
||||
)
|
||||
if existing_member.scalar_one_or_none():
|
||||
logger.debug(f" ℹ️ 成员关系已存在:{character.name} -> {org_name}")
|
||||
continue
|
||||
|
||||
# 创建成员关系
|
||||
member = OrganizationMember(
|
||||
organization_id=org.id,
|
||||
character_id=character.id,
|
||||
position=membership.get("position", "成员"),
|
||||
rank=membership.get("rank", 0),
|
||||
loyalty=membership.get("loyalty", 50),
|
||||
joined_at=membership.get("joined_at"),
|
||||
status=membership.get("status", "active"),
|
||||
source="ai"
|
||||
)
|
||||
db.add(member)
|
||||
|
||||
# 更新组织成员计数
|
||||
org.member_count += 1
|
||||
|
||||
created_members += 1
|
||||
logger.info(f" ✅ 添加成员:{character.name} -> {org_name} ({membership.get('position')})")
|
||||
else:
|
||||
logger.warning(f" ⚠️ 组织不存在:{org_name}")
|
||||
|
||||
except Exception as org_error:
|
||||
logger.warning(f" ❌ 添加组织成员失败:{str(org_error)}")
|
||||
continue
|
||||
|
||||
logger.info(f"✅ 成功创建 {created_members} 条组织成员记录")
|
||||
|
||||
# 记录生成历史
|
||||
history = GenerationHistory(
|
||||
project_id=request.project_id,
|
||||
prompt=prompt,
|
||||
generated_content=ai_response,
|
||||
model=request.model or "default"
|
||||
)
|
||||
db.add(history)
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(character)
|
||||
|
||||
logger.info(f"🎉 成功为项目 {request.project_id} 生成角色: {character.name}")
|
||||
|
||||
return character
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"生成角色失败: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"生成角色失败: {str(e)}")
|
||||
@@ -0,0 +1,341 @@
|
||||
"""组织管理API"""
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, and_
|
||||
from typing import List
|
||||
|
||||
from app.database import get_db
|
||||
from app.models.relationship import Organization, OrganizationMember
|
||||
from app.models.character import Character
|
||||
from app.schemas.relationship import (
|
||||
OrganizationCreate,
|
||||
OrganizationUpdate,
|
||||
OrganizationResponse,
|
||||
OrganizationDetailResponse,
|
||||
OrganizationMemberCreate,
|
||||
OrganizationMemberUpdate,
|
||||
OrganizationMemberResponse,
|
||||
OrganizationMemberDetailResponse
|
||||
)
|
||||
from app.logger import get_logger
|
||||
|
||||
router = APIRouter(prefix="/organizations", tags=["组织管理"])
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@router.get("/project/{project_id}", response_model=List[OrganizationDetailResponse], summary="获取项目的所有组织")
|
||||
async def get_project_organizations(
|
||||
project_id: str,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
获取项目中的所有组织及其详情
|
||||
|
||||
返回组织的基本信息和统计数据
|
||||
"""
|
||||
result = await db.execute(
|
||||
select(Organization).where(Organization.project_id == project_id)
|
||||
)
|
||||
organizations = result.scalars().all()
|
||||
|
||||
# 获取每个组织的角色信息
|
||||
org_list = []
|
||||
for org in organizations:
|
||||
char_result = await db.execute(
|
||||
select(Character).where(Character.id == org.character_id)
|
||||
)
|
||||
char = char_result.scalar_one_or_none()
|
||||
|
||||
if char:
|
||||
org_list.append(OrganizationDetailResponse(
|
||||
id=org.id,
|
||||
character_id=org.character_id,
|
||||
name=char.name,
|
||||
type=char.organization_type,
|
||||
purpose=char.organization_purpose,
|
||||
member_count=org.member_count,
|
||||
power_level=org.power_level,
|
||||
location=org.location,
|
||||
motto=org.motto,
|
||||
color=org.color
|
||||
))
|
||||
|
||||
logger.info(f"获取项目 {project_id} 的组织列表,共 {len(org_list)} 个")
|
||||
return org_list
|
||||
|
||||
|
||||
@router.get("/{org_id}", response_model=OrganizationResponse, summary="获取组织详情")
|
||||
async def get_organization(
|
||||
org_id: str,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""获取组织的详细信息"""
|
||||
result = await db.execute(
|
||||
select(Organization).where(Organization.id == org_id)
|
||||
)
|
||||
org = result.scalar_one_or_none()
|
||||
|
||||
if not org:
|
||||
raise HTTPException(status_code=404, detail="组织不存在")
|
||||
|
||||
return org
|
||||
|
||||
|
||||
@router.post("/", response_model=OrganizationResponse, summary="创建组织")
|
||||
async def create_organization(
|
||||
organization: OrganizationCreate,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
创建新组织
|
||||
|
||||
- 需要关联到一个已存在的角色记录(is_organization=True)
|
||||
- 可以设置父组织、势力等级等属性
|
||||
"""
|
||||
# 验证角色是否存在且是组织
|
||||
char_result = await db.execute(
|
||||
select(Character).where(Character.id == organization.character_id)
|
||||
)
|
||||
char = char_result.scalar_one_or_none()
|
||||
|
||||
if not char:
|
||||
raise HTTPException(status_code=404, detail="关联的角色不存在")
|
||||
if not char.is_organization:
|
||||
raise HTTPException(status_code=400, detail="关联的角色不是组织类型")
|
||||
|
||||
# 检查是否已存在
|
||||
existing = await db.execute(
|
||||
select(Organization).where(Organization.character_id == organization.character_id)
|
||||
)
|
||||
if existing.scalar_one_or_none():
|
||||
raise HTTPException(status_code=400, detail="该角色已有组织详情记录")
|
||||
|
||||
# 创建组织
|
||||
db_org = Organization(**organization.model_dump())
|
||||
db.add(db_org)
|
||||
await db.commit()
|
||||
await db.refresh(db_org)
|
||||
|
||||
logger.info(f"创建组织成功:{db_org.id} - {char.name}")
|
||||
return db_org
|
||||
|
||||
|
||||
@router.put("/{org_id}", response_model=OrganizationResponse, summary="更新组织")
|
||||
async def update_organization(
|
||||
org_id: str,
|
||||
organization: OrganizationUpdate,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""更新组织的属性"""
|
||||
result = await db.execute(
|
||||
select(Organization).where(Organization.id == org_id)
|
||||
)
|
||||
db_org = result.scalar_one_or_none()
|
||||
|
||||
if not db_org:
|
||||
raise HTTPException(status_code=404, detail="组织不存在")
|
||||
|
||||
# 更新字段
|
||||
update_data = organization.model_dump(exclude_unset=True)
|
||||
for field, value in update_data.items():
|
||||
setattr(db_org, field, value)
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(db_org)
|
||||
|
||||
logger.info(f"更新组织成功:{org_id}")
|
||||
return db_org
|
||||
|
||||
|
||||
@router.delete("/{org_id}", summary="删除组织")
|
||||
async def delete_organization(
|
||||
org_id: str,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""删除组织(会级联删除所有成员关系)"""
|
||||
result = await db.execute(
|
||||
select(Organization).where(Organization.id == org_id)
|
||||
)
|
||||
db_org = result.scalar_one_or_none()
|
||||
|
||||
if not db_org:
|
||||
raise HTTPException(status_code=404, detail="组织不存在")
|
||||
|
||||
await db.delete(db_org)
|
||||
await db.commit()
|
||||
|
||||
logger.info(f"删除组织成功:{org_id}")
|
||||
return {"message": "组织删除成功", "id": org_id}
|
||||
|
||||
|
||||
# ============ 组织成员管理 ============
|
||||
|
||||
@router.get("/{org_id}/members", response_model=List[OrganizationMemberDetailResponse], summary="获取组织成员")
|
||||
async def get_organization_members(
|
||||
org_id: str,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
获取组织的所有成员
|
||||
|
||||
按职位等级(rank)降序排列
|
||||
"""
|
||||
# 验证组织存在
|
||||
org_result = await db.execute(
|
||||
select(Organization).where(Organization.id == org_id)
|
||||
)
|
||||
if not org_result.scalar_one_or_none():
|
||||
raise HTTPException(status_code=404, detail="组织不存在")
|
||||
|
||||
# 获取成员列表
|
||||
result = await db.execute(
|
||||
select(OrganizationMember)
|
||||
.where(OrganizationMember.organization_id == org_id)
|
||||
.order_by(OrganizationMember.rank.desc(), OrganizationMember.created_at)
|
||||
)
|
||||
members = result.scalars().all()
|
||||
|
||||
# 获取成员角色信息
|
||||
member_list = []
|
||||
for member in members:
|
||||
char_result = await db.execute(
|
||||
select(Character).where(Character.id == member.character_id)
|
||||
)
|
||||
char = char_result.scalar_one_or_none()
|
||||
|
||||
if char:
|
||||
member_list.append(OrganizationMemberDetailResponse(
|
||||
id=member.id,
|
||||
character_id=member.character_id,
|
||||
character_name=char.name,
|
||||
position=member.position,
|
||||
rank=member.rank,
|
||||
loyalty=member.loyalty,
|
||||
contribution=member.contribution,
|
||||
status=member.status,
|
||||
joined_at=member.joined_at,
|
||||
left_at=member.left_at,
|
||||
notes=member.notes
|
||||
))
|
||||
|
||||
logger.info(f"获取组织 {org_id} 的成员列表,共 {len(member_list)} 人")
|
||||
return member_list
|
||||
|
||||
|
||||
@router.post("/{org_id}/members", response_model=OrganizationMemberResponse, summary="添加组织成员")
|
||||
async def add_organization_member(
|
||||
org_id: str,
|
||||
member: OrganizationMemberCreate,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
添加角色到组织
|
||||
|
||||
- 一个角色在同一组织中只能有一个职位
|
||||
- 会自动更新组织的成员计数
|
||||
"""
|
||||
# 验证组织存在
|
||||
org_result = await db.execute(
|
||||
select(Organization).where(Organization.id == org_id)
|
||||
)
|
||||
org = org_result.scalar_one_or_none()
|
||||
if not org:
|
||||
raise HTTPException(status_code=404, detail="组织不存在")
|
||||
|
||||
# 验证角色存在
|
||||
char_result = await db.execute(
|
||||
select(Character).where(Character.id == member.character_id)
|
||||
)
|
||||
char = char_result.scalar_one_or_none()
|
||||
if not char:
|
||||
raise HTTPException(status_code=404, detail="角色不存在")
|
||||
if char.is_organization:
|
||||
raise HTTPException(status_code=400, detail="不能将组织添加为成员")
|
||||
|
||||
# 检查是否已存在
|
||||
existing = await db.execute(
|
||||
select(OrganizationMember).where(
|
||||
and_(
|
||||
OrganizationMember.organization_id == org_id,
|
||||
OrganizationMember.character_id == member.character_id
|
||||
)
|
||||
)
|
||||
)
|
||||
if existing.scalar_one_or_none():
|
||||
raise HTTPException(status_code=400, detail="该角色已在组织中")
|
||||
|
||||
# 创建成员关系
|
||||
db_member = OrganizationMember(
|
||||
organization_id=org_id,
|
||||
**member.model_dump(),
|
||||
source="manual"
|
||||
)
|
||||
db.add(db_member)
|
||||
|
||||
# 更新组织成员计数
|
||||
org.member_count += 1
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(db_member)
|
||||
|
||||
logger.info(f"添加成员成功:{char.name} 加入组织 {org_id}")
|
||||
return db_member
|
||||
|
||||
|
||||
@router.put("/members/{member_id}", response_model=OrganizationMemberResponse, summary="更新成员信息")
|
||||
async def update_organization_member(
|
||||
member_id: str,
|
||||
member: OrganizationMemberUpdate,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""更新组织成员的职位、忠诚度等信息"""
|
||||
result = await db.execute(
|
||||
select(OrganizationMember).where(OrganizationMember.id == member_id)
|
||||
)
|
||||
db_member = result.scalar_one_or_none()
|
||||
|
||||
if not db_member:
|
||||
raise HTTPException(status_code=404, detail="成员记录不存在")
|
||||
|
||||
# 更新字段
|
||||
update_data = member.model_dump(exclude_unset=True)
|
||||
for field, value in update_data.items():
|
||||
setattr(db_member, field, value)
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(db_member)
|
||||
|
||||
logger.info(f"更新成员信息成功:{member_id}")
|
||||
return db_member
|
||||
|
||||
|
||||
@router.delete("/members/{member_id}", summary="移除组织成员")
|
||||
async def remove_organization_member(
|
||||
member_id: str,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
从组织中移除成员
|
||||
|
||||
会自动更新组织的成员计数
|
||||
"""
|
||||
result = await db.execute(
|
||||
select(OrganizationMember).where(OrganizationMember.id == member_id)
|
||||
)
|
||||
db_member = result.scalar_one_or_none()
|
||||
|
||||
if not db_member:
|
||||
raise HTTPException(status_code=404, detail="成员记录不存在")
|
||||
|
||||
# 更新组织成员计数
|
||||
org_result = await db.execute(
|
||||
select(Organization).where(Organization.id == db_member.organization_id)
|
||||
)
|
||||
org = org_result.scalar_one()
|
||||
org.member_count = max(0, org.member_count - 1)
|
||||
|
||||
await db.delete(db_member)
|
||||
await db.commit()
|
||||
|
||||
logger.info(f"移除成员成功:{member_id}")
|
||||
return {"message": "成员移除成功", "id": member_id}
|
||||
@@ -0,0 +1,657 @@
|
||||
"""大纲管理API"""
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, func, delete
|
||||
from typing import List
|
||||
import json
|
||||
|
||||
from app.database import get_db
|
||||
from app.models.outline import Outline
|
||||
from app.models.project import Project
|
||||
from app.models.chapter import Chapter
|
||||
from app.models.character import Character
|
||||
from app.models.generation_history import GenerationHistory
|
||||
from app.schemas.outline import (
|
||||
OutlineCreate,
|
||||
OutlineUpdate,
|
||||
OutlineResponse,
|
||||
OutlineListResponse,
|
||||
OutlineGenerateRequest,
|
||||
OutlineReorderRequest
|
||||
)
|
||||
from app.services.ai_service import ai_service
|
||||
from app.services.prompt_service import prompt_service
|
||||
from app.logger import get_logger
|
||||
|
||||
router = APIRouter(prefix="/outlines", tags=["大纲管理"])
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@router.post("", response_model=OutlineResponse, summary="创建大纲")
|
||||
async def create_outline(
|
||||
outline: OutlineCreate,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""创建新的章节大纲,同时创建对应的章节记录"""
|
||||
# 验证项目是否存在
|
||||
result = await db.execute(
|
||||
select(Project).where(Project.id == outline.project_id)
|
||||
)
|
||||
project = result.scalar_one_or_none()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="项目不存在")
|
||||
|
||||
# 创建大纲
|
||||
db_outline = Outline(**outline.model_dump())
|
||||
db.add(db_outline)
|
||||
|
||||
# 同步创建对应的章节记录
|
||||
chapter = Chapter(
|
||||
project_id=outline.project_id,
|
||||
chapter_number=outline.order_index,
|
||||
title=outline.title,
|
||||
summary=outline.content[:500] if len(outline.content) > 500 else outline.content,
|
||||
status="draft"
|
||||
)
|
||||
db.add(chapter)
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(db_outline)
|
||||
return db_outline
|
||||
|
||||
|
||||
@router.get("", response_model=OutlineListResponse, summary="获取大纲列表")
|
||||
async def get_outlines(
|
||||
project_id: str,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""获取指定项目的所有大纲"""
|
||||
# 获取总数
|
||||
count_result = await db.execute(
|
||||
select(func.count(Outline.id)).where(Outline.project_id == project_id)
|
||||
)
|
||||
total = count_result.scalar_one()
|
||||
|
||||
# 获取大纲列表
|
||||
result = await db.execute(
|
||||
select(Outline)
|
||||
.where(Outline.project_id == project_id)
|
||||
.order_by(Outline.order_index)
|
||||
)
|
||||
outlines = result.scalars().all()
|
||||
|
||||
return OutlineListResponse(total=total, items=outlines)
|
||||
|
||||
|
||||
@router.get("/project/{project_id}", response_model=OutlineListResponse, summary="获取项目的所有大纲")
|
||||
async def get_project_outlines(
|
||||
project_id: str,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""获取指定项目的所有大纲(路径参数版本)"""
|
||||
# 获取总数
|
||||
count_result = await db.execute(
|
||||
select(func.count(Outline.id)).where(Outline.project_id == project_id)
|
||||
)
|
||||
total = count_result.scalar_one()
|
||||
|
||||
# 获取大纲列表
|
||||
result = await db.execute(
|
||||
select(Outline)
|
||||
.where(Outline.project_id == project_id)
|
||||
.order_by(Outline.order_index)
|
||||
)
|
||||
outlines = result.scalars().all()
|
||||
|
||||
return OutlineListResponse(total=total, items=outlines)
|
||||
|
||||
|
||||
@router.get("/{outline_id}", response_model=OutlineResponse, summary="获取大纲详情")
|
||||
async def get_outline(
|
||||
outline_id: str,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""根据ID获取大纲详情"""
|
||||
result = await db.execute(
|
||||
select(Outline).where(Outline.id == outline_id)
|
||||
)
|
||||
outline = result.scalar_one_or_none()
|
||||
|
||||
if not outline:
|
||||
raise HTTPException(status_code=404, detail="大纲不存在")
|
||||
|
||||
return outline
|
||||
|
||||
|
||||
@router.put("/{outline_id}", response_model=OutlineResponse, summary="更新大纲")
|
||||
async def update_outline(
|
||||
outline_id: str,
|
||||
outline_update: OutlineUpdate,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""更新大纲信息,同步更新对应章节和structure字段"""
|
||||
result = await db.execute(
|
||||
select(Outline).where(Outline.id == outline_id)
|
||||
)
|
||||
outline = result.scalar_one_or_none()
|
||||
|
||||
if not outline:
|
||||
raise HTTPException(status_code=404, detail="大纲不存在")
|
||||
|
||||
# 更新字段
|
||||
update_data = outline_update.model_dump(exclude_unset=True)
|
||||
for field, value in update_data.items():
|
||||
setattr(outline, field, value)
|
||||
|
||||
# 如果修改了content或title,同步更新structure字段
|
||||
if 'content' in update_data or 'title' in update_data:
|
||||
try:
|
||||
# 尝试解析现有的structure
|
||||
if outline.structure:
|
||||
structure_data = json.loads(outline.structure)
|
||||
else:
|
||||
structure_data = {}
|
||||
|
||||
# 更新structure中的对应字段
|
||||
if 'title' in update_data:
|
||||
structure_data['title'] = outline.title
|
||||
if 'content' in update_data:
|
||||
structure_data['summary'] = outline.content
|
||||
structure_data['content'] = outline.content
|
||||
|
||||
# 保存更新后的structure
|
||||
outline.structure = json.dumps(structure_data, ensure_ascii=False)
|
||||
logger.info(f"同步更新大纲 {outline_id} 的structure字段")
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"大纲 {outline_id} 的structure字段格式错误,跳过更新")
|
||||
|
||||
# 同步更新对应的章节标题和摘要
|
||||
if 'title' in update_data or 'content' in update_data:
|
||||
chapter_result = await db.execute(
|
||||
select(Chapter).where(
|
||||
Chapter.project_id == outline.project_id,
|
||||
Chapter.chapter_number == outline.order_index
|
||||
)
|
||||
)
|
||||
chapter = chapter_result.scalar_one_or_none()
|
||||
|
||||
if chapter:
|
||||
if 'title' in update_data:
|
||||
chapter.title = outline.title
|
||||
if 'content' in update_data:
|
||||
# 更新章节摘要(取content前500字符)
|
||||
chapter.summary = outline.content[:500] if len(outline.content) > 500 else outline.content
|
||||
logger.info(f"同步更新章节 {chapter.id} 的标题和摘要")
|
||||
else:
|
||||
logger.warning(f"未找到对应的章节记录 (order_index={outline.order_index})")
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(outline)
|
||||
return outline
|
||||
|
||||
|
||||
@router.delete("/{outline_id}", summary="删除大纲")
|
||||
async def delete_outline(
|
||||
outline_id: str,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""删除大纲,同步删除章节,并重新排序后续项"""
|
||||
result = await db.execute(
|
||||
select(Outline).where(Outline.id == outline_id)
|
||||
)
|
||||
outline = result.scalar_one_or_none()
|
||||
|
||||
if not outline:
|
||||
raise HTTPException(status_code=404, detail="大纲不存在")
|
||||
|
||||
project_id = outline.project_id
|
||||
deleted_order = outline.order_index
|
||||
|
||||
# 删除对应的章节
|
||||
await db.execute(
|
||||
delete(Chapter).where(
|
||||
Chapter.project_id == project_id,
|
||||
Chapter.chapter_number == deleted_order
|
||||
)
|
||||
)
|
||||
|
||||
# 删除大纲
|
||||
await db.delete(outline)
|
||||
|
||||
# 重新排序后续的大纲和章节(序号-1)
|
||||
result = await db.execute(
|
||||
select(Outline).where(
|
||||
Outline.project_id == project_id,
|
||||
Outline.order_index > deleted_order
|
||||
)
|
||||
)
|
||||
subsequent_outlines = result.scalars().all()
|
||||
|
||||
for o in subsequent_outlines:
|
||||
old_order = o.order_index
|
||||
o.order_index -= 1
|
||||
|
||||
# 同步更新对应的章节
|
||||
chapter_result = await db.execute(
|
||||
select(Chapter).where(
|
||||
Chapter.project_id == project_id,
|
||||
Chapter.chapter_number == old_order
|
||||
)
|
||||
)
|
||||
chapter = chapter_result.scalar_one_or_none()
|
||||
if chapter:
|
||||
chapter.chapter_number = old_order - 1
|
||||
|
||||
await db.commit()
|
||||
|
||||
return {"message": "大纲删除成功"}
|
||||
|
||||
|
||||
@router.post("/reorder", summary="批量重排序大纲")
|
||||
async def reorder_outlines(
|
||||
request: OutlineReorderRequest,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
批量调整大纲顺序,同步更新章节序号
|
||||
|
||||
策略:先收集所有变更,最后一次性提交,避免临时冲突
|
||||
"""
|
||||
try:
|
||||
# 第一步:收集所有大纲和对应的章节
|
||||
outline_chapter_map = {} # {outline_id: (outline, chapter, old_order, new_order)}
|
||||
|
||||
for item in request.orders:
|
||||
outline_id = item.id
|
||||
new_order = item.order_index
|
||||
|
||||
# 获取大纲
|
||||
result = await db.execute(
|
||||
select(Outline).where(Outline.id == outline_id)
|
||||
)
|
||||
outline = result.scalar_one_or_none()
|
||||
|
||||
if not outline:
|
||||
logger.warning(f"大纲 {outline_id} 不存在,跳过")
|
||||
continue
|
||||
|
||||
old_order = outline.order_index
|
||||
|
||||
# 获取对应的章节(通过旧的chapter_number匹配)
|
||||
chapter_result = await db.execute(
|
||||
select(Chapter).where(
|
||||
Chapter.project_id == outline.project_id,
|
||||
Chapter.chapter_number == old_order
|
||||
)
|
||||
)
|
||||
chapter = chapter_result.first()
|
||||
chapter_obj = chapter[0] if chapter else None
|
||||
|
||||
outline_chapter_map[outline_id] = (outline, chapter_obj, old_order, new_order)
|
||||
|
||||
# 第二步:批量更新所有大纲和章节
|
||||
updated_outlines = 0
|
||||
updated_chapters = 0
|
||||
|
||||
for outline_id, (outline, chapter, old_order, new_order) in outline_chapter_map.items():
|
||||
# 更新大纲
|
||||
outline.order_index = new_order
|
||||
updated_outlines += 1
|
||||
|
||||
# 更新章节
|
||||
if chapter:
|
||||
chapter.chapter_number = new_order
|
||||
chapter.title = outline.title # 同步更新标题
|
||||
updated_chapters += 1
|
||||
else:
|
||||
logger.warning(f"章节 {old_order} 不存在,跳过")
|
||||
|
||||
# 第三步:一次性提交所有更改
|
||||
await db.commit()
|
||||
|
||||
logger.info(f"重排序成功:更新了 {updated_outlines} 个大纲,{updated_chapters} 个章节")
|
||||
|
||||
return {
|
||||
"message": "重排序成功",
|
||||
"updated_outlines": updated_outlines,
|
||||
"updated_chapters": updated_chapters
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"重排序失败: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"重排序失败: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/generate", response_model=OutlineListResponse, summary="AI生成/续写大纲")
|
||||
async def generate_outline(
|
||||
request: OutlineGenerateRequest,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
使用AI生成或续写小说大纲 - 智能模式
|
||||
|
||||
支持三种模式:
|
||||
- auto: 自动判断(无大纲→新建,有大纲→续写)
|
||||
- new: 强制全新生成
|
||||
- continue: 强制续写模式
|
||||
"""
|
||||
# 验证项目是否存在
|
||||
result = await db.execute(
|
||||
select(Project).where(Project.id == request.project_id)
|
||||
)
|
||||
project = result.scalar_one_or_none()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="项目不存在")
|
||||
|
||||
try:
|
||||
# 获取现有大纲(强制从数据库获取最新数据,包括用户手动修改的内容)
|
||||
existing_result = await db.execute(
|
||||
select(Outline)
|
||||
.where(Outline.project_id == request.project_id)
|
||||
.order_by(Outline.order_index)
|
||||
.execution_options(populate_existing=True)
|
||||
)
|
||||
existing_outlines = existing_result.scalars().all()
|
||||
|
||||
# 判断实际执行模式
|
||||
actual_mode = request.mode
|
||||
if actual_mode == "auto":
|
||||
actual_mode = "continue" if existing_outlines else "new"
|
||||
logger.info(f"自动判断模式:{'续写' if existing_outlines else '新建'}")
|
||||
|
||||
# 模式:全新生成
|
||||
if actual_mode == "new":
|
||||
return await _generate_new_outline(
|
||||
request, project, db
|
||||
)
|
||||
|
||||
# 模式:续写
|
||||
elif actual_mode == "continue":
|
||||
if not existing_outlines:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="续写模式需要已有大纲,当前项目没有大纲"
|
||||
)
|
||||
|
||||
return await _continue_outline(
|
||||
request, project, existing_outlines, db
|
||||
)
|
||||
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"不支持的模式: {request.mode}"
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"生成大纲失败: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"生成大纲失败: {str(e)}")
|
||||
|
||||
|
||||
async def _generate_new_outline(
|
||||
request: OutlineGenerateRequest,
|
||||
project: Project,
|
||||
db: AsyncSession
|
||||
) -> OutlineListResponse:
|
||||
"""全新生成大纲"""
|
||||
logger.info(f"全新生成大纲 - 项目: {project.id}, keep_existing: {request.keep_existing}")
|
||||
|
||||
# 获取角色信息
|
||||
characters_result = await db.execute(
|
||||
select(Character).where(Character.project_id == project.id)
|
||||
)
|
||||
characters = characters_result.scalars().all()
|
||||
characters_info = "\n".join([
|
||||
f"- {char.name} ({'组织' if char.is_organization else '角色'}, {char.role_type}): "
|
||||
f"{char.personality[:100] if char.personality else '暂无描述'}"
|
||||
for char in characters
|
||||
])
|
||||
|
||||
# 使用完整提示词
|
||||
prompt = prompt_service.get_complete_outline_prompt(
|
||||
title=project.title,
|
||||
theme=request.theme or project.theme or "未设定",
|
||||
genre=request.genre or project.genre or "通用",
|
||||
chapter_count=request.chapter_count,
|
||||
narrative_perspective=request.narrative_perspective,
|
||||
target_words=request.target_words,
|
||||
time_period=project.world_time_period or "未设定",
|
||||
location=project.world_location or "未设定",
|
||||
atmosphere=project.world_atmosphere or "未设定",
|
||||
rules=project.world_rules or "未设定",
|
||||
characters_info=characters_info or "暂无角色信息",
|
||||
requirements=request.requirements or ""
|
||||
)
|
||||
|
||||
# 调用AI
|
||||
ai_response = await ai_service.generate_text(
|
||||
prompt=prompt,
|
||||
provider=request.provider,
|
||||
model=request.model
|
||||
)
|
||||
|
||||
# 解析响应
|
||||
outline_data = _parse_ai_response(ai_response)
|
||||
|
||||
# 全新生成模式:必须删除旧大纲和章节
|
||||
# 注意:这是"new"模式的核心逻辑,应该始终删除旧数据
|
||||
logger.info(f"删除项目 {project.id} 的旧大纲和章节")
|
||||
await db.execute(
|
||||
delete(Outline).where(Outline.project_id == project.id)
|
||||
)
|
||||
await db.execute(
|
||||
delete(Chapter).where(Chapter.project_id == project.id)
|
||||
)
|
||||
|
||||
# 保存新大纲
|
||||
outlines = await _save_outlines(
|
||||
project.id, outline_data, db, start_index=1
|
||||
)
|
||||
|
||||
# 记录历史
|
||||
history = GenerationHistory(
|
||||
project_id=project.id,
|
||||
prompt=prompt,
|
||||
generated_content=ai_response,
|
||||
model=request.model or "default"
|
||||
)
|
||||
db.add(history)
|
||||
|
||||
await db.commit()
|
||||
|
||||
for outline in outlines:
|
||||
await db.refresh(outline)
|
||||
|
||||
logger.info(f"全新生成完成 - {len(outlines)} 章")
|
||||
return OutlineListResponse(total=len(outlines), items=outlines)
|
||||
|
||||
|
||||
async def _continue_outline(
|
||||
request: OutlineGenerateRequest,
|
||||
project: Project,
|
||||
existing_outlines: List[Outline],
|
||||
db: AsyncSession
|
||||
) -> OutlineListResponse:
|
||||
"""续写大纲"""
|
||||
logger.info(f"续写大纲 - 项目: {project.id}, 已有: {len(existing_outlines)} 章")
|
||||
|
||||
# 分析已有大纲
|
||||
current_chapter_count = len(existing_outlines)
|
||||
last_chapter_number = existing_outlines[-1].order_index
|
||||
|
||||
# 获取最近2章的剧情
|
||||
recent_outlines = existing_outlines[-2:] if len(existing_outlines) >= 2 else existing_outlines
|
||||
recent_plot = "\n".join([
|
||||
f"第{o.order_index}章《{o.title}》: {o.content}"
|
||||
for o in recent_outlines
|
||||
])
|
||||
# logger.debug(f"最近三章内容:{recent_plot}")
|
||||
# 全部章节概览
|
||||
all_chapters_brief = "\n".join([
|
||||
f"第{o.order_index}章: {o.title}"
|
||||
for o in existing_outlines
|
||||
])
|
||||
|
||||
# 获取角色信息
|
||||
characters_result = await db.execute(
|
||||
select(Character).where(Character.project_id == project.id)
|
||||
)
|
||||
characters = characters_result.scalars().all()
|
||||
characters_info = "\n".join([
|
||||
f"- {char.name} ({'组织' if char.is_organization else '角色'}, {char.role_type}): "
|
||||
f"{char.personality[:100] if char.personality else '暂无描述'}"
|
||||
for char in characters
|
||||
])
|
||||
|
||||
# 情节阶段指导
|
||||
stage_instructions = {
|
||||
"development": "继续展开情节,深化角色关系,推进主线冲突",
|
||||
"climax": "进入故事高潮,矛盾激化,关键冲突爆发",
|
||||
"ending": "解决主要冲突,收束伏笔,给出结局"
|
||||
}
|
||||
stage_instruction = stage_instructions.get(request.plot_stage, "")
|
||||
|
||||
# 使用标准续写提示词模板
|
||||
prompt = prompt_service.get_outline_continue_prompt(
|
||||
title=project.title,
|
||||
theme=request.theme or project.theme or "未设定",
|
||||
genre=request.genre or project.genre or "通用",
|
||||
narrative_perspective=request.narrative_perspective,
|
||||
chapter_count=request.chapter_count,
|
||||
time_period=project.world_time_period or "未设定",
|
||||
location=project.world_location or "未设定",
|
||||
atmosphere=project.world_atmosphere or "未设定",
|
||||
rules=project.world_rules or "未设定",
|
||||
characters_info=characters_info or "暂无角色信息",
|
||||
current_chapter_count=current_chapter_count,
|
||||
all_chapters_brief=all_chapters_brief,
|
||||
recent_plot=recent_plot,
|
||||
plot_stage_instruction=stage_instruction,
|
||||
start_chapter=last_chapter_number + 1,
|
||||
story_direction=request.story_direction or "自然延续",
|
||||
requirements=request.requirements or ""
|
||||
)
|
||||
|
||||
# 调用AI
|
||||
ai_response = await ai_service.generate_text(
|
||||
prompt=prompt,
|
||||
provider=request.provider,
|
||||
model=request.model
|
||||
)
|
||||
|
||||
# 解析响应
|
||||
outline_data = _parse_ai_response(ai_response)
|
||||
|
||||
# 保存续写的大纲
|
||||
new_outlines = await _save_outlines(
|
||||
project.id, outline_data, db, start_index=last_chapter_number + 1
|
||||
)
|
||||
|
||||
# 记录历史
|
||||
history = GenerationHistory(
|
||||
project_id=project.id,
|
||||
prompt=prompt,
|
||||
generated_content=ai_response,
|
||||
model=request.model or "default"
|
||||
)
|
||||
db.add(history)
|
||||
|
||||
await db.commit()
|
||||
|
||||
for outline in new_outlines:
|
||||
await db.refresh(outline)
|
||||
|
||||
# 返回所有大纲(包括旧的和新的)
|
||||
all_result = await db.execute(
|
||||
select(Outline)
|
||||
.where(Outline.project_id == project.id)
|
||||
.order_by(Outline.order_index)
|
||||
)
|
||||
all_outlines = all_result.scalars().all()
|
||||
|
||||
logger.info(f"续写完成 - 新增 {len(new_outlines)} 章,总计 {len(all_outlines)} 章")
|
||||
return OutlineListResponse(total=len(all_outlines), items=all_outlines)
|
||||
|
||||
|
||||
def _parse_ai_response(ai_response: str) -> list:
|
||||
"""解析AI响应为章节数据列表"""
|
||||
try:
|
||||
# 清理响应文本
|
||||
cleaned_text = ai_response.strip()
|
||||
if cleaned_text.startswith('```json'):
|
||||
cleaned_text = cleaned_text[7:]
|
||||
if cleaned_text.startswith('```'):
|
||||
cleaned_text = cleaned_text[3:]
|
||||
if cleaned_text.endswith('```'):
|
||||
cleaned_text = cleaned_text[:-3]
|
||||
cleaned_text = cleaned_text.strip()
|
||||
|
||||
outline_data = json.loads(cleaned_text)
|
||||
|
||||
# 确保是列表格式
|
||||
if not isinstance(outline_data, list):
|
||||
# 如果是对象,尝试提取chapters字段
|
||||
if isinstance(outline_data, dict):
|
||||
outline_data = outline_data.get("chapters", [outline_data])
|
||||
else:
|
||||
outline_data = [outline_data]
|
||||
|
||||
return outline_data
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"AI响应解析失败: {e}")
|
||||
# 返回一个包含原始内容的章节
|
||||
return [{
|
||||
"title": "AI生成的大纲",
|
||||
"content": ai_response[:1000],
|
||||
"summary": ai_response[:1000]
|
||||
}]
|
||||
|
||||
|
||||
async def _save_outlines(
|
||||
project_id: str,
|
||||
outline_data: list,
|
||||
db: AsyncSession,
|
||||
start_index: int = 1
|
||||
) -> List[Outline]:
|
||||
"""保存大纲到数据库"""
|
||||
outlines = []
|
||||
|
||||
for idx, chapter_data in enumerate(outline_data):
|
||||
order_idx = chapter_data.get("chapter_number", start_index + idx)
|
||||
title = chapter_data.get("title", f"第{order_idx}章")
|
||||
|
||||
# 优先使用summary,其次content
|
||||
content = chapter_data.get("summary") or chapter_data.get("content", "")
|
||||
|
||||
# 如果有额外信息,添加到内容中
|
||||
if "key_events" in chapter_data:
|
||||
content += f"\n\n关键事件:" + "、".join(chapter_data["key_events"])
|
||||
if "characters_involved" in chapter_data:
|
||||
content += f"\n涉及角色:" + "、".join(chapter_data["characters_involved"])
|
||||
|
||||
# 创建大纲
|
||||
outline = Outline(
|
||||
project_id=project_id,
|
||||
title=title,
|
||||
content=content,
|
||||
structure=json.dumps(chapter_data, ensure_ascii=False),
|
||||
order_index=order_idx
|
||||
)
|
||||
db.add(outline)
|
||||
outlines.append(outline)
|
||||
|
||||
# 同步创建章节记录
|
||||
chapter = Chapter(
|
||||
project_id=project_id,
|
||||
chapter_number=order_idx,
|
||||
title=title,
|
||||
summary=content[:500] if len(content) > 500 else content,
|
||||
status="draft"
|
||||
)
|
||||
db.add(chapter)
|
||||
|
||||
return outlines
|
||||
@@ -0,0 +1,124 @@
|
||||
"""AI去味API - 核心特色功能"""
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.database import get_db
|
||||
from app.models.generation_history import GenerationHistory
|
||||
from app.schemas.polish import PolishRequest, PolishResponse
|
||||
from app.services.ai_service import ai_service
|
||||
from app.services.prompt_service import prompt_service
|
||||
from app.logger import get_logger
|
||||
|
||||
router = APIRouter(prefix="/polish", tags=["AI去味"])
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@router.post("", response_model=PolishResponse, summary="AI去味")
|
||||
async def polish_text(
|
||||
request: PolishRequest,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
AI去味 - 将AI生成的文本改写得更像人类作家的手笔
|
||||
|
||||
核心功能:
|
||||
- 去除AI痕迹(工整排比、重复修辞、机械总结)
|
||||
- 增加人性化(口语化、不完美细节、真实情感)
|
||||
- 优化叙事(自然节奏、简单词汇、松弛感)
|
||||
- 让对话更生活化
|
||||
|
||||
这是本项目的核心特色功能!
|
||||
"""
|
||||
try:
|
||||
# 构建AI去味提示词
|
||||
prompt = prompt_service.get_denoising_prompt(
|
||||
original_text=request.original_text
|
||||
)
|
||||
|
||||
logger.info(f"开始AI去味处理,原文长度: {len(request.original_text)}")
|
||||
|
||||
# 调用AI进行去味处理
|
||||
polished_text = await ai_service.generate_text(
|
||||
prompt=prompt,
|
||||
provider=request.provider,
|
||||
model=request.model,
|
||||
temperature=request.temperature,
|
||||
max_tokens=len(request.original_text) * 2 # 预留足够token
|
||||
)
|
||||
|
||||
# 计算字数
|
||||
word_count_before = len(request.original_text)
|
||||
word_count_after = len(polished_text)
|
||||
|
||||
logger.info(f"AI去味完成,处理后长度: {word_count_after}")
|
||||
|
||||
# 如果提供了项目ID,记录到历史
|
||||
if request.project_id:
|
||||
history = GenerationHistory(
|
||||
project_id=request.project_id,
|
||||
generation_type="polish",
|
||||
prompt=f"原文: {request.original_text[:100]}...",
|
||||
result=polished_text,
|
||||
provider=request.provider or "default",
|
||||
model=request.model or "default"
|
||||
)
|
||||
db.add(history)
|
||||
await db.commit()
|
||||
|
||||
return PolishResponse(
|
||||
original_text=request.original_text,
|
||||
polished_text=polished_text,
|
||||
word_count_before=word_count_before,
|
||||
word_count_after=word_count_after
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"AI去味失败: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"AI去味失败: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/batch", summary="批量AI去味")
|
||||
async def polish_batch(
|
||||
texts: list[str],
|
||||
project_id: int = None,
|
||||
provider: str = None,
|
||||
model: str = None,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
批量处理多个文本的AI去味
|
||||
|
||||
适用于一次性处理多个章节或段落
|
||||
"""
|
||||
try:
|
||||
results = []
|
||||
|
||||
for idx, text in enumerate(texts):
|
||||
logger.info(f"处理第 {idx+1}/{len(texts)} 个文本")
|
||||
|
||||
prompt = prompt_service.get_denoising_prompt(original_text=text)
|
||||
|
||||
polished_text = await ai_service.generate_text(
|
||||
prompt=prompt,
|
||||
provider=provider,
|
||||
model=model
|
||||
)
|
||||
|
||||
results.append({
|
||||
"index": idx,
|
||||
"original": text,
|
||||
"polished": polished_text,
|
||||
"word_count_before": len(text),
|
||||
"word_count_after": len(polished_text)
|
||||
})
|
||||
|
||||
logger.info(f"批量AI去味完成,共处理 {len(results)} 个文本")
|
||||
|
||||
return {
|
||||
"total": len(results),
|
||||
"results": results
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"批量AI去味失败: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"批量AI去味失败: {str(e)}")
|
||||
@@ -0,0 +1,414 @@
|
||||
"""项目管理API"""
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.responses import Response
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, func, delete
|
||||
from typing import List
|
||||
from app.database import get_db
|
||||
from app.models.project import Project
|
||||
from app.models.character import Character
|
||||
from app.models.outline import Outline
|
||||
from app.models.chapter import Chapter
|
||||
from app.models.generation_history import GenerationHistory
|
||||
from app.models.relationship import CharacterRelationship, Organization, OrganizationMember
|
||||
from app.schemas.project import (
|
||||
ProjectCreate,
|
||||
ProjectUpdate,
|
||||
ProjectResponse,
|
||||
ProjectListResponse
|
||||
)
|
||||
from app.logger import get_logger
|
||||
from app.utils.data_consistency import (
|
||||
run_full_data_consistency_check,
|
||||
fix_missing_organization_records,
|
||||
fix_organization_member_counts
|
||||
)
|
||||
|
||||
logger = get_logger(__name__)
|
||||
router = APIRouter(prefix="/projects", tags=["项目管理"])
|
||||
|
||||
|
||||
@router.post("", response_model=ProjectResponse, summary="创建项目")
|
||||
async def create_project(
|
||||
project: ProjectCreate,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
try:
|
||||
logger.info(f"创建新项目: {project.title}")
|
||||
db_project = Project(**project.model_dump())
|
||||
db.add(db_project)
|
||||
await db.commit()
|
||||
await db.refresh(db_project)
|
||||
logger.info(f"项目创建成功: {db_project.id}")
|
||||
return db_project
|
||||
except Exception as e:
|
||||
logger.error(f"创建项目失败: {str(e)}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
@router.get("", response_model=ProjectListResponse, summary="获取项目列表")
|
||||
async def get_projects(
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""获取所有项目列表"""
|
||||
try:
|
||||
logger.debug(f"获取项目列表: skip={skip}, limit={limit}")
|
||||
count_result = await db.execute(select(func.count(Project.id)))
|
||||
total = count_result.scalar_one()
|
||||
|
||||
result = await db.execute(
|
||||
select(Project)
|
||||
.order_by(Project.updated_at.desc())
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
)
|
||||
projects = result.scalars().all()
|
||||
logger.info(f"获取项目列表成功: 共{total}个项目")
|
||||
|
||||
return ProjectListResponse(total=total, items=projects)
|
||||
except Exception as e:
|
||||
logger.error(f"获取项目列表失败: {str(e)}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
@router.get("/{project_id}", response_model=ProjectResponse, summary="获取项目详情")
|
||||
async def get_project(
|
||||
project_id: str,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
try:
|
||||
logger.debug(f"获取项目详情: {project_id}")
|
||||
result = await db.execute(
|
||||
select(Project).where(Project.id == project_id)
|
||||
)
|
||||
project = result.scalar_one_or_none()
|
||||
|
||||
if not project:
|
||||
logger.warning(f"项目不存在: {project_id}")
|
||||
raise HTTPException(status_code=404, detail="项目不存在")
|
||||
|
||||
logger.info(f"获取项目详情成功: {project.title}")
|
||||
return project
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"获取项目详情失败: {str(e)}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
@router.put("/{project_id}", response_model=ProjectResponse, summary="更新项目")
|
||||
async def update_project(
|
||||
project_id: str,
|
||||
project_update: ProjectUpdate,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
try:
|
||||
logger.info(f"更新项目: {project_id}")
|
||||
result = await db.execute(
|
||||
select(Project).where(Project.id == project_id)
|
||||
)
|
||||
project = result.scalar_one_or_none()
|
||||
|
||||
if not project:
|
||||
logger.warning(f"项目不存在: {project_id}")
|
||||
raise HTTPException(status_code=404, detail="项目不存在")
|
||||
|
||||
update_data = project_update.model_dump(exclude_unset=True)
|
||||
logger.debug(f"更新字段: {list(update_data.keys())}")
|
||||
for field, value in update_data.items():
|
||||
setattr(project, field, value)
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(project)
|
||||
logger.info(f"项目更新成功: {project.title}")
|
||||
return project
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"更新项目失败: {str(e)}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
@router.delete("/{project_id}", summary="删除项目")
|
||||
async def delete_project(
|
||||
project_id: str,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
try:
|
||||
logger.info(f"删除项目: {project_id}")
|
||||
result = await db.execute(
|
||||
select(Project).where(Project.id == project_id)
|
||||
)
|
||||
project = result.scalar_one_or_none()
|
||||
|
||||
if not project:
|
||||
logger.warning(f"项目不存在: {project_id}")
|
||||
raise HTTPException(status_code=404, detail="项目不存在")
|
||||
|
||||
project_title = project.title
|
||||
|
||||
relationships_result = await db.execute(
|
||||
delete(CharacterRelationship).where(CharacterRelationship.project_id == project_id)
|
||||
)
|
||||
logger.debug(f"删除角色关系数: {relationships_result.rowcount}")
|
||||
|
||||
orgs_result = await db.execute(
|
||||
select(Organization).where(Organization.project_id == project_id)
|
||||
)
|
||||
orgs = orgs_result.scalars().all()
|
||||
org_member_count = 0
|
||||
for org in orgs:
|
||||
members_result = await db.execute(
|
||||
delete(OrganizationMember).where(OrganizationMember.organization_id == org.id)
|
||||
)
|
||||
org_member_count += members_result.rowcount
|
||||
logger.debug(f"删除组织成员数: {org_member_count}")
|
||||
|
||||
organizations_result = await db.execute(
|
||||
delete(Organization).where(Organization.project_id == project_id)
|
||||
)
|
||||
logger.debug(f"删除组织数: {organizations_result.rowcount}")
|
||||
|
||||
history_result = await db.execute(
|
||||
delete(GenerationHistory).where(GenerationHistory.project_id == project_id)
|
||||
)
|
||||
logger.debug(f"删除生成历史数: {history_result.rowcount}")
|
||||
|
||||
chapters_result = await db.execute(
|
||||
delete(Chapter).where(Chapter.project_id == project_id)
|
||||
)
|
||||
logger.debug(f"删除章节数: {chapters_result.rowcount}")
|
||||
|
||||
outlines_result = await db.execute(
|
||||
delete(Outline).where(Outline.project_id == project_id)
|
||||
)
|
||||
logger.debug(f"删除大纲数: {outlines_result.rowcount}")
|
||||
|
||||
characters_result = await db.execute(
|
||||
delete(Character).where(Character.project_id == project_id)
|
||||
)
|
||||
logger.debug(f"删除角色数: {characters_result.rowcount}")
|
||||
|
||||
await db.delete(project)
|
||||
await db.commit()
|
||||
|
||||
logger.info(f"项目删除成功: {project_title}")
|
||||
return {"message": "项目及所有关联数据删除成功"}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"删除项目失败: {str(e)}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
@router.get("/{project_id}/export", summary="导出项目章节为TXT")
|
||||
async def export_project_chapters(
|
||||
project_id: str,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
导出项目的所有章节内容为TXT文本文件
|
||||
按章节顺序组织,包含项目基本信息
|
||||
"""
|
||||
try:
|
||||
logger.info(f"开始导出项目: {project_id}")
|
||||
|
||||
result = await db.execute(
|
||||
select(Project).where(Project.id == project_id)
|
||||
)
|
||||
project = result.scalar_one_or_none()
|
||||
|
||||
if not project:
|
||||
logger.warning(f"项目不存在: {project_id}")
|
||||
raise HTTPException(status_code=404, detail="项目不存在")
|
||||
|
||||
chapters_result = await db.execute(
|
||||
select(Chapter)
|
||||
.where(Chapter.project_id == project_id)
|
||||
.order_by(Chapter.chapter_number)
|
||||
)
|
||||
chapters = chapters_result.scalars().all()
|
||||
|
||||
if not chapters:
|
||||
logger.warning(f"项目没有章节: {project_id}")
|
||||
raise HTTPException(status_code=404, detail="项目没有任何章节")
|
||||
|
||||
txt_content = []
|
||||
|
||||
txt_content.append("=" * 80)
|
||||
txt_content.append(f"项目标题: {project.title}")
|
||||
txt_content.append("=" * 80)
|
||||
|
||||
if project.description:
|
||||
txt_content.append(f"\n简介: {project.description}\n")
|
||||
|
||||
if project.theme:
|
||||
txt_content.append(f"主题: {project.theme}")
|
||||
|
||||
if project.genre:
|
||||
txt_content.append(f"类型: {project.genre}")
|
||||
|
||||
txt_content.append(f"总章节数: {len(chapters)}")
|
||||
txt_content.append(f"总字数: {project.current_words}")
|
||||
txt_content.append("\n" + "=" * 80 + "\n\n")
|
||||
|
||||
for chapter in chapters:
|
||||
txt_content.append(f"第 {chapter.chapter_number} 章 {chapter.title}")
|
||||
txt_content.append("-" * 80)
|
||||
txt_content.append("") # 空行
|
||||
|
||||
if chapter.content:
|
||||
txt_content.append(chapter.content)
|
||||
else:
|
||||
txt_content.append("(本章暂无内容)")
|
||||
|
||||
txt_content.append("\n\n" + "=" * 80 + "\n\n")
|
||||
|
||||
txt_content.append(f"--- 全文完 ---")
|
||||
txt_content.append(f"\n导出时间: {func.now()}")
|
||||
|
||||
final_content = "\n".join(txt_content)
|
||||
|
||||
safe_title = "".join(c for c in project.title if c.isalnum() or c in (' ', '-', '_', ',', '。', '、'))
|
||||
filename = f"{safe_title}.txt"
|
||||
|
||||
from urllib.parse import quote
|
||||
encoded_filename = quote(filename)
|
||||
|
||||
logger.info(f"导出成功: {filename}, 共{len(chapters)}章, {len(final_content)}字符")
|
||||
|
||||
return Response(
|
||||
content=final_content.encode('utf-8'),
|
||||
media_type="text/plain; charset=utf-8",
|
||||
headers={
|
||||
"Content-Disposition": f"attachment; filename*=UTF-8''{encoded_filename}",
|
||||
"Content-Type": "text/plain; charset=utf-8"
|
||||
}
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"导出项目失败: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"导出失败: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/{project_id}/check-consistency", summary="检查数据一致性")
|
||||
async def check_project_consistency(
|
||||
project_id: str,
|
||||
auto_fix: bool = True,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
检查并修复项目的数据一致性问题
|
||||
|
||||
Args:
|
||||
project_id: 项目ID
|
||||
auto_fix: 是否自动修复问题(默认True)
|
||||
|
||||
返回检查报告,包含:
|
||||
- organization_records: 检查并修复缺失的Organization记录
|
||||
- member_counts: 检查并修复组织成员计数
|
||||
- relationships: 验证关系数据完整性
|
||||
- organization_members: 验证组织成员数据完整性
|
||||
"""
|
||||
try:
|
||||
logger.info(f"开始数据一致性检查: {project_id}, auto_fix={auto_fix}")
|
||||
|
||||
result = await db.execute(
|
||||
select(Project).where(Project.id == project_id)
|
||||
)
|
||||
project = result.scalar_one_or_none()
|
||||
|
||||
if not project:
|
||||
logger.warning(f"项目不存在: {project_id}")
|
||||
raise HTTPException(status_code=404, detail="项目不存在")
|
||||
|
||||
report = await run_full_data_consistency_check(project_id, db, auto_fix)
|
||||
|
||||
logger.info(f"数据一致性检查完成: {project_id}")
|
||||
return report
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"数据一致性检查失败: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"检查失败: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/{project_id}/fix-organizations", summary="修复组织记录")
|
||||
async def fix_project_organizations(
|
||||
project_id: str,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
修复项目中缺失的Organization记录
|
||||
|
||||
为所有is_organization=True但没有Organization记录的Character创建记录
|
||||
"""
|
||||
try:
|
||||
logger.info(f"开始修复组织记录: {project_id}")
|
||||
|
||||
result = await db.execute(
|
||||
select(Project).where(Project.id == project_id)
|
||||
)
|
||||
project = result.scalar_one_or_none()
|
||||
|
||||
if not project:
|
||||
logger.warning(f"项目不存在: {project_id}")
|
||||
raise HTTPException(status_code=404, detail="项目不存在")
|
||||
|
||||
fixed_count, total_count = await fix_missing_organization_records(project_id, db)
|
||||
|
||||
logger.info(f"组织记录修复完成: {project_id}, 修复{fixed_count}/{total_count}")
|
||||
return {
|
||||
"message": "组织记录修复完成",
|
||||
"fixed": fixed_count,
|
||||
"total": total_count
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"修复组织记录失败: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"修复失败: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/{project_id}/fix-member-counts", summary="修复成员计数")
|
||||
async def fix_project_member_counts(
|
||||
project_id: str,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
修复项目中所有组织的成员计数
|
||||
|
||||
从实际成员记录重新计算每个组织的member_count
|
||||
"""
|
||||
try:
|
||||
logger.info(f"开始修复成员计数: {project_id}")
|
||||
|
||||
result = await db.execute(
|
||||
select(Project).where(Project.id == project_id)
|
||||
)
|
||||
project = result.scalar_one_or_none()
|
||||
|
||||
if not project:
|
||||
logger.warning(f"项目不存在: {project_id}")
|
||||
raise HTTPException(status_code=404, detail="项目不存在")
|
||||
|
||||
fixed_count, total_count = await fix_organization_member_counts(project_id, db)
|
||||
|
||||
logger.info(f"成员计数修复完成: {project_id}, 修复{fixed_count}/{total_count}")
|
||||
return {
|
||||
"message": "成员计数修复完成",
|
||||
"fixed": fixed_count,
|
||||
"total": total_count
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"修复成员计数失败: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"修复失败: {str(e)}")
|
||||
@@ -0,0 +1,209 @@
|
||||
"""关系管理API"""
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, or_, and_
|
||||
from typing import List, Optional
|
||||
|
||||
from app.database import get_db
|
||||
from app.models.relationship import (
|
||||
RelationshipType,
|
||||
CharacterRelationship,
|
||||
Organization,
|
||||
OrganizationMember
|
||||
)
|
||||
from app.models.character import Character
|
||||
from app.schemas.relationship import (
|
||||
RelationshipTypeResponse,
|
||||
CharacterRelationshipCreate,
|
||||
CharacterRelationshipUpdate,
|
||||
CharacterRelationshipResponse,
|
||||
RelationshipGraphData,
|
||||
RelationshipGraphNode,
|
||||
RelationshipGraphLink
|
||||
)
|
||||
from app.logger import get_logger
|
||||
|
||||
router = APIRouter(prefix="/relationships", tags=["关系管理"])
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@router.get("/types", response_model=List[RelationshipTypeResponse], summary="获取关系类型列表")
|
||||
async def get_relationship_types(db: AsyncSession = Depends(get_db)):
|
||||
"""获取所有预定义的关系类型"""
|
||||
result = await db.execute(select(RelationshipType).order_by(RelationshipType.category, RelationshipType.id))
|
||||
types = result.scalars().all()
|
||||
return types
|
||||
|
||||
|
||||
@router.get("/project/{project_id}", response_model=List[CharacterRelationshipResponse], summary="获取项目的所有关系")
|
||||
async def get_project_relationships(
|
||||
project_id: str,
|
||||
character_id: Optional[str] = Query(None, description="筛选特定角色的关系"),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
获取项目中的所有角色关系
|
||||
|
||||
- 如果提供character_id,则只返回与该角色相关的关系(作为发起方或接收方)
|
||||
- 否则返回项目中的所有关系
|
||||
"""
|
||||
query = select(CharacterRelationship).where(
|
||||
CharacterRelationship.project_id == project_id
|
||||
)
|
||||
|
||||
if character_id:
|
||||
query = query.where(
|
||||
or_(
|
||||
CharacterRelationship.character_from_id == character_id,
|
||||
CharacterRelationship.character_to_id == character_id
|
||||
)
|
||||
)
|
||||
|
||||
query = query.order_by(CharacterRelationship.created_at.desc())
|
||||
result = await db.execute(query)
|
||||
relationships = result.scalars().all()
|
||||
|
||||
logger.info(f"获取项目 {project_id} 的关系列表,共 {len(relationships)} 条")
|
||||
return relationships
|
||||
|
||||
|
||||
@router.get("/graph/{project_id}", response_model=RelationshipGraphData, summary="获取关系图谱数据")
|
||||
async def get_relationship_graph(
|
||||
project_id: str,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
获取用于可视化的关系图谱数据
|
||||
|
||||
返回格式:
|
||||
- nodes: 角色节点列表
|
||||
- links: 关系连线列表
|
||||
"""
|
||||
# 获取所有角色(节点)
|
||||
chars_result = await db.execute(
|
||||
select(Character).where(Character.project_id == project_id)
|
||||
)
|
||||
characters = chars_result.scalars().all()
|
||||
|
||||
nodes = [
|
||||
RelationshipGraphNode(
|
||||
id=c.id,
|
||||
name=c.name,
|
||||
type="organization" if c.is_organization else "character",
|
||||
role_type=c.role_type,
|
||||
avatar=c.avatar_url
|
||||
)
|
||||
for c in characters
|
||||
]
|
||||
|
||||
# 获取所有关系(边)
|
||||
rels_result = await db.execute(
|
||||
select(CharacterRelationship).where(
|
||||
CharacterRelationship.project_id == project_id
|
||||
)
|
||||
)
|
||||
relationships = rels_result.scalars().all()
|
||||
|
||||
links = [
|
||||
RelationshipGraphLink(
|
||||
source=r.character_from_id,
|
||||
target=r.character_to_id,
|
||||
relationship=r.relationship_name or "未知关系",
|
||||
intimacy=r.intimacy_level,
|
||||
status=r.status
|
||||
)
|
||||
for r in relationships
|
||||
]
|
||||
|
||||
logger.info(f"获取项目 {project_id} 的关系图谱:{len(nodes)} 个节点,{len(links)} 条关系")
|
||||
return RelationshipGraphData(nodes=nodes, links=links)
|
||||
|
||||
|
||||
@router.post("/", response_model=CharacterRelationshipResponse, summary="创建角色关系")
|
||||
async def create_relationship(
|
||||
relationship: CharacterRelationshipCreate,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
手动创建角色关系
|
||||
|
||||
- 需要提供角色A和角色B的ID
|
||||
- 可以指定预定义的关系类型或自定义关系名称
|
||||
- 可以设置亲密度、状态等属性
|
||||
"""
|
||||
# 验证角色是否存在
|
||||
char_from = await db.execute(
|
||||
select(Character).where(Character.id == relationship.character_from_id)
|
||||
)
|
||||
char_to = await db.execute(
|
||||
select(Character).where(Character.id == relationship.character_to_id)
|
||||
)
|
||||
|
||||
if not char_from.scalar_one_or_none():
|
||||
raise HTTPException(status_code=404, detail=f"角色A(ID: {relationship.character_from_id})不存在")
|
||||
if not char_to.scalar_one_or_none():
|
||||
raise HTTPException(status_code=404, detail=f"角色B(ID: {relationship.character_to_id})不存在")
|
||||
|
||||
# 创建关系
|
||||
db_relationship = CharacterRelationship(
|
||||
**relationship.model_dump(),
|
||||
source="manual"
|
||||
)
|
||||
db.add(db_relationship)
|
||||
await db.commit()
|
||||
await db.refresh(db_relationship)
|
||||
|
||||
logger.info(f"创建关系成功:{relationship.character_from_id} -> {relationship.character_to_id}")
|
||||
return db_relationship
|
||||
|
||||
|
||||
@router.put("/{relationship_id}", response_model=CharacterRelationshipResponse, summary="更新关系")
|
||||
async def update_relationship(
|
||||
relationship_id: str,
|
||||
relationship: CharacterRelationshipUpdate,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""更新角色关系的属性(亲密度、状态等)"""
|
||||
result = await db.execute(
|
||||
select(CharacterRelationship).where(
|
||||
CharacterRelationship.id == relationship_id
|
||||
)
|
||||
)
|
||||
db_rel = result.scalar_one_or_none()
|
||||
|
||||
if not db_rel:
|
||||
raise HTTPException(status_code=404, detail="关系不存在")
|
||||
|
||||
# 更新字段
|
||||
update_data = relationship.model_dump(exclude_unset=True)
|
||||
for field, value in update_data.items():
|
||||
setattr(db_rel, field, value)
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(db_rel)
|
||||
|
||||
logger.info(f"更新关系成功:{relationship_id}")
|
||||
return db_rel
|
||||
|
||||
|
||||
@router.delete("/{relationship_id}", summary="删除关系")
|
||||
async def delete_relationship(
|
||||
relationship_id: str,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""删除角色关系"""
|
||||
result = await db.execute(
|
||||
select(CharacterRelationship).where(
|
||||
CharacterRelationship.id == relationship_id
|
||||
)
|
||||
)
|
||||
db_rel = result.scalar_one_or_none()
|
||||
|
||||
if not db_rel:
|
||||
raise HTTPException(status_code=404, detail="关系不存在")
|
||||
|
||||
await db.delete(db_rel)
|
||||
await db.commit()
|
||||
|
||||
logger.info(f"删除关系成功:{relationship_id}")
|
||||
return {"message": "关系删除成功", "id": relationship_id}
|
||||
@@ -0,0 +1,125 @@
|
||||
"""
|
||||
用户管理 API
|
||||
"""
|
||||
from fastapi import APIRouter, HTTPException, Request, Depends
|
||||
from pydantic import BaseModel
|
||||
from typing import List, Optional
|
||||
from app.user_manager import user_manager, User
|
||||
|
||||
router = APIRouter(prefix="/users", tags=["用户管理"])
|
||||
|
||||
|
||||
def require_login(request: Request):
|
||||
"""依赖:要求用户已登录"""
|
||||
if not hasattr(request.state, "user") or not request.state.user:
|
||||
raise HTTPException(status_code=401, detail="需要登录")
|
||||
return request.state.user
|
||||
|
||||
|
||||
def require_admin(request: Request):
|
||||
"""依赖:要求用户为管理员"""
|
||||
user = require_login(request)
|
||||
if not request.state.is_admin:
|
||||
raise HTTPException(status_code=403, detail="需要管理员权限")
|
||||
return user
|
||||
|
||||
|
||||
class SetAdminRequest(BaseModel):
|
||||
user_id: str
|
||||
is_admin: bool
|
||||
|
||||
|
||||
@router.get("/current")
|
||||
async def get_current_user(user: User = Depends(require_login)):
|
||||
"""获取当前登录用户信息"""
|
||||
return user.dict()
|
||||
|
||||
|
||||
@router.get("", response_model=List[dict])
|
||||
async def list_users(admin_user: User = Depends(require_admin)):
|
||||
"""
|
||||
获取所有用户列表(仅管理员)
|
||||
"""
|
||||
users = await user_manager.get_all_users()
|
||||
return [user.dict() for user in users]
|
||||
|
||||
|
||||
@router.post("/set-admin")
|
||||
async def set_admin(
|
||||
data: SetAdminRequest,
|
||||
request: Request,
|
||||
admin_user: User = Depends(require_admin)
|
||||
):
|
||||
"""
|
||||
设置用户的管理员权限(仅管理员)
|
||||
|
||||
限制:
|
||||
- 不能撤销自己的管理员权限
|
||||
- 至少保留一个管理员
|
||||
"""
|
||||
# 检查是否尝试撤销自己的权限
|
||||
if data.user_id == admin_user.user_id and not data.is_admin:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="不能撤销自己的管理员权限"
|
||||
)
|
||||
|
||||
# 尝试设置管理员权限
|
||||
success = await user_manager.set_admin(data.user_id, data.is_admin)
|
||||
|
||||
if not success:
|
||||
if not data.is_admin:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="无法撤销管理员权限,至少需要保留一个管理员"
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="用户不存在"
|
||||
)
|
||||
|
||||
return {
|
||||
"message": f"已{'授予' if data.is_admin else '撤销'}管理员权限",
|
||||
"user_id": data.user_id,
|
||||
"is_admin": data.is_admin
|
||||
}
|
||||
|
||||
|
||||
@router.delete("/{user_id}")
|
||||
async def delete_user(
|
||||
user_id: str,
|
||||
admin_user: User = Depends(require_admin)
|
||||
):
|
||||
"""
|
||||
删除用户(仅管理员)
|
||||
|
||||
限制:
|
||||
- 不能删除管理员用户
|
||||
"""
|
||||
success = await user_manager.delete_user(user_id)
|
||||
|
||||
if not success:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="无法删除该用户(用户不存在或为管理员)"
|
||||
)
|
||||
|
||||
return {
|
||||
"message": "用户已删除",
|
||||
"user_id": user_id
|
||||
}
|
||||
|
||||
|
||||
@router.get("/{user_id}")
|
||||
async def get_user(
|
||||
user_id: str,
|
||||
admin_user: User = Depends(require_admin)
|
||||
):
|
||||
"""获取指定用户信息(仅管理员)"""
|
||||
user = await user_manager.get_user(user_id)
|
||||
|
||||
if not user:
|
||||
raise HTTPException(status_code=404, detail="用户不存在")
|
||||
|
||||
return user.dict()
|
||||
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user