This commit is contained in:
xiamuceer
2025-10-30 11:14:43 +08:00
parent b97410d973
commit 0f6c2d344a
91 changed files with 22309 additions and 0 deletions
+1
View File
@@ -0,0 +1 @@
"""API路由模块"""
+230
View File
@@ -0,0 +1,230 @@
"""
认证 API - LinuxDO OAuth2 登录 + 本地账户登录
"""
from fastapi import APIRouter, HTTPException, Response, Request
from fastapi.responses import RedirectResponse
from pydantic import BaseModel
from typing import Optional
import hashlib
from app.services.oauth_service import LinuxDOOAuthService
from app.user_manager import user_manager
from app.database import init_db
from app.logger import get_logger
from app.config import settings
logger = get_logger(__name__)
router = APIRouter(prefix="/auth", tags=["认证"])
# OAuth2 服务实例
oauth_service = LinuxDOOAuthService()
# State 临时存储(生产环境应使用 Redis)
_state_storage = {}
class AuthUrlResponse(BaseModel):
auth_url: str
state: str
class LocalLoginRequest(BaseModel):
"""本地登录请求"""
username: str
password: str
class LocalLoginResponse(BaseModel):
"""本地登录响应"""
success: bool
message: str
user: Optional[dict] = None
@router.get("/config")
async def get_auth_config():
"""获取认证配置信息"""
return {
"local_auth_enabled": settings.LOCAL_AUTH_ENABLED,
"linuxdo_enabled": bool(settings.LINUXDO_CLIENT_ID and settings.LINUXDO_CLIENT_SECRET)
}
@router.post("/local/login", response_model=LocalLoginResponse)
async def local_login(request: LocalLoginRequest, response: Response):
"""本地账户登录"""
# 检查是否启用本地登录
if not settings.LOCAL_AUTH_ENABLED:
raise HTTPException(status_code=403, detail="本地账户登录未启用")
# 检查是否配置了本地账户
if not settings.LOCAL_AUTH_USERNAME or not settings.LOCAL_AUTH_PASSWORD:
raise HTTPException(status_code=500, detail="本地账户未配置")
# 验证用户名和密码
if request.username != settings.LOCAL_AUTH_USERNAME or request.password != settings.LOCAL_AUTH_PASSWORD:
raise HTTPException(status_code=401, detail="用户名或密码错误")
# 生成本地用户ID(使用用户名的hash)
user_id = f"local_{hashlib.md5(request.username.encode()).hexdigest()[:16]}"
# 创建或更新本地用户
user = await user_manager.create_or_update_from_linuxdo(
linuxdo_id=user_id,
username=request.username,
display_name=settings.LOCAL_AUTH_DISPLAY_NAME,
avatar_url=None,
trust_level=9 # 本地用户给予高信任级别
)
# 初始化用户数据库
try:
await init_db(user.user_id)
logger.info(f"本地用户 {user.user_id} 数据库初始化成功")
except Exception as e:
logger.error(f"本地用户 {user.user_id} 数据库初始化失败: {e}")
# 设置 Cookie7天有效)
response.set_cookie(
key="user_id",
value=user.user_id,
max_age=7 * 24 * 60 * 60, # 7天
httponly=True,
samesite="lax"
)
return LocalLoginResponse(
success=True,
message="登录成功",
user=user.dict()
)
@router.get("/linuxdo/url", response_model=AuthUrlResponse)
async def get_linuxdo_auth_url():
"""获取 LinuxDO 授权 URL"""
state = oauth_service.generate_state()
auth_url = oauth_service.get_authorization_url(state)
# 临时存储 state5分钟有效)
_state_storage[state] = True
return AuthUrlResponse(auth_url=auth_url, state=state)
async def _handle_callback(
code: Optional[str] = None,
state: Optional[str] = None,
error: Optional[str] = None,
response: Response = None
):
"""
LinuxDO OAuth2 回调处理
成功后重定向到前端首页,并设置 user_id Cookie
"""
# 检查是否有错误
if error:
raise HTTPException(status_code=400, detail=f"授权失败: {error}")
# 检查必需参数
if not code or not state:
raise HTTPException(status_code=400, detail="缺少 code 或 state 参数")
# 验证 state(防止 CSRF
if state not in _state_storage:
raise HTTPException(status_code=400, detail="无效的 state 参数")
# 删除已使用的 state
del _state_storage[state]
# 1. 使用 code 获取 access_token
token_data = await oauth_service.get_access_token(code)
if not token_data or "access_token" not in token_data:
raise HTTPException(status_code=400, detail="获取访问令牌失败")
access_token = token_data["access_token"]
# 2. 使用 access_token 获取用户信息
user_info = await oauth_service.get_user_info(access_token)
if not user_info:
raise HTTPException(status_code=400, detail="获取用户信息失败")
# 3. 创建或更新用户
linuxdo_id = str(user_info.get("id"))
username = user_info.get("username", "")
display_name = user_info.get("name", username)
avatar_url = user_info.get("avatar_url")
trust_level = user_info.get("trust_level", 0)
user = await user_manager.create_or_update_from_linuxdo(
linuxdo_id=linuxdo_id,
username=username,
display_name=display_name,
avatar_url=avatar_url,
trust_level=trust_level
)
# 3.5. 初始化用户数据库(如果是新用户)
try:
await init_db(user.user_id)
logger.info(f"用户 {user.user_id} 数据库初始化成功")
except Exception as e:
logger.error(f"用户 {user.user_id} 数据库初始化失败: {e}")
# 继续执行,不影响登录流程(可能是已存在的用户)
# 4. 设置 Cookie 并重定向到前端回调页面
# 使用配置的前端URL,支持不同的部署环境
frontend_url = settings.FRONTEND_URL.rstrip('/')
redirect_url = f"{frontend_url}/auth/callback"
logger.info(f"OAuth回调成功,重定向到前端: {redirect_url}")
redirect_response = RedirectResponse(url=redirect_url)
# 设置 httponly Cookie7天有效)
redirect_response.set_cookie(
key="user_id",
value=user.user_id,
max_age=7 * 24 * 60 * 60, # 7天
httponly=True,
samesite="lax"
)
return redirect_response
@router.get("/linuxdo/callback")
async def linuxdo_callback(
code: Optional[str] = None,
state: Optional[str] = None,
error: Optional[str] = None,
response: Response = None
):
"""LinuxDO OAuth2 回调处理(标准路径)"""
return await _handle_callback(code, state, error, response)
@router.get("/callback")
async def callback_alias(
code: Optional[str] = None,
state: Optional[str] = None,
error: Optional[str] = None,
response: Response = None
):
"""LinuxDO OAuth2 回调处理(兼容路径)"""
return await _handle_callback(code, state, error, response)
@router.post("/logout")
async def logout(response: Response):
"""退出登录"""
response.delete_cookie("user_id")
return {"message": "退出登录成功"}
@router.get("/user")
async def get_current_user(request: Request):
"""获取当前登录用户信息"""
if not hasattr(request.state, "user") or not request.state.user:
raise HTTPException(status_code=401, detail="未登录")
return request.state.user.dict()
+655
View File
@@ -0,0 +1,655 @@
"""章节管理API"""
from fastapi import APIRouter, Depends, HTTPException, Request
from fastapi.responses import StreamingResponse
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, func
import json
import asyncio
from app.database import get_db
from app.models.chapter import Chapter
from app.models.project import Project
from app.models.outline import Outline
from app.models.character import Character
from app.models.generation_history import GenerationHistory
from app.schemas.chapter import (
ChapterCreate,
ChapterUpdate,
ChapterResponse,
ChapterListResponse
)
from app.services.ai_service import ai_service
from app.services.prompt_service import prompt_service
from app.logger import get_logger
router = APIRouter(prefix="/chapters", tags=["章节管理"])
logger = get_logger(__name__)
@router.post("", response_model=ChapterResponse, summary="创建章节")
async def create_chapter(
chapter: ChapterCreate,
db: AsyncSession = Depends(get_db)
):
"""创建新的章节"""
# 验证项目是否存在
result = await db.execute(
select(Project).where(Project.id == chapter.project_id)
)
project = result.scalar_one_or_none()
if not project:
raise HTTPException(status_code=404, detail="项目不存在")
# 计算字数
word_count = len(chapter.content)
db_chapter = Chapter(
**chapter.model_dump(),
word_count=word_count
)
db.add(db_chapter)
# 更新项目的当前字数
project.current_words = project.current_words + word_count
await db.commit()
await db.refresh(db_chapter)
return db_chapter
@router.get("/project/{project_id}", response_model=ChapterListResponse, summary="获取项目的所有章节")
async def get_project_chapters(
project_id: str,
db: AsyncSession = Depends(get_db)
):
"""获取指定项目的所有章节(路径参数版本)"""
# 获取总数
count_result = await db.execute(
select(func.count(Chapter.id)).where(Chapter.project_id == project_id)
)
total = count_result.scalar_one()
# 获取章节列表
result = await db.execute(
select(Chapter)
.where(Chapter.project_id == project_id)
.order_by(Chapter.chapter_number)
)
chapters = result.scalars().all()
return ChapterListResponse(total=total, items=chapters)
@router.get("/{chapter_id}", response_model=ChapterResponse, summary="获取章节详情")
async def get_chapter(
chapter_id: str,
db: AsyncSession = Depends(get_db)
):
"""根据ID获取章节详情"""
result = await db.execute(
select(Chapter).where(Chapter.id == chapter_id)
)
chapter = result.scalar_one_or_none()
if not chapter:
raise HTTPException(status_code=404, detail="章节不存在")
return chapter
@router.put("/{chapter_id}", response_model=ChapterResponse, summary="更新章节")
async def update_chapter(
chapter_id: str,
chapter_update: ChapterUpdate,
db: AsyncSession = Depends(get_db)
):
"""更新章节信息"""
result = await db.execute(
select(Chapter).where(Chapter.id == chapter_id)
)
chapter = result.scalar_one_or_none()
if not chapter:
raise HTTPException(status_code=404, detail="章节不存在")
# 记录旧字数
old_word_count = chapter.word_count or 0
# 更新字段
update_data = chapter_update.model_dump(exclude_unset=True)
for field, value in update_data.items():
setattr(chapter, field, value)
# 如果内容更新了,重新计算字数
if "content" in update_data and chapter.content:
new_word_count = len(chapter.content)
chapter.word_count = new_word_count
# 更新项目字数
result = await db.execute(
select(Project).where(Project.id == chapter.project_id)
)
project = result.scalar_one_or_none()
if project:
project.current_words = project.current_words - old_word_count + new_word_count
await db.commit()
await db.refresh(chapter)
return chapter
@router.delete("/{chapter_id}", summary="删除章节")
async def delete_chapter(
chapter_id: str,
db: AsyncSession = Depends(get_db)
):
"""删除章节"""
result = await db.execute(
select(Chapter).where(Chapter.id == chapter_id)
)
chapter = result.scalar_one_or_none()
if not chapter:
raise HTTPException(status_code=404, detail="章节不存在")
# 更新项目字数
result = await db.execute(
select(Project).where(Project.id == chapter.project_id)
)
project = result.scalar_one_or_none()
if project:
project.current_words = max(0, project.current_words - chapter.word_count)
await db.delete(chapter)
await db.commit()
return {"message": "章节删除成功"}
async def check_prerequisites(db: AsyncSession, chapter: Chapter) -> tuple[bool, str, list[Chapter]]:
"""
检查章节前置条件
Args:
db: 数据库会话
chapter: 当前章节
Returns:
(可否生成, 错误信息, 前置章节列表)
"""
# 如果是第一章,无需检查前置
if chapter.chapter_number == 1:
return True, "", []
# 查询所有前置章节(序号小于当前章节的)
result = await db.execute(
select(Chapter)
.where(Chapter.project_id == chapter.project_id)
.where(Chapter.chapter_number < chapter.chapter_number)
.order_by(Chapter.chapter_number)
)
previous_chapters = result.scalars().all()
# 检查是否所有前置章节都有内容
incomplete_chapters = [
ch for ch in previous_chapters
if not ch.content or ch.content.strip() == ""
]
if incomplete_chapters:
missing_numbers = [str(ch.chapter_number) for ch in incomplete_chapters]
error_msg = f"需要先完成前置章节:第 {', '.join(missing_numbers)}"
return False, error_msg, previous_chapters
return True, "", previous_chapters
@router.get("/{chapter_id}/can-generate", summary="检查章节是否可以生成")
async def check_can_generate(
chapter_id: str,
db: AsyncSession = Depends(get_db)
):
"""
检查章节是否满足生成条件
返回可生成状态和前置章节信息
"""
# 获取章节
result = await db.execute(
select(Chapter).where(Chapter.id == chapter_id)
)
chapter = result.scalar_one_or_none()
if not chapter:
raise HTTPException(status_code=404, detail="章节不存在")
# 检查前置条件
can_generate, error_msg, previous_chapters = await check_prerequisites(db, chapter)
# 构建前置章节信息
previous_info = [
{
"id": ch.id,
"chapter_number": ch.chapter_number,
"title": ch.title,
"has_content": bool(ch.content and ch.content.strip()),
"word_count": ch.word_count or 0
}
for ch in previous_chapters
]
return {
"can_generate": can_generate,
"reason": error_msg if not can_generate else "",
"previous_chapters": previous_info,
"chapter_number": chapter.chapter_number
}
@router.post("/{chapter_id}/generate", summary="AI创作章节内容")
async def generate_chapter_content(
chapter_id: str,
db: AsyncSession = Depends(get_db)
):
"""
根据大纲、前置章节内容和项目信息AI创作章节完整内容
要求:必须按顺序生成,确保前置章节都已完成
"""
# 获取章节
result = await db.execute(
select(Chapter).where(Chapter.id == chapter_id)
)
chapter = result.scalar_one_or_none()
if not chapter:
raise HTTPException(status_code=404, detail="章节不存在")
# 检查前置条件
can_generate, error_msg, previous_chapters = await check_prerequisites(db, chapter)
if not can_generate:
raise HTTPException(status_code=400, detail=error_msg)
try:
# 获取项目信息
project_result = await db.execute(
select(Project).where(Project.id == chapter.project_id)
)
project = project_result.scalar_one_or_none()
if not project:
raise HTTPException(status_code=404, detail="项目不存在")
# 获取对应的大纲(使用新的查询确保获取最新数据)
outline_result = await db.execute(
select(Outline)
.where(Outline.project_id == chapter.project_id)
.where(Outline.order_index == chapter.chapter_number)
.execution_options(populate_existing=True)
)
outline = outline_result.scalar_one_or_none()
# 获取所有大纲用于上下文(使用新的查询确保获取最新数据)
all_outlines_result = await db.execute(
select(Outline)
.where(Outline.project_id == chapter.project_id)
.order_by(Outline.order_index)
.execution_options(populate_existing=True)
)
all_outlines = all_outlines_result.scalars().all()
outlines_context = "\n".join([
f"{o.order_index}{o.title}: {o.content[:100]}..."
for o in all_outlines
])
# 获取角色信息
characters_result = await db.execute(
select(Character).where(Character.project_id == chapter.project_id)
)
characters = characters_result.scalars().all()
characters_info = "\n".join([
f"- {c.name}({'组织' if c.is_organization else '角色'}, {c.role_type}): {c.personality[:100] if c.personality else ''}"
for c in characters
])
# 构建前置章节内容上下文(如果有前置章节)
previous_content = ""
if previous_chapters:
# Token控制:保留最近3章的完整内容,早期章节使用摘要
recent_chapters = previous_chapters[-3:] if len(previous_chapters) > 3 else previous_chapters
early_chapters = previous_chapters[:-3] if len(previous_chapters) > 3 else []
# 早期章节摘要
if early_chapters:
early_summary = "【前期剧情概要】\n" + "\n".join([
f"{ch.chapter_number}章《{ch.title}》:{ch.content[:200] if ch.content else ''}..."
for ch in early_chapters
])
previous_content += early_summary + "\n\n"
# 最近章节完整内容
if recent_chapters:
recent_content = "【最近章节完整内容】\n" + "\n\n".join([
f"=== 第{ch.chapter_number}章:{ch.title} ===\n{ch.content}"
for ch in recent_chapters
])
previous_content += recent_content
logger.info(f"构建前置上下文:{len(early_chapters)}章摘要 + {len(recent_chapters)}章完整内容")
# 根据是否有前置内容选择不同的提示词
if previous_content:
# 使用带上下文的提示词
prompt = prompt_service.get_chapter_generation_with_context_prompt(
title=project.title,
theme=project.theme or '',
genre=project.genre or '',
narrative_perspective=project.narrative_perspective or '第三人称',
time_period=project.world_time_period or '未设定',
location=project.world_location or '未设定',
atmosphere=project.world_atmosphere or '未设定',
rules=project.world_rules or '未设定',
characters_info=characters_info or '暂无角色信息',
outlines_context=outlines_context,
previous_content=previous_content,
chapter_number=chapter.chapter_number,
chapter_title=chapter.title,
chapter_outline=outline.content if outline else chapter.summary or '暂无大纲'
)
else:
# 第一章,使用原有提示词
prompt = prompt_service.get_chapter_generation_prompt(
title=project.title,
theme=project.theme or '',
genre=project.genre or '',
narrative_perspective=project.narrative_perspective or '第三人称',
time_period=project.world_time_period or '未设定',
location=project.world_location or '未设定',
atmosphere=project.world_atmosphere or '未设定',
rules=project.world_rules or '未设定',
characters_info=characters_info or '暂无角色信息',
outlines_context=outlines_context,
chapter_number=chapter.chapter_number,
chapter_title=chapter.title,
chapter_outline=outline.content if outline else chapter.summary or '暂无大纲'
)
logger.info(f"开始AI创作章节 {chapter_id}")
# 调用AI生成
ai_content = await ai_service.generate_text(
prompt=prompt
)
# 更新章节内容
old_word_count = chapter.word_count or 0
chapter.content = ai_content
new_word_count = len(ai_content)
chapter.word_count = new_word_count
chapter.status = "completed"
# 更新项目字数
project.current_words = project.current_words - old_word_count + new_word_count
# 记录生成历史
history = GenerationHistory(
project_id=chapter.project_id,
chapter_id=chapter.id,
prompt=f"创作章节: 第{chapter.chapter_number}{chapter.title}",
generated_content=ai_content[:500] if len(ai_content) > 500 else ai_content,
model="default"
)
db.add(history)
await db.commit()
await db.refresh(chapter)
logger.info(f"成功创作章节 {chapter_id},共 {new_word_count}")
return {"content": ai_content}
except Exception as e:
logger.error(f"创作章节失败: {str(e)}")
raise HTTPException(status_code=500, detail=f"创作章节失败: {str(e)}")
@router.post("/{chapter_id}/generate-stream", summary="AI创作章节内容(流式)")
async def generate_chapter_content_stream(
chapter_id: str,
request: Request
):
"""
根据大纲、前置章节内容和项目信息AI创作章节完整内容(流式返回)
要求:必须按顺序生成,确保前置章节都已完成
注意:此函数不使用依赖注入的db,而是在生成器内部创建独立的数据库会话
以避免流式响应期间的连接泄漏问题
"""
# 预先验证章节存在性(使用临时会话)
async for temp_db in get_db(request):
try:
result = await temp_db.execute(
select(Chapter).where(Chapter.id == chapter_id)
)
chapter = result.scalar_one_or_none()
if not chapter:
raise HTTPException(status_code=404, detail="章节不存在")
# 检查前置条件
can_generate, error_msg, previous_chapters = await check_prerequisites(temp_db, chapter)
if not can_generate:
raise HTTPException(status_code=400, detail=error_msg)
# 保存前置章节数据供生成器使用
previous_chapters_data = [
{
'id': ch.id,
'chapter_number': ch.chapter_number,
'title': ch.title,
'content': ch.content
}
for ch in previous_chapters
]
finally:
await temp_db.close()
break
async def event_generator():
# 在生成器内部创建独立的数据库会话
db_session = None
db_committed = False
try:
# 创建新的数据库会话
async for db_session in get_db(request):
# 重新获取章节信息
chapter_result = await db_session.execute(
select(Chapter).where(Chapter.id == chapter_id)
)
current_chapter = chapter_result.scalar_one_or_none()
if not current_chapter:
yield f"data: {json.dumps({'type': 'error', 'error': '章节不存在'}, ensure_ascii=False)}\n\n"
return
# 获取项目信息
project_result = await db_session.execute(
select(Project).where(Project.id == current_chapter.project_id)
)
project = project_result.scalar_one_or_none()
if not project:
yield f"data: {json.dumps({'type': 'error', 'error': '项目不存在'}, ensure_ascii=False)}\n\n"
return
# 获取对应的大纲
outline_result = await db_session.execute(
select(Outline)
.where(Outline.project_id == current_chapter.project_id)
.where(Outline.order_index == current_chapter.chapter_number)
.execution_options(populate_existing=True)
)
outline = outline_result.scalar_one_or_none()
# 获取所有大纲用于上下文
all_outlines_result = await db_session.execute(
select(Outline)
.where(Outline.project_id == current_chapter.project_id)
.order_by(Outline.order_index)
.execution_options(populate_existing=True)
)
all_outlines = all_outlines_result.scalars().all()
outlines_context = "\n".join([
f"{o.order_index}{o.title}: {o.content[:100]}..."
for o in all_outlines
])
# 获取角色信息
characters_result = await db_session.execute(
select(Character).where(Character.project_id == current_chapter.project_id)
)
characters = characters_result.scalars().all()
characters_info = "\n".join([
f"- {c.name}({'组织' if c.is_organization else '角色'}, {c.role_type}): {c.personality[:100] if c.personality else ''}"
for c in characters
])
# 构建前置章节内容上下文(使用之前保存的数据)
previous_content = ""
if previous_chapters_data:
recent_chapters = previous_chapters_data[-3:] if len(previous_chapters_data) > 3 else previous_chapters_data
early_chapters = previous_chapters_data[:-3] if len(previous_chapters_data) > 3 else []
if early_chapters:
early_summary = "【前期剧情概要】\n" + "\n".join([
f"{ch['chapter_number']}章《{ch['title']}》:{ch['content'][:200] if ch['content'] else ''}..."
for ch in early_chapters
])
previous_content += early_summary + "\n\n"
if recent_chapters:
recent_content = "【最近章节完整内容】\n" + "\n\n".join([
f"=== 第{ch['chapter_number']}章:{ch['title']} ===\n{ch['content']}"
for ch in recent_chapters
])
previous_content += recent_content
logger.info(f"构建前置上下文:{len(early_chapters)}章摘要 + {len(recent_chapters)}章完整内容")
# 发送开始事件
yield f"data: {json.dumps({'type': 'start', 'message': '开始AI创作...'}, ensure_ascii=False)}\n\n"
# 根据是否有前置内容选择不同的提示词
if previous_content:
prompt = prompt_service.get_chapter_generation_with_context_prompt(
title=project.title,
theme=project.theme or '',
genre=project.genre or '',
narrative_perspective=project.narrative_perspective or '第三人称',
time_period=project.world_time_period or '未设定',
location=project.world_location or '未设定',
atmosphere=project.world_atmosphere or '未设定',
rules=project.world_rules or '未设定',
characters_info=characters_info or '暂无角色信息',
outlines_context=outlines_context,
previous_content=previous_content,
chapter_number=current_chapter.chapter_number,
chapter_title=current_chapter.title,
chapter_outline=outline.content if outline else current_chapter.summary or '暂无大纲'
)
else:
prompt = prompt_service.get_chapter_generation_prompt(
title=project.title,
theme=project.theme or '',
genre=project.genre or '',
narrative_perspective=project.narrative_perspective or '第三人称',
time_period=project.world_time_period or '未设定',
location=project.world_location or '未设定',
atmosphere=project.world_atmosphere or '未设定',
rules=project.world_rules or '未设定',
characters_info=characters_info or '暂无角色信息',
outlines_context=outlines_context,
chapter_number=current_chapter.chapter_number,
chapter_title=current_chapter.title,
chapter_outline=outline.content if outline else current_chapter.summary or '暂无大纲'
)
logger.info(f"开始AI流式创作章节 {chapter_id}")
# 流式生成内容
full_content = ""
async for chunk in ai_service.generate_text_stream(prompt=prompt):
full_content += chunk
yield f"data: {json.dumps({'type': 'content', 'content': chunk}, ensure_ascii=False)}\n\n"
await asyncio.sleep(0) # 让出控制权
# 更新章节内容到数据库
old_word_count = current_chapter.word_count or 0
current_chapter.content = full_content
new_word_count = len(full_content)
current_chapter.word_count = new_word_count
current_chapter.status = "completed"
# 更新项目字数
project.current_words = project.current_words - old_word_count + new_word_count
# 记录生成历史
history = GenerationHistory(
project_id=current_chapter.project_id,
chapter_id=current_chapter.id,
prompt=f"创作章节: 第{current_chapter.chapter_number}{current_chapter.title}",
generated_content=full_content[:500] if len(full_content) > 500 else full_content,
model="default"
)
db_session.add(history)
await db_session.commit()
db_committed = True
await db_session.refresh(current_chapter)
logger.info(f"成功创作章节 {chapter_id},共 {new_word_count}")
# 发送完成事件
yield f"data: {json.dumps({'type': 'done', 'message': '创作完成', 'word_count': new_word_count}, ensure_ascii=False)}\n\n"
break # 退出async for db_session循环
except GeneratorExit:
# SSE连接断开
logger.warning("章节生成器被提前关闭(SSE断开)")
if db_session and not db_committed:
try:
if db_session.in_transaction():
await db_session.rollback()
logger.info("章节生成事务已回滚(GeneratorExit")
except Exception as e:
logger.error(f"GeneratorExit回滚失败: {str(e)}")
except Exception as e:
logger.error(f"流式创作章节失败: {str(e)}")
if db_session and not db_committed:
try:
if db_session.in_transaction():
await db_session.rollback()
logger.info("章节生成事务已回滚(异常)")
except Exception as rollback_error:
logger.error(f"回滚失败: {str(rollback_error)}")
yield f"data: {json.dumps({'type': 'error', 'error': str(e)}, ensure_ascii=False)}\n\n"
finally:
# 确保数据库会话被正确关闭
if db_session:
try:
# 最后检查:确保没有未提交的事务
if not db_committed and db_session.in_transaction():
await db_session.rollback()
logger.warning("在finally中发现未提交事务,已回滚")
await db_session.close()
logger.info("数据库会话已关闭")
except Exception as close_error:
logger.error(f"关闭数据库会话失败: {str(close_error)}")
# 强制关闭
try:
await db_session.close()
except:
pass
return StreamingResponse(
event_generator(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no"
}
)
+491
View File
@@ -0,0 +1,491 @@
"""角色管理API"""
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, func
import json
from app.database import get_db
from app.models.character import Character
from app.models.project import Project
from app.models.generation_history import GenerationHistory
from app.models.relationship import CharacterRelationship, Organization, OrganizationMember, RelationshipType
from app.schemas.character import (
CharacterUpdate,
CharacterResponse,
CharacterListResponse,
CharacterGenerateRequest
)
from app.services.ai_service import ai_service
from app.services.prompt_service import prompt_service
from app.logger import get_logger
router = APIRouter(prefix="/characters", tags=["角色管理"])
logger = get_logger(__name__)
@router.get("", response_model=CharacterListResponse, summary="获取角色列表")
async def get_characters(
project_id: str,
db: AsyncSession = Depends(get_db)
):
"""获取指定项目的所有角色(query参数版本)"""
# 获取总数
count_result = await db.execute(
select(func.count(Character.id)).where(Character.project_id == project_id)
)
total = count_result.scalar_one()
# 获取角色列表
result = await db.execute(
select(Character)
.where(Character.project_id == project_id)
.order_by(Character.created_at.desc())
)
characters = result.scalars().all()
return CharacterListResponse(total=total, items=characters)
@router.get("/project/{project_id}", response_model=CharacterListResponse, summary="获取项目的所有角色")
async def get_project_characters(
project_id: str,
db: AsyncSession = Depends(get_db)
):
"""获取指定项目的所有角色(路径参数版本)"""
# 获取总数
count_result = await db.execute(
select(func.count(Character.id)).where(Character.project_id == project_id)
)
total = count_result.scalar_one()
# 获取角色列表
result = await db.execute(
select(Character)
.where(Character.project_id == project_id)
.order_by(Character.created_at.desc())
)
characters = result.scalars().all()
return CharacterListResponse(total=total, items=characters)
@router.get("/{character_id}", response_model=CharacterResponse, summary="获取角色详情")
async def get_character(
character_id: str,
db: AsyncSession = Depends(get_db)
):
"""根据ID获取角色详情"""
result = await db.execute(
select(Character).where(Character.id == character_id)
)
character = result.scalar_one_or_none()
if not character:
raise HTTPException(status_code=404, detail="角色不存在")
return character
@router.put("/{character_id}", response_model=CharacterResponse, summary="更新角色")
async def update_character(
character_id: str,
character_update: CharacterUpdate,
db: AsyncSession = Depends(get_db)
):
"""更新角色信息"""
result = await db.execute(
select(Character).where(Character.id == character_id)
)
character = result.scalar_one_or_none()
if not character:
raise HTTPException(status_code=404, detail="角色不存在")
# 更新字段
update_data = character_update.model_dump(exclude_unset=True)
for field, value in update_data.items():
setattr(character, field, value)
await db.commit()
await db.refresh(character)
return character
@router.delete("/{character_id}", summary="删除角色")
async def delete_character(
character_id: str,
db: AsyncSession = Depends(get_db)
):
"""删除角色"""
result = await db.execute(
select(Character).where(Character.id == character_id)
)
character = result.scalar_one_or_none()
if not character:
raise HTTPException(status_code=404, detail="角色不存在")
await db.delete(character)
await db.commit()
return {"message": "角色删除成功"}
@router.post("/generate", response_model=CharacterResponse, summary="AI生成角色")
async def generate_character(
request: CharacterGenerateRequest,
db: AsyncSession = Depends(get_db)
):
"""
使用AI生成角色卡
根据用户输入的信息,结合项目的世界观、主题等背景,
AI会生成一个完整、详细的角色设定卡片。
生成内容包括:姓名、年龄、性别、性格、外貌、背景故事、人际关系等
"""
# 验证项目是否存在并获取项目信息
result = await db.execute(
select(Project).where(Project.id == request.project_id)
)
project = result.scalar_one_or_none()
if not project:
raise HTTPException(status_code=404, detail="项目不存在")
try:
# 获取已存在的角色列表,用于关系网络
existing_chars_result = await db.execute(
select(Character)
.where(Character.project_id == request.project_id)
.order_by(Character.created_at.desc())
)
existing_characters = existing_chars_result.scalars().all()
# 构建现有角色信息摘要(包含组织)
existing_chars_info = ""
character_list = []
organization_list = []
if existing_characters:
for c in existing_characters[:10]: # 最多显示10个
if c.is_organization:
organization_list.append(f"- {c.name} [{c.organization_type or '组织'}]")
else:
character_list.append(f"- {c.name}{c.role_type or '未知'}")
if character_list:
existing_chars_info += "\n已有角色:\n" + "\n".join(character_list)
if organization_list:
existing_chars_info += "\n\n已有组织:\n" + "\n".join(organization_list)
# 构建项目上下文信息
project_context = f"""
项目信息:
- 书名:{project.title}
- 主题:{project.theme or '未设定'}
- 类型:{project.genre or '未设定'}
- 时间背景:{project.world_time_period or '未设定'}
- 地理位置:{project.world_location or '未设定'}
- 氛围基调:{project.world_atmosphere or '未设定'}
- 世界规则:{project.world_rules or '未设定'}
{existing_chars_info}
"""
# 构建用户输入信息
user_input = f"""
用户要求:
- 角色名称:{request.name or '请AI生成'}
- 角色定位:{request.role_type or 'supporting'}protagonist=主角, supporting=配角, antagonist=反派)
- 背景设定:{request.background or '无特殊要求'}
- 其他要求:{request.requirements or ''}
"""
# 使用统一的提示词服务
prompt = prompt_service.get_single_character_prompt(
project_context=project_context,
user_input=user_input
)
# 调用AI生成角色
logger.info(f"🎯 开始为项目 {request.project_id} 生成角色")
logger.info(f" - 角色名:{request.name or 'AI生成'}")
logger.info(f" - 角色定位:{request.role_type}")
logger.info(f" - 背景设定:{request.background or ''}")
logger.info(f" - AI提供商:{request.provider or 'default'}")
logger.info(f" - AI模型:{request.model or 'default'}")
logger.info(f" - Prompt长度:{len(prompt)} 字符")
try:
ai_response = await ai_service.generate_text(
prompt=prompt,
provider=request.provider,
model=request.model
)
logger.info(f"✅ AI响应接收完成,长度:{len(ai_response) if ai_response else 0} 字符")
except Exception as ai_error:
logger.error(f"❌ AI服务调用异常:{str(ai_error)}")
raise HTTPException(
status_code=500,
detail=f"AI服务调用失败:{str(ai_error)}"
)
# 检查AI响应
if not ai_response or not ai_response.strip():
logger.error("❌ AI返回了空响应")
raise HTTPException(
status_code=500,
detail="AI服务返回空响应。可能原因:1) API配置错误 2) 模型不支持 3) 网络问题。请检查后端日志。"
)
logger.info(f"📝 开始清理AI响应")
# 清理AI响应,移除可能的markdown标记
cleaned_response = ai_response.strip()
original_length = len(cleaned_response)
if cleaned_response.startswith("```json"):
cleaned_response = cleaned_response[7:]
logger.info(" - 移除了 ```json 标记")
if cleaned_response.startswith("```"):
cleaned_response = cleaned_response[3:]
logger.info(" - 移除了 ``` 标记")
if cleaned_response.endswith("```"):
cleaned_response = cleaned_response[:-3]
logger.info(" - 移除了末尾 ``` 标记")
cleaned_response = cleaned_response.strip()
logger.info(f" - 清理前长度:{original_length},清理后长度:{len(cleaned_response)}")
logger.info(f" - 清理后内容预览(前300字符):{cleaned_response[:300]}")
# 解析AI响应
logger.info(f"🔍 开始解析JSON")
try:
character_data = json.loads(cleaned_response)
logger.info(f"✅ JSON解析成功")
logger.info(f" - 解析后的字段:{list(character_data.keys())}")
except json.JSONDecodeError as e:
logger.error(f"❌ JSON解析失败")
logger.error(f" - 错误位置:line {e.lineno}, column {e.colno}")
logger.error(f" - 错误信息:{str(e)}")
logger.error(f" - 完整响应内容(前1000字符):{cleaned_response[:1000]}")
raise HTTPException(
status_code=500,
detail=f"AI返回的内容无法解析为JSON。错误:{str(e)}。响应内容已记录到日志,请查看后端日志排查。"
)
# 转换traits为JSON字符串
traits_json = json.dumps(character_data.get("traits", []), ensure_ascii=False) if character_data.get("traits") else None
# 判断是否为组织
is_organization = character_data.get("is_organization", False)
# 创建角色
character = Character(
project_id=request.project_id,
name=character_data.get("name", request.name or "未命名角色"),
age=str(character_data.get("age", "")),
gender=character_data.get("gender"),
is_organization=is_organization,
role_type=request.role_type or "supporting",
personality=character_data.get("personality", ""),
background=character_data.get("background", ""),
appearance=character_data.get("appearance", ""),
relationships=character_data.get("relationships_text", character_data.get("relationships", "")), # 优先使用文本描述
organization_type=character_data.get("organization_type") if is_organization else None,
organization_purpose=character_data.get("organization_purpose") if is_organization else None,
organization_members=json.dumps(character_data.get("organization_members", []), ensure_ascii=False) if is_organization else None,
traits=traits_json
)
db.add(character)
await db.flush() # 获取character.id
logger.info(f"✅ 角色创建成功:{character.name} (ID: {character.id}, 是否组织: {is_organization})")
# 如果是组织,自动创建Organization详情记录
if is_organization:
org_check = await db.execute(
select(Organization).where(Organization.character_id == character.id)
)
existing_org = org_check.scalar_one_or_none()
if not existing_org:
organization = Organization(
character_id=character.id,
project_id=request.project_id,
member_count=0,
power_level=character_data.get("power_level", 50),
location=character_data.get("location"),
motto=character_data.get("motto")
)
db.add(organization)
await db.flush()
logger.info(f"✅ 自动创建组织详情:{character.name} (Org ID: {organization.id})")
else:
logger.info(f"️ 组织详情已存在:{character.name}")
# 处理结构化关系数据(仅针对非组织角色)
if not is_organization:
relationships_data = character_data.get("relationships", [])
if relationships_data and isinstance(relationships_data, list):
logger.info(f"📊 开始处理 {len(relationships_data)} 条关系数据")
created_rels = 0
for rel in relationships_data:
try:
target_name = rel.get("target_character_name")
if not target_name:
logger.debug(f" ⚠️ 关系缺少target_character_name,跳过")
continue
target_result = await db.execute(
select(Character).where(
Character.project_id == request.project_id,
Character.name == target_name
)
)
target_char = target_result.scalar_one_or_none()
if target_char:
# 检查是否已存在相同关系
existing_rel = await db.execute(
select(CharacterRelationship).where(
CharacterRelationship.project_id == request.project_id,
CharacterRelationship.character_from_id == character.id,
CharacterRelationship.character_to_id == target_char.id
)
)
if existing_rel.scalar_one_or_none():
logger.debug(f" ️ 关系已存在:{character.name} -> {target_name}")
continue
relationship = CharacterRelationship(
project_id=request.project_id,
character_from_id=character.id,
character_to_id=target_char.id,
relationship_name=rel.get("relationship_type", "未知关系"),
intimacy_level=rel.get("intimacy_level", 50),
description=rel.get("description", ""),
started_at=rel.get("started_at"),
source="ai"
)
# 匹配预定义关系类型
rel_type_result = await db.execute(
select(RelationshipType).where(
RelationshipType.name == rel.get("relationship_type")
)
)
rel_type = rel_type_result.scalar_one_or_none()
if rel_type:
relationship.relationship_type_id = rel_type.id
db.add(relationship)
created_rels += 1
logger.info(f" ✅ 创建关系:{character.name} -> {target_name} ({rel.get('relationship_type')})")
else:
logger.warning(f" ⚠️ 目标角色不存在:{target_name}")
except Exception as rel_error:
logger.warning(f" ❌ 创建关系失败:{str(rel_error)}")
continue
logger.info(f"✅ 成功创建 {created_rels} 条关系记录")
# 处理组织成员关系(仅针对非组织角色)
if not is_organization:
org_memberships = character_data.get("organization_memberships", [])
if org_memberships and isinstance(org_memberships, list):
logger.info(f"🏢 开始处理 {len(org_memberships)} 条组织成员关系")
created_members = 0
for membership in org_memberships:
try:
org_name = membership.get("organization_name")
if not org_name:
logger.debug(f" ⚠️ 组织成员关系缺少organization_name,跳过")
continue
org_char_result = await db.execute(
select(Character).where(
Character.project_id == request.project_id,
Character.name == org_name,
Character.is_organization == True
)
)
org_char = org_char_result.scalar_one_or_none()
if org_char:
# 获取或创建Organization记录
org_result = await db.execute(
select(Organization).where(Organization.character_id == org_char.id)
)
org = org_result.scalar_one_or_none()
if not org:
# 如果组织Character存在但Organization不存在,自动创建
org = Organization(
character_id=org_char.id,
project_id=request.project_id,
member_count=0
)
db.add(org)
await db.flush()
logger.info(f" ️ 自动创建缺失的组织详情:{org_name}")
# 检查是否已存在成员关系
existing_member = await db.execute(
select(OrganizationMember).where(
OrganizationMember.organization_id == org.id,
OrganizationMember.character_id == character.id
)
)
if existing_member.scalar_one_or_none():
logger.debug(f" ️ 成员关系已存在:{character.name} -> {org_name}")
continue
# 创建成员关系
member = OrganizationMember(
organization_id=org.id,
character_id=character.id,
position=membership.get("position", "成员"),
rank=membership.get("rank", 0),
loyalty=membership.get("loyalty", 50),
joined_at=membership.get("joined_at"),
status=membership.get("status", "active"),
source="ai"
)
db.add(member)
# 更新组织成员计数
org.member_count += 1
created_members += 1
logger.info(f" ✅ 添加成员:{character.name} -> {org_name} ({membership.get('position')})")
else:
logger.warning(f" ⚠️ 组织不存在:{org_name}")
except Exception as org_error:
logger.warning(f" ❌ 添加组织成员失败:{str(org_error)}")
continue
logger.info(f"✅ 成功创建 {created_members} 条组织成员记录")
# 记录生成历史
history = GenerationHistory(
project_id=request.project_id,
prompt=prompt,
generated_content=ai_response,
model=request.model or "default"
)
db.add(history)
await db.commit()
await db.refresh(character)
logger.info(f"🎉 成功为项目 {request.project_id} 生成角色: {character.name}")
return character
except Exception as e:
logger.error(f"生成角色失败: {str(e)}")
raise HTTPException(status_code=500, detail=f"生成角色失败: {str(e)}")
+341
View File
@@ -0,0 +1,341 @@
"""组织管理API"""
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, and_
from typing import List
from app.database import get_db
from app.models.relationship import Organization, OrganizationMember
from app.models.character import Character
from app.schemas.relationship import (
OrganizationCreate,
OrganizationUpdate,
OrganizationResponse,
OrganizationDetailResponse,
OrganizationMemberCreate,
OrganizationMemberUpdate,
OrganizationMemberResponse,
OrganizationMemberDetailResponse
)
from app.logger import get_logger
router = APIRouter(prefix="/organizations", tags=["组织管理"])
logger = get_logger(__name__)
@router.get("/project/{project_id}", response_model=List[OrganizationDetailResponse], summary="获取项目的所有组织")
async def get_project_organizations(
project_id: str,
db: AsyncSession = Depends(get_db)
):
"""
获取项目中的所有组织及其详情
返回组织的基本信息和统计数据
"""
result = await db.execute(
select(Organization).where(Organization.project_id == project_id)
)
organizations = result.scalars().all()
# 获取每个组织的角色信息
org_list = []
for org in organizations:
char_result = await db.execute(
select(Character).where(Character.id == org.character_id)
)
char = char_result.scalar_one_or_none()
if char:
org_list.append(OrganizationDetailResponse(
id=org.id,
character_id=org.character_id,
name=char.name,
type=char.organization_type,
purpose=char.organization_purpose,
member_count=org.member_count,
power_level=org.power_level,
location=org.location,
motto=org.motto,
color=org.color
))
logger.info(f"获取项目 {project_id} 的组织列表,共 {len(org_list)}")
return org_list
@router.get("/{org_id}", response_model=OrganizationResponse, summary="获取组织详情")
async def get_organization(
org_id: str,
db: AsyncSession = Depends(get_db)
):
"""获取组织的详细信息"""
result = await db.execute(
select(Organization).where(Organization.id == org_id)
)
org = result.scalar_one_or_none()
if not org:
raise HTTPException(status_code=404, detail="组织不存在")
return org
@router.post("/", response_model=OrganizationResponse, summary="创建组织")
async def create_organization(
organization: OrganizationCreate,
db: AsyncSession = Depends(get_db)
):
"""
创建新组织
- 需要关联到一个已存在的角色记录(is_organization=True
- 可以设置父组织、势力等级等属性
"""
# 验证角色是否存在且是组织
char_result = await db.execute(
select(Character).where(Character.id == organization.character_id)
)
char = char_result.scalar_one_or_none()
if not char:
raise HTTPException(status_code=404, detail="关联的角色不存在")
if not char.is_organization:
raise HTTPException(status_code=400, detail="关联的角色不是组织类型")
# 检查是否已存在
existing = await db.execute(
select(Organization).where(Organization.character_id == organization.character_id)
)
if existing.scalar_one_or_none():
raise HTTPException(status_code=400, detail="该角色已有组织详情记录")
# 创建组织
db_org = Organization(**organization.model_dump())
db.add(db_org)
await db.commit()
await db.refresh(db_org)
logger.info(f"创建组织成功:{db_org.id} - {char.name}")
return db_org
@router.put("/{org_id}", response_model=OrganizationResponse, summary="更新组织")
async def update_organization(
org_id: str,
organization: OrganizationUpdate,
db: AsyncSession = Depends(get_db)
):
"""更新组织的属性"""
result = await db.execute(
select(Organization).where(Organization.id == org_id)
)
db_org = result.scalar_one_or_none()
if not db_org:
raise HTTPException(status_code=404, detail="组织不存在")
# 更新字段
update_data = organization.model_dump(exclude_unset=True)
for field, value in update_data.items():
setattr(db_org, field, value)
await db.commit()
await db.refresh(db_org)
logger.info(f"更新组织成功:{org_id}")
return db_org
@router.delete("/{org_id}", summary="删除组织")
async def delete_organization(
org_id: str,
db: AsyncSession = Depends(get_db)
):
"""删除组织(会级联删除所有成员关系)"""
result = await db.execute(
select(Organization).where(Organization.id == org_id)
)
db_org = result.scalar_one_or_none()
if not db_org:
raise HTTPException(status_code=404, detail="组织不存在")
await db.delete(db_org)
await db.commit()
logger.info(f"删除组织成功:{org_id}")
return {"message": "组织删除成功", "id": org_id}
# ============ 组织成员管理 ============
@router.get("/{org_id}/members", response_model=List[OrganizationMemberDetailResponse], summary="获取组织成员")
async def get_organization_members(
org_id: str,
db: AsyncSession = Depends(get_db)
):
"""
获取组织的所有成员
按职位等级(rank)降序排列
"""
# 验证组织存在
org_result = await db.execute(
select(Organization).where(Organization.id == org_id)
)
if not org_result.scalar_one_or_none():
raise HTTPException(status_code=404, detail="组织不存在")
# 获取成员列表
result = await db.execute(
select(OrganizationMember)
.where(OrganizationMember.organization_id == org_id)
.order_by(OrganizationMember.rank.desc(), OrganizationMember.created_at)
)
members = result.scalars().all()
# 获取成员角色信息
member_list = []
for member in members:
char_result = await db.execute(
select(Character).where(Character.id == member.character_id)
)
char = char_result.scalar_one_or_none()
if char:
member_list.append(OrganizationMemberDetailResponse(
id=member.id,
character_id=member.character_id,
character_name=char.name,
position=member.position,
rank=member.rank,
loyalty=member.loyalty,
contribution=member.contribution,
status=member.status,
joined_at=member.joined_at,
left_at=member.left_at,
notes=member.notes
))
logger.info(f"获取组织 {org_id} 的成员列表,共 {len(member_list)}")
return member_list
@router.post("/{org_id}/members", response_model=OrganizationMemberResponse, summary="添加组织成员")
async def add_organization_member(
org_id: str,
member: OrganizationMemberCreate,
db: AsyncSession = Depends(get_db)
):
"""
添加角色到组织
- 一个角色在同一组织中只能有一个职位
- 会自动更新组织的成员计数
"""
# 验证组织存在
org_result = await db.execute(
select(Organization).where(Organization.id == org_id)
)
org = org_result.scalar_one_or_none()
if not org:
raise HTTPException(status_code=404, detail="组织不存在")
# 验证角色存在
char_result = await db.execute(
select(Character).where(Character.id == member.character_id)
)
char = char_result.scalar_one_or_none()
if not char:
raise HTTPException(status_code=404, detail="角色不存在")
if char.is_organization:
raise HTTPException(status_code=400, detail="不能将组织添加为成员")
# 检查是否已存在
existing = await db.execute(
select(OrganizationMember).where(
and_(
OrganizationMember.organization_id == org_id,
OrganizationMember.character_id == member.character_id
)
)
)
if existing.scalar_one_or_none():
raise HTTPException(status_code=400, detail="该角色已在组织中")
# 创建成员关系
db_member = OrganizationMember(
organization_id=org_id,
**member.model_dump(),
source="manual"
)
db.add(db_member)
# 更新组织成员计数
org.member_count += 1
await db.commit()
await db.refresh(db_member)
logger.info(f"添加成员成功:{char.name} 加入组织 {org_id}")
return db_member
@router.put("/members/{member_id}", response_model=OrganizationMemberResponse, summary="更新成员信息")
async def update_organization_member(
member_id: str,
member: OrganizationMemberUpdate,
db: AsyncSession = Depends(get_db)
):
"""更新组织成员的职位、忠诚度等信息"""
result = await db.execute(
select(OrganizationMember).where(OrganizationMember.id == member_id)
)
db_member = result.scalar_one_or_none()
if not db_member:
raise HTTPException(status_code=404, detail="成员记录不存在")
# 更新字段
update_data = member.model_dump(exclude_unset=True)
for field, value in update_data.items():
setattr(db_member, field, value)
await db.commit()
await db.refresh(db_member)
logger.info(f"更新成员信息成功:{member_id}")
return db_member
@router.delete("/members/{member_id}", summary="移除组织成员")
async def remove_organization_member(
member_id: str,
db: AsyncSession = Depends(get_db)
):
"""
从组织中移除成员
会自动更新组织的成员计数
"""
result = await db.execute(
select(OrganizationMember).where(OrganizationMember.id == member_id)
)
db_member = result.scalar_one_or_none()
if not db_member:
raise HTTPException(status_code=404, detail="成员记录不存在")
# 更新组织成员计数
org_result = await db.execute(
select(Organization).where(Organization.id == db_member.organization_id)
)
org = org_result.scalar_one()
org.member_count = max(0, org.member_count - 1)
await db.delete(db_member)
await db.commit()
logger.info(f"移除成员成功:{member_id}")
return {"message": "成员移除成功", "id": member_id}
+657
View File
@@ -0,0 +1,657 @@
"""大纲管理API"""
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, func, delete
from typing import List
import json
from app.database import get_db
from app.models.outline import Outline
from app.models.project import Project
from app.models.chapter import Chapter
from app.models.character import Character
from app.models.generation_history import GenerationHistory
from app.schemas.outline import (
OutlineCreate,
OutlineUpdate,
OutlineResponse,
OutlineListResponse,
OutlineGenerateRequest,
OutlineReorderRequest
)
from app.services.ai_service import ai_service
from app.services.prompt_service import prompt_service
from app.logger import get_logger
router = APIRouter(prefix="/outlines", tags=["大纲管理"])
logger = get_logger(__name__)
@router.post("", response_model=OutlineResponse, summary="创建大纲")
async def create_outline(
outline: OutlineCreate,
db: AsyncSession = Depends(get_db)
):
"""创建新的章节大纲,同时创建对应的章节记录"""
# 验证项目是否存在
result = await db.execute(
select(Project).where(Project.id == outline.project_id)
)
project = result.scalar_one_or_none()
if not project:
raise HTTPException(status_code=404, detail="项目不存在")
# 创建大纲
db_outline = Outline(**outline.model_dump())
db.add(db_outline)
# 同步创建对应的章节记录
chapter = Chapter(
project_id=outline.project_id,
chapter_number=outline.order_index,
title=outline.title,
summary=outline.content[:500] if len(outline.content) > 500 else outline.content,
status="draft"
)
db.add(chapter)
await db.commit()
await db.refresh(db_outline)
return db_outline
@router.get("", response_model=OutlineListResponse, summary="获取大纲列表")
async def get_outlines(
project_id: str,
db: AsyncSession = Depends(get_db)
):
"""获取指定项目的所有大纲"""
# 获取总数
count_result = await db.execute(
select(func.count(Outline.id)).where(Outline.project_id == project_id)
)
total = count_result.scalar_one()
# 获取大纲列表
result = await db.execute(
select(Outline)
.where(Outline.project_id == project_id)
.order_by(Outline.order_index)
)
outlines = result.scalars().all()
return OutlineListResponse(total=total, items=outlines)
@router.get("/project/{project_id}", response_model=OutlineListResponse, summary="获取项目的所有大纲")
async def get_project_outlines(
project_id: str,
db: AsyncSession = Depends(get_db)
):
"""获取指定项目的所有大纲(路径参数版本)"""
# 获取总数
count_result = await db.execute(
select(func.count(Outline.id)).where(Outline.project_id == project_id)
)
total = count_result.scalar_one()
# 获取大纲列表
result = await db.execute(
select(Outline)
.where(Outline.project_id == project_id)
.order_by(Outline.order_index)
)
outlines = result.scalars().all()
return OutlineListResponse(total=total, items=outlines)
@router.get("/{outline_id}", response_model=OutlineResponse, summary="获取大纲详情")
async def get_outline(
outline_id: str,
db: AsyncSession = Depends(get_db)
):
"""根据ID获取大纲详情"""
result = await db.execute(
select(Outline).where(Outline.id == outline_id)
)
outline = result.scalar_one_or_none()
if not outline:
raise HTTPException(status_code=404, detail="大纲不存在")
return outline
@router.put("/{outline_id}", response_model=OutlineResponse, summary="更新大纲")
async def update_outline(
outline_id: str,
outline_update: OutlineUpdate,
db: AsyncSession = Depends(get_db)
):
"""更新大纲信息,同步更新对应章节和structure字段"""
result = await db.execute(
select(Outline).where(Outline.id == outline_id)
)
outline = result.scalar_one_or_none()
if not outline:
raise HTTPException(status_code=404, detail="大纲不存在")
# 更新字段
update_data = outline_update.model_dump(exclude_unset=True)
for field, value in update_data.items():
setattr(outline, field, value)
# 如果修改了content或title,同步更新structure字段
if 'content' in update_data or 'title' in update_data:
try:
# 尝试解析现有的structure
if outline.structure:
structure_data = json.loads(outline.structure)
else:
structure_data = {}
# 更新structure中的对应字段
if 'title' in update_data:
structure_data['title'] = outline.title
if 'content' in update_data:
structure_data['summary'] = outline.content
structure_data['content'] = outline.content
# 保存更新后的structure
outline.structure = json.dumps(structure_data, ensure_ascii=False)
logger.info(f"同步更新大纲 {outline_id} 的structure字段")
except json.JSONDecodeError:
logger.warning(f"大纲 {outline_id} 的structure字段格式错误,跳过更新")
# 同步更新对应的章节标题和摘要
if 'title' in update_data or 'content' in update_data:
chapter_result = await db.execute(
select(Chapter).where(
Chapter.project_id == outline.project_id,
Chapter.chapter_number == outline.order_index
)
)
chapter = chapter_result.scalar_one_or_none()
if chapter:
if 'title' in update_data:
chapter.title = outline.title
if 'content' in update_data:
# 更新章节摘要(取content前500字符)
chapter.summary = outline.content[:500] if len(outline.content) > 500 else outline.content
logger.info(f"同步更新章节 {chapter.id} 的标题和摘要")
else:
logger.warning(f"未找到对应的章节记录 (order_index={outline.order_index})")
await db.commit()
await db.refresh(outline)
return outline
@router.delete("/{outline_id}", summary="删除大纲")
async def delete_outline(
outline_id: str,
db: AsyncSession = Depends(get_db)
):
"""删除大纲,同步删除章节,并重新排序后续项"""
result = await db.execute(
select(Outline).where(Outline.id == outline_id)
)
outline = result.scalar_one_or_none()
if not outline:
raise HTTPException(status_code=404, detail="大纲不存在")
project_id = outline.project_id
deleted_order = outline.order_index
# 删除对应的章节
await db.execute(
delete(Chapter).where(
Chapter.project_id == project_id,
Chapter.chapter_number == deleted_order
)
)
# 删除大纲
await db.delete(outline)
# 重新排序后续的大纲和章节(序号-1)
result = await db.execute(
select(Outline).where(
Outline.project_id == project_id,
Outline.order_index > deleted_order
)
)
subsequent_outlines = result.scalars().all()
for o in subsequent_outlines:
old_order = o.order_index
o.order_index -= 1
# 同步更新对应的章节
chapter_result = await db.execute(
select(Chapter).where(
Chapter.project_id == project_id,
Chapter.chapter_number == old_order
)
)
chapter = chapter_result.scalar_one_or_none()
if chapter:
chapter.chapter_number = old_order - 1
await db.commit()
return {"message": "大纲删除成功"}
@router.post("/reorder", summary="批量重排序大纲")
async def reorder_outlines(
request: OutlineReorderRequest,
db: AsyncSession = Depends(get_db)
):
"""
批量调整大纲顺序,同步更新章节序号
策略:先收集所有变更,最后一次性提交,避免临时冲突
"""
try:
# 第一步:收集所有大纲和对应的章节
outline_chapter_map = {} # {outline_id: (outline, chapter, old_order, new_order)}
for item in request.orders:
outline_id = item.id
new_order = item.order_index
# 获取大纲
result = await db.execute(
select(Outline).where(Outline.id == outline_id)
)
outline = result.scalar_one_or_none()
if not outline:
logger.warning(f"大纲 {outline_id} 不存在,跳过")
continue
old_order = outline.order_index
# 获取对应的章节(通过旧的chapter_number匹配)
chapter_result = await db.execute(
select(Chapter).where(
Chapter.project_id == outline.project_id,
Chapter.chapter_number == old_order
)
)
chapter = chapter_result.first()
chapter_obj = chapter[0] if chapter else None
outline_chapter_map[outline_id] = (outline, chapter_obj, old_order, new_order)
# 第二步:批量更新所有大纲和章节
updated_outlines = 0
updated_chapters = 0
for outline_id, (outline, chapter, old_order, new_order) in outline_chapter_map.items():
# 更新大纲
outline.order_index = new_order
updated_outlines += 1
# 更新章节
if chapter:
chapter.chapter_number = new_order
chapter.title = outline.title # 同步更新标题
updated_chapters += 1
else:
logger.warning(f"章节 {old_order} 不存在,跳过")
# 第三步:一次性提交所有更改
await db.commit()
logger.info(f"重排序成功:更新了 {updated_outlines} 个大纲,{updated_chapters} 个章节")
return {
"message": "重排序成功",
"updated_outlines": updated_outlines,
"updated_chapters": updated_chapters
}
except Exception as e:
await db.rollback()
logger.error(f"重排序失败: {str(e)}")
raise HTTPException(status_code=500, detail=f"重排序失败: {str(e)}")
@router.post("/generate", response_model=OutlineListResponse, summary="AI生成/续写大纲")
async def generate_outline(
request: OutlineGenerateRequest,
db: AsyncSession = Depends(get_db)
):
"""
使用AI生成或续写小说大纲 - 智能模式
支持三种模式:
- auto: 自动判断(无大纲→新建,有大纲→续写)
- new: 强制全新生成
- continue: 强制续写模式
"""
# 验证项目是否存在
result = await db.execute(
select(Project).where(Project.id == request.project_id)
)
project = result.scalar_one_or_none()
if not project:
raise HTTPException(status_code=404, detail="项目不存在")
try:
# 获取现有大纲(强制从数据库获取最新数据,包括用户手动修改的内容)
existing_result = await db.execute(
select(Outline)
.where(Outline.project_id == request.project_id)
.order_by(Outline.order_index)
.execution_options(populate_existing=True)
)
existing_outlines = existing_result.scalars().all()
# 判断实际执行模式
actual_mode = request.mode
if actual_mode == "auto":
actual_mode = "continue" if existing_outlines else "new"
logger.info(f"自动判断模式:{'续写' if existing_outlines else '新建'}")
# 模式:全新生成
if actual_mode == "new":
return await _generate_new_outline(
request, project, db
)
# 模式:续写
elif actual_mode == "continue":
if not existing_outlines:
raise HTTPException(
status_code=400,
detail="续写模式需要已有大纲,当前项目没有大纲"
)
return await _continue_outline(
request, project, existing_outlines, db
)
else:
raise HTTPException(
status_code=400,
detail=f"不支持的模式: {request.mode}"
)
except HTTPException:
raise
except Exception as e:
logger.error(f"生成大纲失败: {str(e)}")
raise HTTPException(status_code=500, detail=f"生成大纲失败: {str(e)}")
async def _generate_new_outline(
request: OutlineGenerateRequest,
project: Project,
db: AsyncSession
) -> OutlineListResponse:
"""全新生成大纲"""
logger.info(f"全新生成大纲 - 项目: {project.id}, keep_existing: {request.keep_existing}")
# 获取角色信息
characters_result = await db.execute(
select(Character).where(Character.project_id == project.id)
)
characters = characters_result.scalars().all()
characters_info = "\n".join([
f"- {char.name} ({'组织' if char.is_organization else '角色'}, {char.role_type}): "
f"{char.personality[:100] if char.personality else '暂无描述'}"
for char in characters
])
# 使用完整提示词
prompt = prompt_service.get_complete_outline_prompt(
title=project.title,
theme=request.theme or project.theme or "未设定",
genre=request.genre or project.genre or "通用",
chapter_count=request.chapter_count,
narrative_perspective=request.narrative_perspective,
target_words=request.target_words,
time_period=project.world_time_period or "未设定",
location=project.world_location or "未设定",
atmosphere=project.world_atmosphere or "未设定",
rules=project.world_rules or "未设定",
characters_info=characters_info or "暂无角色信息",
requirements=request.requirements or ""
)
# 调用AI
ai_response = await ai_service.generate_text(
prompt=prompt,
provider=request.provider,
model=request.model
)
# 解析响应
outline_data = _parse_ai_response(ai_response)
# 全新生成模式:必须删除旧大纲和章节
# 注意:这是"new"模式的核心逻辑,应该始终删除旧数据
logger.info(f"删除项目 {project.id} 的旧大纲和章节")
await db.execute(
delete(Outline).where(Outline.project_id == project.id)
)
await db.execute(
delete(Chapter).where(Chapter.project_id == project.id)
)
# 保存新大纲
outlines = await _save_outlines(
project.id, outline_data, db, start_index=1
)
# 记录历史
history = GenerationHistory(
project_id=project.id,
prompt=prompt,
generated_content=ai_response,
model=request.model or "default"
)
db.add(history)
await db.commit()
for outline in outlines:
await db.refresh(outline)
logger.info(f"全新生成完成 - {len(outlines)}")
return OutlineListResponse(total=len(outlines), items=outlines)
async def _continue_outline(
request: OutlineGenerateRequest,
project: Project,
existing_outlines: List[Outline],
db: AsyncSession
) -> OutlineListResponse:
"""续写大纲"""
logger.info(f"续写大纲 - 项目: {project.id}, 已有: {len(existing_outlines)}")
# 分析已有大纲
current_chapter_count = len(existing_outlines)
last_chapter_number = existing_outlines[-1].order_index
# 获取最近2章的剧情
recent_outlines = existing_outlines[-2:] if len(existing_outlines) >= 2 else existing_outlines
recent_plot = "\n".join([
f"{o.order_index}章《{o.title}》: {o.content}"
for o in recent_outlines
])
# logger.debug(f"最近三章内容:{recent_plot}")
# 全部章节概览
all_chapters_brief = "\n".join([
f"{o.order_index}章: {o.title}"
for o in existing_outlines
])
# 获取角色信息
characters_result = await db.execute(
select(Character).where(Character.project_id == project.id)
)
characters = characters_result.scalars().all()
characters_info = "\n".join([
f"- {char.name} ({'组织' if char.is_organization else '角色'}, {char.role_type}): "
f"{char.personality[:100] if char.personality else '暂无描述'}"
for char in characters
])
# 情节阶段指导
stage_instructions = {
"development": "继续展开情节,深化角色关系,推进主线冲突",
"climax": "进入故事高潮,矛盾激化,关键冲突爆发",
"ending": "解决主要冲突,收束伏笔,给出结局"
}
stage_instruction = stage_instructions.get(request.plot_stage, "")
# 使用标准续写提示词模板
prompt = prompt_service.get_outline_continue_prompt(
title=project.title,
theme=request.theme or project.theme or "未设定",
genre=request.genre or project.genre or "通用",
narrative_perspective=request.narrative_perspective,
chapter_count=request.chapter_count,
time_period=project.world_time_period or "未设定",
location=project.world_location or "未设定",
atmosphere=project.world_atmosphere or "未设定",
rules=project.world_rules or "未设定",
characters_info=characters_info or "暂无角色信息",
current_chapter_count=current_chapter_count,
all_chapters_brief=all_chapters_brief,
recent_plot=recent_plot,
plot_stage_instruction=stage_instruction,
start_chapter=last_chapter_number + 1,
story_direction=request.story_direction or "自然延续",
requirements=request.requirements or ""
)
# 调用AI
ai_response = await ai_service.generate_text(
prompt=prompt,
provider=request.provider,
model=request.model
)
# 解析响应
outline_data = _parse_ai_response(ai_response)
# 保存续写的大纲
new_outlines = await _save_outlines(
project.id, outline_data, db, start_index=last_chapter_number + 1
)
# 记录历史
history = GenerationHistory(
project_id=project.id,
prompt=prompt,
generated_content=ai_response,
model=request.model or "default"
)
db.add(history)
await db.commit()
for outline in new_outlines:
await db.refresh(outline)
# 返回所有大纲(包括旧的和新的)
all_result = await db.execute(
select(Outline)
.where(Outline.project_id == project.id)
.order_by(Outline.order_index)
)
all_outlines = all_result.scalars().all()
logger.info(f"续写完成 - 新增 {len(new_outlines)} 章,总计 {len(all_outlines)}")
return OutlineListResponse(total=len(all_outlines), items=all_outlines)
def _parse_ai_response(ai_response: str) -> list:
"""解析AI响应为章节数据列表"""
try:
# 清理响应文本
cleaned_text = ai_response.strip()
if cleaned_text.startswith('```json'):
cleaned_text = cleaned_text[7:]
if cleaned_text.startswith('```'):
cleaned_text = cleaned_text[3:]
if cleaned_text.endswith('```'):
cleaned_text = cleaned_text[:-3]
cleaned_text = cleaned_text.strip()
outline_data = json.loads(cleaned_text)
# 确保是列表格式
if not isinstance(outline_data, list):
# 如果是对象,尝试提取chapters字段
if isinstance(outline_data, dict):
outline_data = outline_data.get("chapters", [outline_data])
else:
outline_data = [outline_data]
return outline_data
except json.JSONDecodeError as e:
logger.error(f"AI响应解析失败: {e}")
# 返回一个包含原始内容的章节
return [{
"title": "AI生成的大纲",
"content": ai_response[:1000],
"summary": ai_response[:1000]
}]
async def _save_outlines(
project_id: str,
outline_data: list,
db: AsyncSession,
start_index: int = 1
) -> List[Outline]:
"""保存大纲到数据库"""
outlines = []
for idx, chapter_data in enumerate(outline_data):
order_idx = chapter_data.get("chapter_number", start_index + idx)
title = chapter_data.get("title", f"{order_idx}")
# 优先使用summary,其次content
content = chapter_data.get("summary") or chapter_data.get("content", "")
# 如果有额外信息,添加到内容中
if "key_events" in chapter_data:
content += f"\n\n关键事件:" + "".join(chapter_data["key_events"])
if "characters_involved" in chapter_data:
content += f"\n涉及角色:" + "".join(chapter_data["characters_involved"])
# 创建大纲
outline = Outline(
project_id=project_id,
title=title,
content=content,
structure=json.dumps(chapter_data, ensure_ascii=False),
order_index=order_idx
)
db.add(outline)
outlines.append(outline)
# 同步创建章节记录
chapter = Chapter(
project_id=project_id,
chapter_number=order_idx,
title=title,
summary=content[:500] if len(content) > 500 else content,
status="draft"
)
db.add(chapter)
return outlines
+124
View File
@@ -0,0 +1,124 @@
"""AI去味API - 核心特色功能"""
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.ext.asyncio import AsyncSession
from app.database import get_db
from app.models.generation_history import GenerationHistory
from app.schemas.polish import PolishRequest, PolishResponse
from app.services.ai_service import ai_service
from app.services.prompt_service import prompt_service
from app.logger import get_logger
router = APIRouter(prefix="/polish", tags=["AI去味"])
logger = get_logger(__name__)
@router.post("", response_model=PolishResponse, summary="AI去味")
async def polish_text(
request: PolishRequest,
db: AsyncSession = Depends(get_db)
):
"""
AI去味 - 将AI生成的文本改写得更像人类作家的手笔
核心功能:
- 去除AI痕迹(工整排比、重复修辞、机械总结)
- 增加人性化(口语化、不完美细节、真实情感)
- 优化叙事(自然节奏、简单词汇、松弛感)
- 让对话更生活化
这是本项目的核心特色功能!
"""
try:
# 构建AI去味提示词
prompt = prompt_service.get_denoising_prompt(
original_text=request.original_text
)
logger.info(f"开始AI去味处理,原文长度: {len(request.original_text)}")
# 调用AI进行去味处理
polished_text = await ai_service.generate_text(
prompt=prompt,
provider=request.provider,
model=request.model,
temperature=request.temperature,
max_tokens=len(request.original_text) * 2 # 预留足够token
)
# 计算字数
word_count_before = len(request.original_text)
word_count_after = len(polished_text)
logger.info(f"AI去味完成,处理后长度: {word_count_after}")
# 如果提供了项目ID,记录到历史
if request.project_id:
history = GenerationHistory(
project_id=request.project_id,
generation_type="polish",
prompt=f"原文: {request.original_text[:100]}...",
result=polished_text,
provider=request.provider or "default",
model=request.model or "default"
)
db.add(history)
await db.commit()
return PolishResponse(
original_text=request.original_text,
polished_text=polished_text,
word_count_before=word_count_before,
word_count_after=word_count_after
)
except Exception as e:
logger.error(f"AI去味失败: {str(e)}")
raise HTTPException(status_code=500, detail=f"AI去味失败: {str(e)}")
@router.post("/batch", summary="批量AI去味")
async def polish_batch(
texts: list[str],
project_id: int = None,
provider: str = None,
model: str = None,
db: AsyncSession = Depends(get_db)
):
"""
批量处理多个文本的AI去味
适用于一次性处理多个章节或段落
"""
try:
results = []
for idx, text in enumerate(texts):
logger.info(f"处理第 {idx+1}/{len(texts)} 个文本")
prompt = prompt_service.get_denoising_prompt(original_text=text)
polished_text = await ai_service.generate_text(
prompt=prompt,
provider=provider,
model=model
)
results.append({
"index": idx,
"original": text,
"polished": polished_text,
"word_count_before": len(text),
"word_count_after": len(polished_text)
})
logger.info(f"批量AI去味完成,共处理 {len(results)} 个文本")
return {
"total": len(results),
"results": results
}
except Exception as e:
logger.error(f"批量AI去味失败: {str(e)}")
raise HTTPException(status_code=500, detail=f"批量AI去味失败: {str(e)}")
+414
View File
@@ -0,0 +1,414 @@
"""项目管理API"""
from fastapi import APIRouter, Depends, HTTPException
from fastapi.responses import Response
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, func, delete
from typing import List
from app.database import get_db
from app.models.project import Project
from app.models.character import Character
from app.models.outline import Outline
from app.models.chapter import Chapter
from app.models.generation_history import GenerationHistory
from app.models.relationship import CharacterRelationship, Organization, OrganizationMember
from app.schemas.project import (
ProjectCreate,
ProjectUpdate,
ProjectResponse,
ProjectListResponse
)
from app.logger import get_logger
from app.utils.data_consistency import (
run_full_data_consistency_check,
fix_missing_organization_records,
fix_organization_member_counts
)
logger = get_logger(__name__)
router = APIRouter(prefix="/projects", tags=["项目管理"])
@router.post("", response_model=ProjectResponse, summary="创建项目")
async def create_project(
project: ProjectCreate,
db: AsyncSession = Depends(get_db)
):
try:
logger.info(f"创建新项目: {project.title}")
db_project = Project(**project.model_dump())
db.add(db_project)
await db.commit()
await db.refresh(db_project)
logger.info(f"项目创建成功: {db_project.id}")
return db_project
except Exception as e:
logger.error(f"创建项目失败: {str(e)}", exc_info=True)
raise
@router.get("", response_model=ProjectListResponse, summary="获取项目列表")
async def get_projects(
skip: int = 0,
limit: int = 100,
db: AsyncSession = Depends(get_db)
):
"""获取所有项目列表"""
try:
logger.debug(f"获取项目列表: skip={skip}, limit={limit}")
count_result = await db.execute(select(func.count(Project.id)))
total = count_result.scalar_one()
result = await db.execute(
select(Project)
.order_by(Project.updated_at.desc())
.offset(skip)
.limit(limit)
)
projects = result.scalars().all()
logger.info(f"获取项目列表成功: 共{total}个项目")
return ProjectListResponse(total=total, items=projects)
except Exception as e:
logger.error(f"获取项目列表失败: {str(e)}", exc_info=True)
raise
@router.get("/{project_id}", response_model=ProjectResponse, summary="获取项目详情")
async def get_project(
project_id: str,
db: AsyncSession = Depends(get_db)
):
try:
logger.debug(f"获取项目详情: {project_id}")
result = await db.execute(
select(Project).where(Project.id == project_id)
)
project = result.scalar_one_or_none()
if not project:
logger.warning(f"项目不存在: {project_id}")
raise HTTPException(status_code=404, detail="项目不存在")
logger.info(f"获取项目详情成功: {project.title}")
return project
except HTTPException:
raise
except Exception as e:
logger.error(f"获取项目详情失败: {str(e)}", exc_info=True)
raise
@router.put("/{project_id}", response_model=ProjectResponse, summary="更新项目")
async def update_project(
project_id: str,
project_update: ProjectUpdate,
db: AsyncSession = Depends(get_db)
):
try:
logger.info(f"更新项目: {project_id}")
result = await db.execute(
select(Project).where(Project.id == project_id)
)
project = result.scalar_one_or_none()
if not project:
logger.warning(f"项目不存在: {project_id}")
raise HTTPException(status_code=404, detail="项目不存在")
update_data = project_update.model_dump(exclude_unset=True)
logger.debug(f"更新字段: {list(update_data.keys())}")
for field, value in update_data.items():
setattr(project, field, value)
await db.commit()
await db.refresh(project)
logger.info(f"项目更新成功: {project.title}")
return project
except HTTPException:
raise
except Exception as e:
logger.error(f"更新项目失败: {str(e)}", exc_info=True)
raise
@router.delete("/{project_id}", summary="删除项目")
async def delete_project(
project_id: str,
db: AsyncSession = Depends(get_db)
):
try:
logger.info(f"删除项目: {project_id}")
result = await db.execute(
select(Project).where(Project.id == project_id)
)
project = result.scalar_one_or_none()
if not project:
logger.warning(f"项目不存在: {project_id}")
raise HTTPException(status_code=404, detail="项目不存在")
project_title = project.title
relationships_result = await db.execute(
delete(CharacterRelationship).where(CharacterRelationship.project_id == project_id)
)
logger.debug(f"删除角色关系数: {relationships_result.rowcount}")
orgs_result = await db.execute(
select(Organization).where(Organization.project_id == project_id)
)
orgs = orgs_result.scalars().all()
org_member_count = 0
for org in orgs:
members_result = await db.execute(
delete(OrganizationMember).where(OrganizationMember.organization_id == org.id)
)
org_member_count += members_result.rowcount
logger.debug(f"删除组织成员数: {org_member_count}")
organizations_result = await db.execute(
delete(Organization).where(Organization.project_id == project_id)
)
logger.debug(f"删除组织数: {organizations_result.rowcount}")
history_result = await db.execute(
delete(GenerationHistory).where(GenerationHistory.project_id == project_id)
)
logger.debug(f"删除生成历史数: {history_result.rowcount}")
chapters_result = await db.execute(
delete(Chapter).where(Chapter.project_id == project_id)
)
logger.debug(f"删除章节数: {chapters_result.rowcount}")
outlines_result = await db.execute(
delete(Outline).where(Outline.project_id == project_id)
)
logger.debug(f"删除大纲数: {outlines_result.rowcount}")
characters_result = await db.execute(
delete(Character).where(Character.project_id == project_id)
)
logger.debug(f"删除角色数: {characters_result.rowcount}")
await db.delete(project)
await db.commit()
logger.info(f"项目删除成功: {project_title}")
return {"message": "项目及所有关联数据删除成功"}
except HTTPException:
raise
except Exception as e:
logger.error(f"删除项目失败: {str(e)}", exc_info=True)
raise
@router.get("/{project_id}/export", summary="导出项目章节为TXT")
async def export_project_chapters(
project_id: str,
db: AsyncSession = Depends(get_db)
):
"""
导出项目的所有章节内容为TXT文本文件
按章节顺序组织,包含项目基本信息
"""
try:
logger.info(f"开始导出项目: {project_id}")
result = await db.execute(
select(Project).where(Project.id == project_id)
)
project = result.scalar_one_or_none()
if not project:
logger.warning(f"项目不存在: {project_id}")
raise HTTPException(status_code=404, detail="项目不存在")
chapters_result = await db.execute(
select(Chapter)
.where(Chapter.project_id == project_id)
.order_by(Chapter.chapter_number)
)
chapters = chapters_result.scalars().all()
if not chapters:
logger.warning(f"项目没有章节: {project_id}")
raise HTTPException(status_code=404, detail="项目没有任何章节")
txt_content = []
txt_content.append("=" * 80)
txt_content.append(f"项目标题: {project.title}")
txt_content.append("=" * 80)
if project.description:
txt_content.append(f"\n简介: {project.description}\n")
if project.theme:
txt_content.append(f"主题: {project.theme}")
if project.genre:
txt_content.append(f"类型: {project.genre}")
txt_content.append(f"总章节数: {len(chapters)}")
txt_content.append(f"总字数: {project.current_words}")
txt_content.append("\n" + "=" * 80 + "\n\n")
for chapter in chapters:
txt_content.append(f"{chapter.chapter_number}{chapter.title}")
txt_content.append("-" * 80)
txt_content.append("") # 空行
if chapter.content:
txt_content.append(chapter.content)
else:
txt_content.append("(本章暂无内容)")
txt_content.append("\n\n" + "=" * 80 + "\n\n")
txt_content.append(f"--- 全文完 ---")
txt_content.append(f"\n导出时间: {func.now()}")
final_content = "\n".join(txt_content)
safe_title = "".join(c for c in project.title if c.isalnum() or c in (' ', '-', '_', '', '', ''))
filename = f"{safe_title}.txt"
from urllib.parse import quote
encoded_filename = quote(filename)
logger.info(f"导出成功: {filename}, 共{len(chapters)}章, {len(final_content)}字符")
return Response(
content=final_content.encode('utf-8'),
media_type="text/plain; charset=utf-8",
headers={
"Content-Disposition": f"attachment; filename*=UTF-8''{encoded_filename}",
"Content-Type": "text/plain; charset=utf-8"
}
)
except HTTPException:
raise
except Exception as e:
logger.error(f"导出项目失败: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail=f"导出失败: {str(e)}")
@router.post("/{project_id}/check-consistency", summary="检查数据一致性")
async def check_project_consistency(
project_id: str,
auto_fix: bool = True,
db: AsyncSession = Depends(get_db)
):
"""
检查并修复项目的数据一致性问题
Args:
project_id: 项目ID
auto_fix: 是否自动修复问题(默认True)
返回检查报告,包含:
- organization_records: 检查并修复缺失的Organization记录
- member_counts: 检查并修复组织成员计数
- relationships: 验证关系数据完整性
- organization_members: 验证组织成员数据完整性
"""
try:
logger.info(f"开始数据一致性检查: {project_id}, auto_fix={auto_fix}")
result = await db.execute(
select(Project).where(Project.id == project_id)
)
project = result.scalar_one_or_none()
if not project:
logger.warning(f"项目不存在: {project_id}")
raise HTTPException(status_code=404, detail="项目不存在")
report = await run_full_data_consistency_check(project_id, db, auto_fix)
logger.info(f"数据一致性检查完成: {project_id}")
return report
except HTTPException:
raise
except Exception as e:
logger.error(f"数据一致性检查失败: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail=f"检查失败: {str(e)}")
@router.post("/{project_id}/fix-organizations", summary="修复组织记录")
async def fix_project_organizations(
project_id: str,
db: AsyncSession = Depends(get_db)
):
"""
修复项目中缺失的Organization记录
为所有is_organization=True但没有Organization记录的Character创建记录
"""
try:
logger.info(f"开始修复组织记录: {project_id}")
result = await db.execute(
select(Project).where(Project.id == project_id)
)
project = result.scalar_one_or_none()
if not project:
logger.warning(f"项目不存在: {project_id}")
raise HTTPException(status_code=404, detail="项目不存在")
fixed_count, total_count = await fix_missing_organization_records(project_id, db)
logger.info(f"组织记录修复完成: {project_id}, 修复{fixed_count}/{total_count}")
return {
"message": "组织记录修复完成",
"fixed": fixed_count,
"total": total_count
}
except HTTPException:
raise
except Exception as e:
logger.error(f"修复组织记录失败: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail=f"修复失败: {str(e)}")
@router.post("/{project_id}/fix-member-counts", summary="修复成员计数")
async def fix_project_member_counts(
project_id: str,
db: AsyncSession = Depends(get_db)
):
"""
修复项目中所有组织的成员计数
从实际成员记录重新计算每个组织的member_count
"""
try:
logger.info(f"开始修复成员计数: {project_id}")
result = await db.execute(
select(Project).where(Project.id == project_id)
)
project = result.scalar_one_or_none()
if not project:
logger.warning(f"项目不存在: {project_id}")
raise HTTPException(status_code=404, detail="项目不存在")
fixed_count, total_count = await fix_organization_member_counts(project_id, db)
logger.info(f"成员计数修复完成: {project_id}, 修复{fixed_count}/{total_count}")
return {
"message": "成员计数修复完成",
"fixed": fixed_count,
"total": total_count
}
except HTTPException:
raise
except Exception as e:
logger.error(f"修复成员计数失败: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail=f"修复失败: {str(e)}")
+209
View File
@@ -0,0 +1,209 @@
"""关系管理API"""
from fastapi import APIRouter, Depends, HTTPException, Query
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, or_, and_
from typing import List, Optional
from app.database import get_db
from app.models.relationship import (
RelationshipType,
CharacterRelationship,
Organization,
OrganizationMember
)
from app.models.character import Character
from app.schemas.relationship import (
RelationshipTypeResponse,
CharacterRelationshipCreate,
CharacterRelationshipUpdate,
CharacterRelationshipResponse,
RelationshipGraphData,
RelationshipGraphNode,
RelationshipGraphLink
)
from app.logger import get_logger
router = APIRouter(prefix="/relationships", tags=["关系管理"])
logger = get_logger(__name__)
@router.get("/types", response_model=List[RelationshipTypeResponse], summary="获取关系类型列表")
async def get_relationship_types(db: AsyncSession = Depends(get_db)):
"""获取所有预定义的关系类型"""
result = await db.execute(select(RelationshipType).order_by(RelationshipType.category, RelationshipType.id))
types = result.scalars().all()
return types
@router.get("/project/{project_id}", response_model=List[CharacterRelationshipResponse], summary="获取项目的所有关系")
async def get_project_relationships(
project_id: str,
character_id: Optional[str] = Query(None, description="筛选特定角色的关系"),
db: AsyncSession = Depends(get_db)
):
"""
获取项目中的所有角色关系
- 如果提供character_id,则只返回与该角色相关的关系(作为发起方或接收方)
- 否则返回项目中的所有关系
"""
query = select(CharacterRelationship).where(
CharacterRelationship.project_id == project_id
)
if character_id:
query = query.where(
or_(
CharacterRelationship.character_from_id == character_id,
CharacterRelationship.character_to_id == character_id
)
)
query = query.order_by(CharacterRelationship.created_at.desc())
result = await db.execute(query)
relationships = result.scalars().all()
logger.info(f"获取项目 {project_id} 的关系列表,共 {len(relationships)}")
return relationships
@router.get("/graph/{project_id}", response_model=RelationshipGraphData, summary="获取关系图谱数据")
async def get_relationship_graph(
project_id: str,
db: AsyncSession = Depends(get_db)
):
"""
获取用于可视化的关系图谱数据
返回格式:
- nodes: 角色节点列表
- links: 关系连线列表
"""
# 获取所有角色(节点)
chars_result = await db.execute(
select(Character).where(Character.project_id == project_id)
)
characters = chars_result.scalars().all()
nodes = [
RelationshipGraphNode(
id=c.id,
name=c.name,
type="organization" if c.is_organization else "character",
role_type=c.role_type,
avatar=c.avatar_url
)
for c in characters
]
# 获取所有关系(边)
rels_result = await db.execute(
select(CharacterRelationship).where(
CharacterRelationship.project_id == project_id
)
)
relationships = rels_result.scalars().all()
links = [
RelationshipGraphLink(
source=r.character_from_id,
target=r.character_to_id,
relationship=r.relationship_name or "未知关系",
intimacy=r.intimacy_level,
status=r.status
)
for r in relationships
]
logger.info(f"获取项目 {project_id} 的关系图谱:{len(nodes)} 个节点,{len(links)} 条关系")
return RelationshipGraphData(nodes=nodes, links=links)
@router.post("/", response_model=CharacterRelationshipResponse, summary="创建角色关系")
async def create_relationship(
relationship: CharacterRelationshipCreate,
db: AsyncSession = Depends(get_db)
):
"""
手动创建角色关系
- 需要提供角色A和角色B的ID
- 可以指定预定义的关系类型或自定义关系名称
- 可以设置亲密度、状态等属性
"""
# 验证角色是否存在
char_from = await db.execute(
select(Character).where(Character.id == relationship.character_from_id)
)
char_to = await db.execute(
select(Character).where(Character.id == relationship.character_to_id)
)
if not char_from.scalar_one_or_none():
raise HTTPException(status_code=404, detail=f"角色AID: {relationship.character_from_id})不存在")
if not char_to.scalar_one_or_none():
raise HTTPException(status_code=404, detail=f"角色BID: {relationship.character_to_id})不存在")
# 创建关系
db_relationship = CharacterRelationship(
**relationship.model_dump(),
source="manual"
)
db.add(db_relationship)
await db.commit()
await db.refresh(db_relationship)
logger.info(f"创建关系成功:{relationship.character_from_id} -> {relationship.character_to_id}")
return db_relationship
@router.put("/{relationship_id}", response_model=CharacterRelationshipResponse, summary="更新关系")
async def update_relationship(
relationship_id: str,
relationship: CharacterRelationshipUpdate,
db: AsyncSession = Depends(get_db)
):
"""更新角色关系的属性(亲密度、状态等)"""
result = await db.execute(
select(CharacterRelationship).where(
CharacterRelationship.id == relationship_id
)
)
db_rel = result.scalar_one_or_none()
if not db_rel:
raise HTTPException(status_code=404, detail="关系不存在")
# 更新字段
update_data = relationship.model_dump(exclude_unset=True)
for field, value in update_data.items():
setattr(db_rel, field, value)
await db.commit()
await db.refresh(db_rel)
logger.info(f"更新关系成功:{relationship_id}")
return db_rel
@router.delete("/{relationship_id}", summary="删除关系")
async def delete_relationship(
relationship_id: str,
db: AsyncSession = Depends(get_db)
):
"""删除角色关系"""
result = await db.execute(
select(CharacterRelationship).where(
CharacterRelationship.id == relationship_id
)
)
db_rel = result.scalar_one_or_none()
if not db_rel:
raise HTTPException(status_code=404, detail="关系不存在")
await db.delete(db_rel)
await db.commit()
logger.info(f"删除关系成功:{relationship_id}")
return {"message": "关系删除成功", "id": relationship_id}
+125
View File
@@ -0,0 +1,125 @@
"""
用户管理 API
"""
from fastapi import APIRouter, HTTPException, Request, Depends
from pydantic import BaseModel
from typing import List, Optional
from app.user_manager import user_manager, User
router = APIRouter(prefix="/users", tags=["用户管理"])
def require_login(request: Request):
"""依赖:要求用户已登录"""
if not hasattr(request.state, "user") or not request.state.user:
raise HTTPException(status_code=401, detail="需要登录")
return request.state.user
def require_admin(request: Request):
"""依赖:要求用户为管理员"""
user = require_login(request)
if not request.state.is_admin:
raise HTTPException(status_code=403, detail="需要管理员权限")
return user
class SetAdminRequest(BaseModel):
user_id: str
is_admin: bool
@router.get("/current")
async def get_current_user(user: User = Depends(require_login)):
"""获取当前登录用户信息"""
return user.dict()
@router.get("", response_model=List[dict])
async def list_users(admin_user: User = Depends(require_admin)):
"""
获取所有用户列表(仅管理员)
"""
users = await user_manager.get_all_users()
return [user.dict() for user in users]
@router.post("/set-admin")
async def set_admin(
data: SetAdminRequest,
request: Request,
admin_user: User = Depends(require_admin)
):
"""
设置用户的管理员权限(仅管理员)
限制:
- 不能撤销自己的管理员权限
- 至少保留一个管理员
"""
# 检查是否尝试撤销自己的权限
if data.user_id == admin_user.user_id and not data.is_admin:
raise HTTPException(
status_code=400,
detail="不能撤销自己的管理员权限"
)
# 尝试设置管理员权限
success = await user_manager.set_admin(data.user_id, data.is_admin)
if not success:
if not data.is_admin:
raise HTTPException(
status_code=400,
detail="无法撤销管理员权限,至少需要保留一个管理员"
)
else:
raise HTTPException(
status_code=404,
detail="用户不存在"
)
return {
"message": f"{'授予' if data.is_admin else '撤销'}管理员权限",
"user_id": data.user_id,
"is_admin": data.is_admin
}
@router.delete("/{user_id}")
async def delete_user(
user_id: str,
admin_user: User = Depends(require_admin)
):
"""
删除用户(仅管理员)
限制:
- 不能删除管理员用户
"""
success = await user_manager.delete_user(user_id)
if not success:
raise HTTPException(
status_code=400,
detail="无法删除该用户(用户不存在或为管理员)"
)
return {
"message": "用户已删除",
"user_id": user_id
}
@router.get("/{user_id}")
async def get_user(
user_id: str,
admin_user: User = Depends(require_admin)
):
"""获取指定用户信息(仅管理员)"""
user = await user_manager.get_user(user_id)
if not user:
raise HTTPException(status_code=404, detail="用户不存在")
return user.dict()
File diff suppressed because it is too large Load Diff