update:1.更新根据分析建议重新生成章节内容
This commit is contained in:
+212
-18
@@ -9,6 +9,7 @@ import hashlib
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from app.services.oauth_service import LinuxDOOAuthService
|
||||
from app.user_manager import user_manager
|
||||
from app.user_password import password_manager
|
||||
from app.database import init_db
|
||||
from app.logger import get_logger
|
||||
from app.config import settings
|
||||
@@ -49,6 +50,25 @@ class LocalLoginResponse(BaseModel):
|
||||
user: Optional[dict] = None
|
||||
|
||||
|
||||
class SetPasswordRequest(BaseModel):
|
||||
"""设置密码请求"""
|
||||
password: str
|
||||
|
||||
|
||||
class SetPasswordResponse(BaseModel):
|
||||
"""设置密码响应"""
|
||||
success: bool
|
||||
message: str
|
||||
|
||||
|
||||
class PasswordStatusResponse(BaseModel):
|
||||
"""密码状态响应"""
|
||||
has_password: bool
|
||||
has_custom_password: bool
|
||||
username: Optional[str] = None
|
||||
default_password: Optional[str] = None
|
||||
|
||||
|
||||
@router.get("/config")
|
||||
async def get_auth_config():
|
||||
"""获取认证配置信息"""
|
||||
@@ -60,30 +80,77 @@ async def get_auth_config():
|
||||
|
||||
@router.post("/local/login", response_model=LocalLoginResponse)
|
||||
async def local_login(request: LocalLoginRequest, response: Response):
|
||||
"""本地账户登录"""
|
||||
"""本地账户登录(支持.env配置的管理员账号和Linux DO授权后绑定的账号)"""
|
||||
# 检查是否启用本地登录
|
||||
if not settings.LOCAL_AUTH_ENABLED:
|
||||
raise HTTPException(status_code=403, detail="本地账户登录未启用")
|
||||
|
||||
# 检查是否配置了本地账户
|
||||
if not settings.LOCAL_AUTH_USERNAME or not settings.LOCAL_AUTH_PASSWORD:
|
||||
raise HTTPException(status_code=500, detail="本地账户未配置")
|
||||
logger.info(f"[本地登录] 尝试登录用户名: {request.username}")
|
||||
|
||||
# 验证用户名和密码
|
||||
if request.username != settings.LOCAL_AUTH_USERNAME or request.password != settings.LOCAL_AUTH_PASSWORD:
|
||||
raise HTTPException(status_code=401, detail="用户名或密码错误")
|
||||
# 首先尝试查找 Linux DO 授权后绑定的账号
|
||||
all_users = await user_manager.get_all_users()
|
||||
target_user = None
|
||||
|
||||
# 生成本地用户ID(使用用户名的hash)
|
||||
user_id = f"local_{hashlib.md5(request.username.encode()).hexdigest()[:16]}"
|
||||
for user in all_users:
|
||||
# 同时检查 users 表的 username 和 user_passwords 表的 username
|
||||
password_username = await password_manager.get_username(user.user_id)
|
||||
if user.username == request.username or password_username == request.username:
|
||||
target_user = user
|
||||
logger.info(f"[本地登录] 找到 Linux DO 授权用户: {user.user_id}")
|
||||
break
|
||||
|
||||
# 创建或更新本地用户
|
||||
user = await user_manager.create_or_update_from_linuxdo(
|
||||
linuxdo_id=user_id,
|
||||
username=request.username,
|
||||
display_name=settings.LOCAL_AUTH_DISPLAY_NAME,
|
||||
avatar_url=None,
|
||||
trust_level=9 # 本地用户给予高信任级别
|
||||
)
|
||||
# 如果找到了 Linux DO 授权的用户
|
||||
if target_user:
|
||||
# 检查是否有密码
|
||||
if not await password_manager.has_password(target_user.user_id):
|
||||
logger.warning(f"[本地登录] 用户 {target_user.user_id} 没有设置密码")
|
||||
raise HTTPException(status_code=401, detail="用户名或密码错误")
|
||||
|
||||
# 验证密码
|
||||
if not await password_manager.verify_password(target_user.user_id, request.password):
|
||||
logger.warning(f"[本地登录] 用户 {target_user.user_id} 密码验证失败")
|
||||
raise HTTPException(status_code=401, detail="用户名或密码错误")
|
||||
|
||||
logger.info(f"[本地登录] Linux DO 授权用户 {target_user.user_id} 登录成功")
|
||||
user = target_user
|
||||
else:
|
||||
# 没有找到 Linux DO 用户,尝试 .env 配置的管理员账号
|
||||
logger.info(f"[本地登录] 未找到 Linux DO 用户,检查 .env 管理员账号")
|
||||
|
||||
# 检查是否配置了本地账户
|
||||
if not settings.LOCAL_AUTH_USERNAME or not settings.LOCAL_AUTH_PASSWORD:
|
||||
raise HTTPException(status_code=401, detail="用户名或密码错误")
|
||||
|
||||
# 生成本地用户ID(使用用户名的hash)
|
||||
user_id = f"local_{hashlib.md5(request.username.encode()).hexdigest()[:16]}"
|
||||
|
||||
# 检查用户是否存在
|
||||
user = await user_manager.get_user(user_id)
|
||||
|
||||
# 如果用户不存在,使用.env中的默认密码验证
|
||||
if not user:
|
||||
# 验证用户名和密码(使用.env配置)
|
||||
if request.username != settings.LOCAL_AUTH_USERNAME or request.password != settings.LOCAL_AUTH_PASSWORD:
|
||||
raise HTTPException(status_code=401, detail="用户名或密码错误")
|
||||
|
||||
# 创建本地用户
|
||||
user = await user_manager.create_or_update_from_linuxdo(
|
||||
linuxdo_id=user_id,
|
||||
username=request.username,
|
||||
display_name=settings.LOCAL_AUTH_DISPLAY_NAME,
|
||||
avatar_url=None,
|
||||
trust_level=9 # 本地用户给予高信任级别
|
||||
)
|
||||
|
||||
# 为新用户设置默认密码到数据库
|
||||
await password_manager.set_password(user.user_id, request.username, request.password)
|
||||
logger.info(f"[本地登录] 管理员用户 {user.user_id} 初始密码已设置到数据库")
|
||||
else:
|
||||
# 用户已存在,使用数据库中的密码验证
|
||||
if not await password_manager.verify_password(user.user_id, request.password):
|
||||
raise HTTPException(status_code=401, detail="用户名或密码错误")
|
||||
|
||||
logger.info(f"[本地登录] 管理员用户 {user.user_id} 登录成功")
|
||||
|
||||
# 初始化用户数据库
|
||||
try:
|
||||
@@ -189,6 +256,11 @@ async def _handle_callback(
|
||||
trust_level=trust_level
|
||||
)
|
||||
|
||||
# 3.1. 自动绑定密码(如果还没有设置)
|
||||
if not await password_manager.has_password(user.user_id):
|
||||
default_password = await password_manager.set_password(user.user_id, username)
|
||||
logger.info(f"用户 {user.user_id} ({username}) 自动绑定默认密码: {default_password}")
|
||||
|
||||
# 3.5. 初始化用户数据库(如果是新用户)
|
||||
try:
|
||||
await init_db(user.user_id)
|
||||
@@ -337,4 +409,126 @@ async def get_current_user(request: Request):
|
||||
if not hasattr(request.state, "user") or not request.state.user:
|
||||
raise HTTPException(status_code=401, detail="未登录")
|
||||
|
||||
return request.state.user.dict()
|
||||
return request.state.user.dict()
|
||||
|
||||
|
||||
@router.get("/password/status", response_model=PasswordStatusResponse)
|
||||
async def get_password_status(request: Request):
|
||||
"""获取当前用户的密码状态"""
|
||||
if not hasattr(request.state, "user") or not request.state.user:
|
||||
raise HTTPException(status_code=401, detail="未登录")
|
||||
|
||||
user = request.state.user
|
||||
has_password = await password_manager.has_password(user.user_id)
|
||||
has_custom = await password_manager.has_custom_password(user.user_id)
|
||||
username = await password_manager.get_username(user.user_id)
|
||||
|
||||
# 如果使用默认密码,返回默认密码供用户查看
|
||||
default_password = None
|
||||
if has_password and not has_custom:
|
||||
default_password = f"{user.username}@666"
|
||||
|
||||
return PasswordStatusResponse(
|
||||
has_password=has_password,
|
||||
has_custom_password=has_custom,
|
||||
username=username or user.username,
|
||||
default_password=default_password
|
||||
)
|
||||
|
||||
|
||||
@router.post("/password/set", response_model=SetPasswordResponse)
|
||||
async def set_user_password(request: Request, password_req: SetPasswordRequest):
|
||||
"""设置当前用户的密码"""
|
||||
if not hasattr(request.state, "user") or not request.state.user:
|
||||
raise HTTPException(status_code=401, detail="未登录")
|
||||
|
||||
user = request.state.user
|
||||
|
||||
# 验证密码强度(至少6个字符)
|
||||
if len(password_req.password) < 6:
|
||||
raise HTTPException(status_code=400, detail="密码长度至少为6个字符")
|
||||
|
||||
# 设置密码
|
||||
await password_manager.set_password(user.user_id, user.username, password_req.password)
|
||||
logger.info(f"用户 {user.user_id} ({user.username}) 设置了自定义密码")
|
||||
|
||||
return SetPasswordResponse(
|
||||
success=True,
|
||||
message="密码设置成功"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/bind/login", response_model=LocalLoginResponse)
|
||||
async def bind_account_login(request: LocalLoginRequest, response: Response):
|
||||
"""使用绑定的账号密码登录(LinuxDO授权后绑定的账号)"""
|
||||
# 查找用户
|
||||
all_users = await user_manager.get_all_users()
|
||||
target_user = None
|
||||
|
||||
logger.info(f"[绑定账号登录] 尝试登录用户名: {request.username}")
|
||||
logger.info(f"[绑定账号登录] 当前共有 {len(all_users)} 个用户")
|
||||
|
||||
for user in all_users:
|
||||
# 同时检查 users 表的 username 和 user_passwords 表的 username
|
||||
password_username = await password_manager.get_username(user.user_id)
|
||||
logger.info(f"[绑定账号登录] 检查用户 {user.user_id}: users.username={user.username}, passwords.username={password_username}")
|
||||
|
||||
if user.username == request.username or password_username == request.username:
|
||||
target_user = user
|
||||
logger.info(f"[绑定账号登录] 找到匹配用户: {user.user_id}")
|
||||
break
|
||||
|
||||
if not target_user:
|
||||
logger.warning(f"[绑定账号登录] 用户名 {request.username} 未找到")
|
||||
raise HTTPException(status_code=401, detail="用户名或密码错误")
|
||||
|
||||
# 检查是否有密码记录
|
||||
has_pwd = await password_manager.has_password(target_user.user_id)
|
||||
if not has_pwd:
|
||||
logger.warning(f"[绑定账号登录] 用户 {target_user.user_id} 没有设置密码")
|
||||
raise HTTPException(status_code=401, detail="用户名或密码错误")
|
||||
|
||||
# 验证密码
|
||||
is_valid = await password_manager.verify_password(target_user.user_id, request.password)
|
||||
logger.info(f"[绑定账号登录] 用户 {target_user.user_id} 密码验证结果: {is_valid}")
|
||||
|
||||
if not is_valid:
|
||||
raise HTTPException(status_code=401, detail="用户名或密码错误")
|
||||
|
||||
# 初始化用户数据库
|
||||
try:
|
||||
await init_db(target_user.user_id)
|
||||
logger.info(f"绑定账号用户 {target_user.user_id} 数据库初始化成功")
|
||||
except Exception as e:
|
||||
logger.error(f"绑定账号用户 {target_user.user_id} 数据库初始化失败: {e}")
|
||||
|
||||
# 设置 Cookie(2小时有效)
|
||||
max_age = settings.SESSION_EXPIRE_MINUTES * 60
|
||||
response.set_cookie(
|
||||
key="user_id",
|
||||
value=target_user.user_id,
|
||||
max_age=max_age,
|
||||
httponly=True,
|
||||
samesite="lax"
|
||||
)
|
||||
|
||||
# 设置过期时间戳 Cookie(用于前端判断)
|
||||
china_now = get_china_now()
|
||||
expire_time = china_now + timedelta(minutes=settings.SESSION_EXPIRE_MINUTES)
|
||||
expire_at = int(expire_time.timestamp())
|
||||
|
||||
logger.info(f"✅ [绑定账号登录] 用户 {target_user.user_id} ({request.username}) 登录成功,会话有效期 {settings.SESSION_EXPIRE_MINUTES} 分钟")
|
||||
|
||||
response.set_cookie(
|
||||
key="session_expire_at",
|
||||
value=str(expire_at),
|
||||
max_age=max_age,
|
||||
httponly=False, # 前端需要读取
|
||||
samesite="lax"
|
||||
)
|
||||
|
||||
return LocalLoginResponse(
|
||||
success=True,
|
||||
message="登录成功",
|
||||
user=target_user.dict()
|
||||
)
|
||||
+296
-10
@@ -1,6 +1,5 @@
|
||||
"""章节管理API"""
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, Query, BackgroundTasks
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, func
|
||||
import json
|
||||
@@ -19,6 +18,7 @@ from app.models.writing_style import WritingStyle
|
||||
from app.models.analysis_task import AnalysisTask
|
||||
from app.models.memory import PlotAnalysis, StoryMemory
|
||||
from app.models.batch_generation_task import BatchGenerationTask
|
||||
from app.models.regeneration_task import RegenerationTask
|
||||
from app.schemas.chapter import (
|
||||
ChapterCreate,
|
||||
ChapterUpdate,
|
||||
@@ -29,12 +29,19 @@ from app.schemas.chapter import (
|
||||
BatchGenerateResponse,
|
||||
BatchGenerateStatusResponse
|
||||
)
|
||||
from app.schemas.regeneration import (
|
||||
ChapterRegenerateRequest,
|
||||
RegenerationTaskResponse,
|
||||
RegenerationTaskStatus
|
||||
)
|
||||
from app.services.ai_service import AIService
|
||||
from app.services.prompt_service import prompt_service
|
||||
from app.services.plot_analyzer import PlotAnalyzer
|
||||
from app.services.memory_service import memory_service
|
||||
from app.services.chapter_regenerator import ChapterRegenerator
|
||||
from app.logger import get_logger
|
||||
from app.api.settings import get_user_ai_service
|
||||
from app.utils.sse_response import create_sse_response
|
||||
|
||||
router = APIRouter(prefix="/chapters", tags=["章节管理"])
|
||||
logger = get_logger(__name__)
|
||||
@@ -1284,15 +1291,7 @@ async def generate_chapter_content_stream(
|
||||
except:
|
||||
pass
|
||||
|
||||
return StreamingResponse(
|
||||
event_generator(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no"
|
||||
}
|
||||
)
|
||||
return create_sse_response(event_generator())
|
||||
|
||||
|
||||
@router.get("/{chapter_id}/analysis/status", summary="查询章节分析任务状态")
|
||||
@@ -2293,3 +2292,290 @@ async def generate_single_chapter_for_batch(
|
||||
await db_session.refresh(chapter)
|
||||
|
||||
logger.info(f"✅ 单章节生成完成: 第{chapter.chapter_number}章,共 {new_word_count} 字")
|
||||
|
||||
|
||||
|
||||
|
||||
# ==================== 章节重新生成相关API ====================
|
||||
|
||||
@router.post("/{chapter_id}/regenerate-stream", summary="流式重新生成章节内容")
|
||||
async def regenerate_chapter_stream(
|
||||
chapter_id: str,
|
||||
request: Request,
|
||||
regenerate_request: ChapterRegenerateRequest,
|
||||
background_tasks: BackgroundTasks,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
user_ai_service: AIService = Depends(get_user_ai_service)
|
||||
):
|
||||
"""
|
||||
根据分析建议或自定义指令重新生成章节内容(流式返回)
|
||||
|
||||
工作流程:
|
||||
1. 验证章节和分析结果
|
||||
2. 创建重新生成任务
|
||||
3. 构建修改指令
|
||||
4. 流式生成新内容
|
||||
5. 保存为版本历史
|
||||
6. 可选自动应用
|
||||
"""
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="未登录")
|
||||
|
||||
# 验证章节存在
|
||||
chapter_result = await db.execute(
|
||||
select(Chapter).where(Chapter.id == chapter_id)
|
||||
)
|
||||
chapter = chapter_result.scalar_one_or_none()
|
||||
|
||||
if not chapter:
|
||||
raise HTTPException(status_code=404, detail="章节不存在")
|
||||
|
||||
if not chapter.content or chapter.content.strip() == "":
|
||||
raise HTTPException(status_code=400, detail="章节内容为空,无法重新生成")
|
||||
|
||||
# 验证用户权限
|
||||
await verify_project_access(chapter.project_id, user_id, db)
|
||||
|
||||
# 获取分析结果(如果使用分析建议)
|
||||
analysis = None
|
||||
if regenerate_request.modification_source in ['analysis_suggestions', 'mixed']:
|
||||
analysis_result = await db.execute(
|
||||
select(PlotAnalysis)
|
||||
.where(PlotAnalysis.chapter_id == chapter_id)
|
||||
.order_by(PlotAnalysis.created_at.desc())
|
||||
.limit(1)
|
||||
)
|
||||
analysis = analysis_result.scalar_one_or_none()
|
||||
|
||||
if not analysis:
|
||||
raise HTTPException(status_code=404, detail="该章节暂无分析结果")
|
||||
|
||||
# 预先获取项目上下文数据
|
||||
async for temp_db in get_db(request):
|
||||
try:
|
||||
# 获取项目信息
|
||||
project_result = await temp_db.execute(
|
||||
select(Project).where(Project.id == chapter.project_id)
|
||||
)
|
||||
project = project_result.scalar_one_or_none()
|
||||
|
||||
# 获取角色信息
|
||||
characters_result = await temp_db.execute(
|
||||
select(Character).where(Character.project_id == chapter.project_id)
|
||||
)
|
||||
characters = characters_result.scalars().all()
|
||||
|
||||
# 获取章节大纲
|
||||
outline_result = await temp_db.execute(
|
||||
select(Outline)
|
||||
.where(Outline.project_id == chapter.project_id)
|
||||
.where(Outline.order_index == chapter.chapter_number)
|
||||
)
|
||||
outline = outline_result.scalar_one_or_none()
|
||||
|
||||
# 构建项目上下文
|
||||
project_context = {
|
||||
'project_title': project.title if project else '未知',
|
||||
'genre': project.genre if project else '未设定',
|
||||
'theme': project.theme if project else '未设定',
|
||||
'narrative_perspective': project.narrative_perspective if project else '第三人称',
|
||||
'time_period': project.world_time_period if project else '未设定',
|
||||
'location': project.world_location if project else '未设定',
|
||||
'atmosphere': project.world_atmosphere if project else '未设定',
|
||||
'characters_info': "\n".join([
|
||||
f"- {c.name}({'组织' if c.is_organization else '角色'}, {c.role_type}): {c.personality[:100] if c.personality else ''}"
|
||||
for c in characters
|
||||
]) if characters else '暂无角色信息',
|
||||
'chapter_outline': outline.content if outline else chapter.summary or '暂无大纲',
|
||||
'previous_context': '' # 可以后续扩展添加前置章节上下文
|
||||
}
|
||||
finally:
|
||||
await temp_db.close()
|
||||
break
|
||||
|
||||
async def event_generator():
|
||||
"""流式生成事件生成器"""
|
||||
db_session = None
|
||||
db_committed = False
|
||||
|
||||
try:
|
||||
# 创建独立数据库会话
|
||||
async for db_session in get_db(request):
|
||||
# 发送开始事件
|
||||
yield f"data: {json.dumps({'type': 'start', 'message': '开始重新生成章节...'}, ensure_ascii=False)}\n\n"
|
||||
|
||||
# 创建重新生成任务
|
||||
regen_task = RegenerationTask(
|
||||
chapter_id=chapter_id,
|
||||
analysis_id=analysis.id if analysis else None,
|
||||
user_id=user_id,
|
||||
project_id=chapter.project_id,
|
||||
modification_instructions="", # 稍后填充
|
||||
original_suggestions=analysis.suggestions if analysis else None,
|
||||
selected_suggestion_indices=regenerate_request.selected_suggestion_indices,
|
||||
custom_instructions=regenerate_request.custom_instructions,
|
||||
style_id=regenerate_request.style_id,
|
||||
target_word_count=regenerate_request.target_word_count,
|
||||
focus_areas=regenerate_request.focus_areas,
|
||||
preserve_elements=regenerate_request.preserve_elements.model_dump() if regenerate_request.preserve_elements else None,
|
||||
status='running',
|
||||
original_content=chapter.content,
|
||||
original_word_count=chapter.word_count or len(chapter.content),
|
||||
version_note=regenerate_request.version_note,
|
||||
started_at=datetime.now()
|
||||
)
|
||||
db_session.add(regen_task)
|
||||
await db_session.commit()
|
||||
await db_session.refresh(regen_task)
|
||||
|
||||
task_id = regen_task.id
|
||||
logger.info(f"📝 创建重新生成任务: {task_id}")
|
||||
|
||||
yield f"data: {json.dumps({'type': 'task_created', 'task_id': task_id}, ensure_ascii=False)}\n\n"
|
||||
|
||||
# 初始化重新生成器
|
||||
regenerator = ChapterRegenerator(user_ai_service)
|
||||
|
||||
# 流式生成新内容
|
||||
full_content = ""
|
||||
async for event in regenerator.regenerate_with_feedback(
|
||||
chapter=chapter,
|
||||
analysis=analysis,
|
||||
regenerate_request=regenerate_request,
|
||||
project_context=project_context
|
||||
):
|
||||
# 处理不同类型的事件
|
||||
if event['type'] == 'chunk':
|
||||
# 内容块
|
||||
chunk = event['content']
|
||||
full_content += chunk
|
||||
yield f"data: {json.dumps({'type': 'chunk', 'content': chunk}, ensure_ascii=False)}\n\n"
|
||||
elif event['type'] == 'progress':
|
||||
# 进度更新
|
||||
progress_data = {
|
||||
'type': 'progress',
|
||||
'progress': event.get('progress', 0),
|
||||
'message': event.get('message', ''),
|
||||
'word_count': event.get('word_count', 0)
|
||||
}
|
||||
yield f"data: {json.dumps(progress_data, ensure_ascii=False)}\n\n"
|
||||
|
||||
await asyncio.sleep(0)
|
||||
|
||||
# 更新任务状态
|
||||
regen_task.status = 'completed'
|
||||
regen_task.regenerated_content = full_content
|
||||
regen_task.regenerated_word_count = len(full_content)
|
||||
regen_task.completed_at = datetime.now()
|
||||
|
||||
# 计算差异统计
|
||||
diff_stats = regenerator.calculate_content_diff(chapter.content, full_content)
|
||||
|
||||
await db_session.commit()
|
||||
db_committed = True
|
||||
|
||||
# 先发送结果数据
|
||||
result_data = {
|
||||
'type': 'result',
|
||||
'data': {
|
||||
'task_id': task_id,
|
||||
'word_count': len(full_content),
|
||||
'version_number': regen_task.version_number,
|
||||
'auto_applied': regenerate_request.auto_apply,
|
||||
'diff_stats': diff_stats
|
||||
}
|
||||
}
|
||||
yield f"data: {json.dumps(result_data, ensure_ascii=False)}\n\n"
|
||||
|
||||
# 再发送完成事件
|
||||
completion_data = {
|
||||
'type': 'done',
|
||||
'message': '重新生成完成'
|
||||
}
|
||||
yield f"data: {json.dumps(completion_data, ensure_ascii=False)}\n\n"
|
||||
|
||||
logger.info(f"✅ 章节重新生成完成: {chapter_id}, 任务: {task_id}")
|
||||
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 重新生成失败: {str(e)}", exc_info=True)
|
||||
|
||||
# 更新任务状态为失败
|
||||
if db_session and not db_committed:
|
||||
try:
|
||||
task_result = await db_session.execute(
|
||||
select(RegenerationTask).where(RegenerationTask.chapter_id == chapter_id)
|
||||
.order_by(RegenerationTask.created_at.desc()).limit(1)
|
||||
)
|
||||
task = task_result.scalar_one_or_none()
|
||||
if task:
|
||||
task.status = 'failed'
|
||||
task.error_message = str(e)[:500]
|
||||
task.completed_at = datetime.now()
|
||||
await db_session.commit()
|
||||
except Exception as update_error:
|
||||
logger.error(f"更新任务失败状态失败: {str(update_error)}")
|
||||
|
||||
yield f"data: {json.dumps({'type': 'error', 'error': str(e)}, ensure_ascii=False)}\n\n"
|
||||
|
||||
finally:
|
||||
if db_session:
|
||||
try:
|
||||
if not db_committed and db_session.in_transaction():
|
||||
await db_session.rollback()
|
||||
await db_session.close()
|
||||
except Exception as close_error:
|
||||
logger.error(f"关闭数据库会话失败: {str(close_error)}")
|
||||
|
||||
return create_sse_response(event_generator())
|
||||
|
||||
|
||||
@router.get("/{chapter_id}/regeneration/tasks", summary="获取章节的重新生成任务列表")
|
||||
async def get_regeneration_tasks(
|
||||
chapter_id: str,
|
||||
request: Request,
|
||||
limit: int = Query(10, ge=1, le=50),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""获取指定章节的重新生成任务历史"""
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
|
||||
# 验证章节存在和权限
|
||||
chapter_result = await db.execute(
|
||||
select(Chapter).where(Chapter.id == chapter_id)
|
||||
)
|
||||
chapter = chapter_result.scalar_one_or_none()
|
||||
if not chapter:
|
||||
raise HTTPException(status_code=404, detail="章节不存在")
|
||||
|
||||
await verify_project_access(chapter.project_id, user_id, db)
|
||||
|
||||
# 获取任务列表
|
||||
result = await db.execute(
|
||||
select(RegenerationTask)
|
||||
.where(RegenerationTask.chapter_id == chapter_id)
|
||||
.order_by(RegenerationTask.created_at.desc())
|
||||
.limit(limit)
|
||||
)
|
||||
tasks = result.scalars().all()
|
||||
|
||||
return {
|
||||
"chapter_id": chapter_id,
|
||||
"total": len(tasks),
|
||||
"tasks": [
|
||||
{
|
||||
"task_id": task.id,
|
||||
"status": task.status,
|
||||
"version_number": task.version_number,
|
||||
"version_note": task.version_note,
|
||||
"original_word_count": task.original_word_count,
|
||||
"regenerated_word_count": task.regenerated_word_count,
|
||||
"created_at": task.created_at.isoformat() if task.created_at else None,
|
||||
"completed_at": task.completed_at.isoformat() if task.completed_at else None
|
||||
}
|
||||
for task in tasks
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
@@ -1135,386 +1135,3 @@ async def generate_outline_stream(
|
||||
"""
|
||||
return create_sse_response(outline_generator(data, db, user_ai_service))
|
||||
|
||||
|
||||
async def update_world_building_generator(
|
||||
project_id: str,
|
||||
data: Dict[str, Any],
|
||||
db: AsyncSession
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""更新世界观流式生成器"""
|
||||
db_committed = False
|
||||
try:
|
||||
yield await SSEResponse.send_progress("开始更新世界观...", 10)
|
||||
|
||||
# 获取项目
|
||||
result = await db.execute(
|
||||
select(Project).where(Project.id == project_id)
|
||||
)
|
||||
project = result.scalar_one_or_none()
|
||||
if not project:
|
||||
yield await SSEResponse.send_error("项目不存在", 404)
|
||||
return
|
||||
|
||||
yield await SSEResponse.send_progress("验证数据...", 30)
|
||||
|
||||
# 更新世界观字段
|
||||
if "time_period" in data:
|
||||
project.world_time_period = data["time_period"]
|
||||
if "location" in data:
|
||||
project.world_location = data["location"]
|
||||
if "atmosphere" in data:
|
||||
project.world_atmosphere = data["atmosphere"]
|
||||
if "rules" in data:
|
||||
project.world_rules = data["rules"]
|
||||
|
||||
yield await SSEResponse.send_progress("保存到数据库...", 70)
|
||||
|
||||
await db.commit()
|
||||
db_committed = True
|
||||
await db.refresh(project)
|
||||
|
||||
# 发送结果
|
||||
yield await SSEResponse.send_result({
|
||||
"project_id": project.id,
|
||||
"time_period": project.world_time_period,
|
||||
"location": project.world_location,
|
||||
"atmosphere": project.world_atmosphere,
|
||||
"rules": project.world_rules
|
||||
})
|
||||
|
||||
yield await SSEResponse.send_progress("完成!", 100, "success")
|
||||
yield await SSEResponse.send_done()
|
||||
|
||||
except GeneratorExit:
|
||||
logger.warning("更新世界观生成器被提前关闭")
|
||||
if not db_committed and db.in_transaction():
|
||||
await db.rollback()
|
||||
logger.info("更新世界观事务已回滚(GeneratorExit)")
|
||||
except Exception as e:
|
||||
logger.error(f"更新世界观失败: {str(e)}")
|
||||
if not db_committed and db.in_transaction():
|
||||
await db.rollback()
|
||||
logger.info("更新世界观事务已回滚(异常)")
|
||||
yield await SSEResponse.send_error(f"更新失败: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/world-building/{project_id}", summary="流式更新世界观")
|
||||
async def update_world_building_stream(
|
||||
project_id: str,
|
||||
data: Dict[str, Any],
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
使用SSE流式更新项目的世界观信息
|
||||
请求体格式:
|
||||
{
|
||||
"time_period": "时间背景",
|
||||
"location": "地理位置",
|
||||
"atmosphere": "氛围基调",
|
||||
"rules": "世界规则"
|
||||
}
|
||||
"""
|
||||
return create_sse_response(update_world_building_generator(project_id, data, db))
|
||||
|
||||
|
||||
async def regenerate_world_building_generator(
|
||||
project_id: str,
|
||||
data: Dict[str, Any],
|
||||
db: AsyncSession,
|
||||
user_ai_service: AIService
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""重新生成世界观流式生成器 - 支持MCP工具增强"""
|
||||
db_committed = False
|
||||
try:
|
||||
yield await SSEResponse.send_progress("开始重新生成世界观...", 10)
|
||||
|
||||
# 获取项目
|
||||
result = await db.execute(
|
||||
select(Project).where(Project.id == project_id)
|
||||
)
|
||||
project = result.scalar_one_or_none()
|
||||
if not project:
|
||||
yield await SSEResponse.send_error("项目不存在", 404)
|
||||
return
|
||||
|
||||
provider = data.get("provider")
|
||||
model = data.get("model")
|
||||
enable_mcp = data.get("enable_mcp", True) # 默认启用MCP
|
||||
user_id = data.get("user_id") # 从中间件注入
|
||||
|
||||
# 获取基础提示词
|
||||
yield await SSEResponse.send_progress("准备AI提示词...", 15)
|
||||
base_prompt = prompt_service.get_world_building_prompt(
|
||||
title=project.title,
|
||||
theme=project.theme or "",
|
||||
genre=project.genre or ""
|
||||
)
|
||||
|
||||
# MCP工具增强:收集参考资料
|
||||
reference_materials = ""
|
||||
if enable_mcp and user_id:
|
||||
try:
|
||||
yield await SSEResponse.send_progress("🔍 尝试使用MCP工具收集参考资料...", 18)
|
||||
|
||||
# 直接调用MCP增强的AI,内部会自动检查和加载工具
|
||||
# 构建资料收集提示词
|
||||
planning_prompt = f"""你正在为小说《{project.title}》重新设计世界观。
|
||||
|
||||
【小说信息】
|
||||
- 题材:{project.genre or '未设定'}
|
||||
- 主题:{project.theme or '未设定'}
|
||||
|
||||
【任务】
|
||||
请使用可用工具搜索相关背景资料,帮助构建更真实、更有深度的世界观设定。
|
||||
你可以查询:
|
||||
1. 历史背景(如果是历史题材)
|
||||
2. 地理环境和文化特征
|
||||
3. 相关领域的专业知识
|
||||
4. 类似作品的设定参考
|
||||
|
||||
请根据题材特点,有针对性地查询2-3个关键问题。"""
|
||||
|
||||
# 调用MCP增强的AI(非流式,最多2轮工具调用)
|
||||
planning_result = await user_ai_service.generate_text_with_mcp(
|
||||
prompt=planning_prompt,
|
||||
user_id=user_id,
|
||||
db_session=db,
|
||||
enable_mcp=True,
|
||||
max_tool_rounds=2,
|
||||
tool_choice="auto",
|
||||
provider=None,
|
||||
model=None
|
||||
)
|
||||
|
||||
# 提取参考资料
|
||||
if planning_result.get("tool_calls_made", 0) > 0:
|
||||
yield await SSEResponse.send_progress(
|
||||
f"✅ MCP工具调用成功({planning_result['tool_calls_made']}次)",
|
||||
25
|
||||
)
|
||||
reference_materials = planning_result.get("content", "")
|
||||
else:
|
||||
yield await SSEResponse.send_progress("ℹ️ 未使用MCP工具(无可用工具或不需要)", 25)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"MCP工具调用失败(降级处理): {e}")
|
||||
yield await SSEResponse.send_progress("⚠️ MCP工具暂时不可用,使用基础模式", 25)
|
||||
|
||||
# 构建增强提示词
|
||||
if reference_materials:
|
||||
enhanced_prompt = f"""{base_prompt}
|
||||
|
||||
【参考资料】
|
||||
以下是通过MCP工具收集的真实背景资料,请参考这些信息构建更真实的世界观:
|
||||
|
||||
{reference_materials}
|
||||
|
||||
请结合上述资料,生成符合历史/现实的世界观设定。"""
|
||||
final_prompt = enhanced_prompt
|
||||
yield await SSEResponse.send_progress("💡 已整合参考资料,开始重新生成世界观...", 30)
|
||||
else:
|
||||
final_prompt = base_prompt
|
||||
yield await SSEResponse.send_progress("正在调用AI生成...", 30)
|
||||
|
||||
# 流式生成世界观
|
||||
accumulated_text = ""
|
||||
chunk_count = 0
|
||||
|
||||
async for chunk in user_ai_service.generate_text_stream(
|
||||
prompt=final_prompt,
|
||||
provider=provider,
|
||||
model=model
|
||||
):
|
||||
chunk_count += 1
|
||||
accumulated_text += chunk
|
||||
|
||||
# 发送内容块
|
||||
yield await SSEResponse.send_chunk(chunk)
|
||||
|
||||
# 定期更新进度
|
||||
if chunk_count % 5 == 0:
|
||||
progress = min(30 + (chunk_count // 5), 70)
|
||||
yield await SSEResponse.send_progress(f"生成中... ({len(accumulated_text)}字符)", progress)
|
||||
|
||||
# 每20个块发送心跳
|
||||
if chunk_count % 20 == 0:
|
||||
yield await SSEResponse.send_heartbeat()
|
||||
|
||||
# 解析结果
|
||||
yield await SSEResponse.send_progress("解析AI返回结果...", 80)
|
||||
|
||||
world_data = {}
|
||||
try:
|
||||
cleaned_text = accumulated_text.strip()
|
||||
# 移除markdown代码块标记
|
||||
if cleaned_text.startswith('```json'):
|
||||
cleaned_text = cleaned_text[7:].lstrip('\n\r')
|
||||
elif cleaned_text.startswith('```'):
|
||||
cleaned_text = cleaned_text[3:].lstrip('\n\r')
|
||||
if cleaned_text.endswith('```'):
|
||||
cleaned_text = cleaned_text[:-3].rstrip('\n\r')
|
||||
cleaned_text = cleaned_text.strip()
|
||||
|
||||
world_data = json.loads(cleaned_text)
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"AI返回非JSON格式: {e}")
|
||||
logger.info(world_data)
|
||||
world_data = {
|
||||
"time_period": "AI返回格式错误,请重试",
|
||||
"location": "AI返回格式错误,请重试",
|
||||
"atmosphere": "AI返回格式错误,请重试",
|
||||
"rules": "AI返回格式错误,请重试"
|
||||
}
|
||||
|
||||
# 更新项目世界观
|
||||
yield await SSEResponse.send_progress("保存到数据库...", 90)
|
||||
|
||||
project.world_time_period = world_data.get("time_period")
|
||||
project.world_location = world_data.get("location")
|
||||
project.world_atmosphere = world_data.get("atmosphere")
|
||||
project.world_rules = world_data.get("rules")
|
||||
|
||||
await db.commit()
|
||||
db_committed = True
|
||||
await db.refresh(project)
|
||||
|
||||
# 发送结果
|
||||
yield await SSEResponse.send_result({
|
||||
"project_id": project.id,
|
||||
"time_period": project.world_time_period,
|
||||
"location": project.world_location,
|
||||
"atmosphere": project.world_atmosphere,
|
||||
"rules": project.world_rules
|
||||
})
|
||||
|
||||
yield await SSEResponse.send_progress("完成!", 100, "success")
|
||||
yield await SSEResponse.send_done()
|
||||
|
||||
except GeneratorExit:
|
||||
logger.warning("重新生成世界观生成器被提前关闭")
|
||||
if not db_committed and db.in_transaction():
|
||||
await db.rollback()
|
||||
logger.info("重新生成世界观事务已回滚(GeneratorExit)")
|
||||
except Exception as e:
|
||||
logger.error(f"重新生成世界观失败: {str(e)}")
|
||||
if not db_committed and db.in_transaction():
|
||||
await db.rollback()
|
||||
logger.info("重新生成世界观事务已回滚(异常)")
|
||||
yield await SSEResponse.send_error(f"重新生成失败: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/world-building/{project_id}/regenerate", summary="流式重新生成世界观")
|
||||
async def regenerate_world_building_stream(
|
||||
request: Request,
|
||||
project_id: str,
|
||||
data: Dict[str, Any],
|
||||
db: AsyncSession = Depends(get_db),
|
||||
user_ai_service: AIService = Depends(get_user_ai_service)
|
||||
):
|
||||
"""
|
||||
使用SSE流式重新生成项目的世界观
|
||||
请求体格式:
|
||||
{
|
||||
"provider": "AI提供商(可选)",
|
||||
"model": "模型名称(可选)"
|
||||
}
|
||||
"""
|
||||
# 从中间件注入user_id到data中
|
||||
if hasattr(request.state, 'user_id'):
|
||||
data['user_id'] = request.state.user_id
|
||||
|
||||
return create_sse_response(regenerate_world_building_generator(project_id, data, db, user_ai_service))
|
||||
|
||||
|
||||
async def cleanup_wizard_data_generator(
|
||||
project_id: str,
|
||||
db: AsyncSession
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""清理向导数据流式生成器"""
|
||||
db_committed = False
|
||||
try:
|
||||
yield await SSEResponse.send_progress("开始清理向导数据...", 10)
|
||||
|
||||
# 获取项目
|
||||
result = await db.execute(
|
||||
select(Project).where(Project.id == project_id)
|
||||
)
|
||||
project = result.scalar_one_or_none()
|
||||
if not project:
|
||||
yield await SSEResponse.send_error("项目不存在", 404)
|
||||
return
|
||||
|
||||
# 删除相关的角色
|
||||
yield await SSEResponse.send_progress("删除角色数据...", 30)
|
||||
characters = await db.execute(
|
||||
select(Character).where(Character.project_id == project_id)
|
||||
)
|
||||
char_count = 0
|
||||
for character in characters.scalars():
|
||||
await db.delete(character)
|
||||
char_count += 1
|
||||
|
||||
# 删除相关的大纲
|
||||
yield await SSEResponse.send_progress("删除大纲数据...", 50)
|
||||
outlines = await db.execute(
|
||||
select(Outline).where(Outline.project_id == project_id)
|
||||
)
|
||||
outline_count = 0
|
||||
for outline in outlines.scalars():
|
||||
await db.delete(outline)
|
||||
outline_count += 1
|
||||
|
||||
# 删除相关的章节
|
||||
yield await SSEResponse.send_progress("删除章节数据...", 70)
|
||||
chapters = await db.execute(
|
||||
select(Chapter).where(Chapter.project_id == project_id)
|
||||
)
|
||||
chapter_count = 0
|
||||
for chapter in chapters.scalars():
|
||||
await db.delete(chapter)
|
||||
chapter_count += 1
|
||||
|
||||
# 删除项目
|
||||
yield await SSEResponse.send_progress("删除项目...", 85)
|
||||
await db.delete(project)
|
||||
|
||||
yield await SSEResponse.send_progress("提交数据库更改...", 95)
|
||||
await db.commit()
|
||||
db_committed = True
|
||||
|
||||
# 发送结果
|
||||
yield await SSEResponse.send_result({
|
||||
"message": "项目及相关数据已清理",
|
||||
"deleted": {
|
||||
"characters": char_count,
|
||||
"outlines": outline_count,
|
||||
"chapters": chapter_count
|
||||
}
|
||||
})
|
||||
|
||||
yield await SSEResponse.send_progress("完成!", 100, "success")
|
||||
yield await SSEResponse.send_done()
|
||||
|
||||
except GeneratorExit:
|
||||
logger.warning("清理向导数据生成器被提前关闭")
|
||||
if not db_committed and db.in_transaction():
|
||||
await db.rollback()
|
||||
logger.info("清理向导数据事务已回滚(GeneratorExit)")
|
||||
except Exception as e:
|
||||
logger.error(f"清理数据失败: {str(e)}")
|
||||
if not db_committed and db.in_transaction():
|
||||
await db.rollback()
|
||||
logger.info("清理向导数据事务已回滚(异常)")
|
||||
yield await SSEResponse.send_error(f"清理失败: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/cleanup/{project_id}", summary="流式清理向导数据")
|
||||
async def cleanup_wizard_data_stream(
|
||||
project_id: str,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
使用SSE流式清理向导过程中创建的项目及相关数据
|
||||
用于返回上一步时清理已生成的内容
|
||||
"""
|
||||
return create_sse_response(cleanup_wizard_data_generator(project_id, db))
|
||||
Reference in New Issue
Block a user