feat: 后台任务系统 + JSON容错解析 + SSE心跳保活 + 多项Bug修复
新功能: - 大纲/章节生成改为服务端后台任务,支持断线续传 - 后台任务队列排队执行,按用户排队(同用户串行不同用户并发) - 章节管理页面添加后台任务列表弹窗和进度面板 - 章节状态添加 pending(待处理)选项 - 集成json5容错解析器 + 上下文感知JSON修复 - SSE流式生成添加心跳保活,防止连接超时 - SSEPostClient添加credentials:include修复network error - 每章最大伏笔数从2调整为5 - 添加大纲读区伏笔的功能 Bug修复: - 修复AI生成JSON中未转义引号/中文标点/多对象属性值未合并 - 修复JSON非法转义字符清洗和中文引号处理 - 修复MCP插件TimeoutError/连接失败上下文清理 - MCP插件后台注册添加重试机制 - 续写模式添加缺失的mcp_references参数 - 修复Alembic迁移链分叉 - 使用torch CPU版本加速Docker构建
This commit is contained in:
+2
-1
@@ -124,4 +124,5 @@ test_api.py
|
||||
backend/embedding/
|
||||
|
||||
# 提示词工坊实例标识(每个部署实例必须唯一)
|
||||
backend/.instance_id
|
||||
backend/.instance_id
|
||||
test.json
|
||||
|
||||
@@ -0,0 +1,46 @@
|
||||
"""添加后台任务表
|
||||
|
||||
Revision ID: abc12345
|
||||
Revises:
|
||||
Create Date: 2026-04-27 12:00:00.000000
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects.postgresql import JSON, JSONB
|
||||
|
||||
# revision identifiers
|
||||
revision = 'abc12345'
|
||||
down_revision = '9a1b2c3d4e5f'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
'background_tasks',
|
||||
sa.Column('id', sa.String(36), primary_key=True),
|
||||
sa.Column('user_id', sa.String(100), nullable=False, index=True, comment='用户ID'),
|
||||
sa.Column('project_id', sa.String(36), nullable=False, index=True, comment='项目ID'),
|
||||
sa.Column('task_type', sa.String(50), nullable=False, comment='任务类型'),
|
||||
sa.Column('status', sa.String(20), default='pending', comment='任务状态'),
|
||||
sa.Column('progress', sa.Integer, default=0, comment='进度百分比'),
|
||||
sa.Column('status_message', sa.String(500), comment='当前状态消息'),
|
||||
sa.Column('task_input', JSON, comment='任务输入参数'),
|
||||
sa.Column('task_result', JSON, comment='任务结果'),
|
||||
sa.Column('error_message', sa.Text, comment='错误信息'),
|
||||
sa.Column('progress_details', JSON, comment='进度详情'),
|
||||
sa.Column('cancel_requested', sa.Boolean, default=False, comment='是否请求取消'),
|
||||
sa.Column('retry_count', sa.Integer, default=0, comment='已重试次数'),
|
||||
sa.Column('max_retries', sa.Integer, default=3, comment='最大重试次数'),
|
||||
sa.Column('created_at', sa.DateTime, server_default=sa.func.now(), comment='创建时间'),
|
||||
sa.Column('started_at', sa.DateTime, comment='开始时间'),
|
||||
sa.Column('completed_at', sa.DateTime, comment='完成时间'),
|
||||
sa.Column('updated_at', sa.DateTime, server_default=sa.func.now(), onupdate=sa.func.now(), comment='更新时间'),
|
||||
)
|
||||
# 添加复合索引:按用户+项目+状态查询
|
||||
op.create_index('ix_background_tasks_user_project', 'background_tasks', ['user_id', 'project_id', 'status'])
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index('ix_background_tasks_user_project', table_name='background_tasks')
|
||||
op.drop_table('background_tasks')
|
||||
@@ -0,0 +1,44 @@
|
||||
"""添加后台任务表
|
||||
|
||||
Revision ID: def45678ghi9
|
||||
Revises: ab12cd34ef56
|
||||
Create Date: 2026-04-27 12:00:00.000000
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = 'def45678ghi9'
|
||||
down_revision = 'ab12cd34ef56'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
'background_tasks',
|
||||
sa.Column('id', sa.String(), primary_key=True),
|
||||
sa.Column('task_type', sa.String(), nullable=False, index=True),
|
||||
sa.Column('project_id', sa.String(), nullable=False, index=True),
|
||||
sa.Column('user_id', sa.String(), nullable=True),
|
||||
sa.Column('status', sa.String(), default='pending', nullable=False, index=True),
|
||||
sa.Column('progress', sa.Float(), default=0.0),
|
||||
sa.Column('status_message', sa.String(), nullable=True),
|
||||
sa.Column('progress_details', sa.Text(), nullable=True),
|
||||
sa.Column('error_message', sa.Text(), nullable=True),
|
||||
sa.Column('task_params', sa.Text(), nullable=True),
|
||||
sa.Column('task_result', sa.Text(), nullable=True),
|
||||
sa.Column('cancel_requested', sa.Boolean(), default=False),
|
||||
sa.Column('retry_count', sa.Integer(), default=0),
|
||||
sa.Column('max_retries', sa.Integer(), default=3),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.func.now()),
|
||||
sa.Column('started_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('completed_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('updated_at', sa.DateTime(), server_default=sa.func.now(), onupdate=sa.func.now()),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table('background_tasks')
|
||||
@@ -27,6 +27,7 @@ from app.models.analysis_task import AnalysisTask
|
||||
from app.models.memory import PlotAnalysis, StoryMemory
|
||||
from app.models.batch_generation_task import BatchGenerationTask
|
||||
from app.models.regeneration_task import RegenerationTask
|
||||
from app.models.background_task import BackgroundTask
|
||||
from app.schemas.chapter import (
|
||||
ChapterCreate,
|
||||
ChapterUpdate,
|
||||
@@ -1815,6 +1816,450 @@ async def generate_chapter_content_stream(
|
||||
return create_sse_response(event_generator())
|
||||
|
||||
|
||||
@router.post("/{chapter_id}/generate-background", summary="AI创作章节内容(后台任务)")
|
||||
async def generate_chapter_content_background(
|
||||
chapter_id: str,
|
||||
request: Request,
|
||||
generate_request: ChapterGenerateRequest = ChapterGenerateRequest(),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
创建后台任务来生成章节内容。
|
||||
任务创建后立即返回task_id,前端通过 GET /api/tasks/{task_id} 轮询进度。
|
||||
关闭浏览器不影响生成,生成完成后内容自动保存到数据库。
|
||||
"""
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="未登录")
|
||||
|
||||
# 验证章节存在
|
||||
result = await db.execute(
|
||||
select(Chapter).where(Chapter.id == chapter_id)
|
||||
)
|
||||
chapter = result.scalar_one_or_none()
|
||||
if not chapter:
|
||||
raise HTTPException(status_code=404, detail="章节不存在")
|
||||
|
||||
# 验证项目权限
|
||||
project = await verify_project_access(chapter.project_id, user_id, db)
|
||||
|
||||
# 检查前置条件
|
||||
can_generate, error_msg, _ = await check_prerequisites(db, chapter)
|
||||
if not can_generate:
|
||||
raise HTTPException(status_code=400, detail=error_msg)
|
||||
|
||||
# 创建后台任务
|
||||
from app.services.background_task_service import background_task_service, TaskProgressTracker
|
||||
task = await background_task_service.create_task(
|
||||
user_id=user_id,
|
||||
project_id=chapter.project_id,
|
||||
task_type="chapter_generate",
|
||||
task_input={
|
||||
"chapter_id": chapter_id,
|
||||
"style_id": generate_request.style_id,
|
||||
"target_word_count": generate_request.target_word_count or 3000,
|
||||
"enable_mcp": generate_request.enable_mcp,
|
||||
"model": generate_request.model,
|
||||
"narrative_perspective": generate_request.narrative_perspective,
|
||||
},
|
||||
db=db
|
||||
)
|
||||
|
||||
# 后台执行的函数
|
||||
async def _run_chapter_generation(task_id: str, bg_user_id: str):
|
||||
from app.database import get_engine
|
||||
from sqlalchemy.ext.asyncio import async_sessionmaker, AsyncSession as BgAsyncSession
|
||||
|
||||
engine = await get_engine(bg_user_id)
|
||||
AsyncSessionLocal = async_sessionmaker(engine, class_=BgAsyncSession, expire_on_commit=False)
|
||||
|
||||
async with AsyncSessionLocal() as bg_db:
|
||||
tracker = TaskProgressTracker(task_id, bg_user_id, "章节")
|
||||
try:
|
||||
await tracker.start()
|
||||
|
||||
# 获取AI服务
|
||||
from app.api.settings import get_user_ai_service_from_db
|
||||
bg_ai_service = await get_user_ai_service_from_db(bg_user_id, bg_db)
|
||||
|
||||
await _run_chapter_generation_bg(
|
||||
task_input={
|
||||
"chapter_id": chapter_id,
|
||||
"style_id": generate_request.style_id,
|
||||
"target_word_count": generate_request.target_word_count or 3000,
|
||||
"enable_mcp": generate_request.enable_mcp,
|
||||
"model": generate_request.model,
|
||||
"narrative_perspective": generate_request.narrative_perspective,
|
||||
},
|
||||
db=bg_db,
|
||||
ai_service=bg_ai_service,
|
||||
tracker=tracker,
|
||||
user_id=bg_user_id,
|
||||
task_id=task_id,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 后台章节生成失败: {e}", exc_info=True)
|
||||
await tracker.error(str(e))
|
||||
|
||||
await background_task_service.spawn_background_task(
|
||||
task.id, user_id, _run_chapter_generation
|
||||
)
|
||||
|
||||
return {
|
||||
"task_id": task.id,
|
||||
"task_type": "chapter_generate",
|
||||
"status": "pending",
|
||||
"message": "任务已创建,请通过 GET /api/tasks/{task_id} 查询进度"
|
||||
}
|
||||
|
||||
|
||||
async def _run_chapter_generation_bg(
|
||||
task_input: dict,
|
||||
db: AsyncSession,
|
||||
ai_service: AIService,
|
||||
tracker,
|
||||
user_id: str,
|
||||
task_id: str,
|
||||
):
|
||||
"""后台执行章节生成(不使用SSE,直接生成并保存)"""
|
||||
from app.services.chapter_context_service import (
|
||||
OneToManyContextBuilder,
|
||||
OneToOneContextBuilder
|
||||
)
|
||||
|
||||
chapter_id = task_input["chapter_id"]
|
||||
style_id = task_input.get("style_id")
|
||||
target_word_count = task_input.get("target_word_count", 3000)
|
||||
custom_model = task_input.get("model")
|
||||
temp_narrative_perspective = task_input.get("narrative_perspective")
|
||||
write_lock = await get_db_write_lock(user_id)
|
||||
|
||||
# === 加载阶段 ===
|
||||
await tracker.loading("加载章节信息...", 0.2)
|
||||
|
||||
chapter_result = await db.execute(
|
||||
select(Chapter).where(Chapter.id == chapter_id)
|
||||
)
|
||||
current_chapter = chapter_result.scalar_one_or_none()
|
||||
if not current_chapter:
|
||||
await tracker.error("章节不存在")
|
||||
return
|
||||
|
||||
await tracker.loading("加载项目信息...", 0.4)
|
||||
|
||||
project_result = await db.execute(
|
||||
select(Project).where(Project.id == current_chapter.project_id)
|
||||
)
|
||||
project = project_result.scalar_one_or_none()
|
||||
if not project:
|
||||
await tracker.error("项目不存在")
|
||||
return
|
||||
|
||||
outline_mode = project.outline_mode if project else 'one-to-many'
|
||||
|
||||
# 获取大纲
|
||||
if current_chapter.outline_id:
|
||||
outline_result = await db.execute(
|
||||
select(Outline).where(Outline.id == current_chapter.outline_id)
|
||||
)
|
||||
else:
|
||||
outline_result = await db.execute(
|
||||
select(Outline)
|
||||
.where(Outline.project_id == current_chapter.project_id)
|
||||
.where(Outline.order_index == current_chapter.chapter_number)
|
||||
)
|
||||
outline = outline_result.scalar_one_or_none()
|
||||
|
||||
# 获取写作风格
|
||||
style_content = ""
|
||||
if style_id:
|
||||
style_result = await db.execute(
|
||||
select(WritingStyle).where(WritingStyle.id == style_id)
|
||||
)
|
||||
style = style_result.scalar_one_or_none()
|
||||
if style and (style.user_id is None or style.user_id == user_id):
|
||||
style_content = style.prompt_content or ""
|
||||
|
||||
# === 构建上下文 ===
|
||||
if outline_mode == 'one-to-one':
|
||||
context_builder = OneToOneContextBuilder(
|
||||
memory_service=memory_service,
|
||||
foreshadow_service=foreshadow_service
|
||||
)
|
||||
chapter_context = await context_builder.build(
|
||||
chapter=current_chapter,
|
||||
project=project,
|
||||
outline=outline,
|
||||
user_id=user_id,
|
||||
db=db,
|
||||
target_word_count=target_word_count
|
||||
)
|
||||
else:
|
||||
context_builder = OneToManyContextBuilder(
|
||||
memory_service=memory_service,
|
||||
foreshadow_service=foreshadow_service
|
||||
)
|
||||
chapter_context = await context_builder.build(
|
||||
chapter=current_chapter,
|
||||
project=project,
|
||||
outline=outline,
|
||||
user_id=user_id,
|
||||
db=db,
|
||||
style_content=style_content,
|
||||
target_word_count=target_word_count,
|
||||
temp_narrative_perspective=temp_narrative_perspective
|
||||
)
|
||||
|
||||
await tracker.loading("上下文构建完成", 0.8)
|
||||
|
||||
# 确定叙事人称
|
||||
chapter_perspective = (
|
||||
temp_narrative_perspective or
|
||||
project.narrative_perspective or
|
||||
'第三人称'
|
||||
)
|
||||
|
||||
# === 准备提示词 ===
|
||||
if outline_mode == 'one-to-one':
|
||||
if chapter_context.continuation_point:
|
||||
template = await PromptService.get_template("CHAPTER_GENERATION_ONE_TO_ONE_NEXT", user_id, db)
|
||||
base_prompt = PromptService.format_prompt(
|
||||
template,
|
||||
project_title=project.title,
|
||||
chapter_number=current_chapter.chapter_number,
|
||||
chapter_title=current_chapter.title,
|
||||
chapter_outline=chapter_context.chapter_outline,
|
||||
target_word_count=target_word_count,
|
||||
genre=project.genre or '未设定',
|
||||
narrative_perspective=chapter_perspective,
|
||||
previous_chapter_content=chapter_context.continuation_point,
|
||||
previous_chapter_summary=chapter_context.previous_chapter_summary or '(无上一章摘要)',
|
||||
characters_info=chapter_context.chapter_characters or '暂无角色信息',
|
||||
chapter_careers=chapter_context.chapter_careers or '暂无职业信息',
|
||||
foreshadow_reminders=chapter_context.foreshadow_reminders or '暂无需要关注的伏笔',
|
||||
relevant_memories=chapter_context.relevant_memories or '暂无相关记忆'
|
||||
)
|
||||
else:
|
||||
template = await PromptService.get_template("CHAPTER_GENERATION_ONE_TO_ONE", user_id, db)
|
||||
base_prompt = PromptService.format_prompt(
|
||||
template,
|
||||
project_title=project.title,
|
||||
chapter_number=current_chapter.chapter_number,
|
||||
chapter_title=current_chapter.title,
|
||||
chapter_outline=chapter_context.chapter_outline,
|
||||
target_word_count=target_word_count,
|
||||
genre=project.genre or '未设定',
|
||||
narrative_perspective=chapter_perspective,
|
||||
characters_info=chapter_context.chapter_characters or '暂无角色信息',
|
||||
chapter_careers=chapter_context.chapter_careers or '暂无职业信息',
|
||||
foreshadow_reminders=chapter_context.foreshadow_reminders or '暂无需要关注的伏笔',
|
||||
relevant_memories=chapter_context.relevant_memories or '暂无相关记忆'
|
||||
)
|
||||
else:
|
||||
if chapter_context.continuation_point:
|
||||
previous_summary = chapter_context.previous_chapter_summary or "(无上一章摘要,请根据锚点续写)"
|
||||
template = await PromptService.get_template("CHAPTER_GENERATION_ONE_TO_MANY_NEXT", user_id, db)
|
||||
base_prompt = PromptService.format_prompt(
|
||||
template,
|
||||
project_title=project.title,
|
||||
chapter_number=current_chapter.chapter_number,
|
||||
chapter_title=current_chapter.title,
|
||||
chapter_outline=chapter_context.chapter_outline,
|
||||
target_word_count=target_word_count,
|
||||
continuation_point=chapter_context.continuation_point,
|
||||
genre=project.genre or '未设定',
|
||||
narrative_perspective=chapter_perspective,
|
||||
characters_info=chapter_context.chapter_characters or '暂无角色信息',
|
||||
chapter_careers=chapter_context.chapter_careers or '暂无职业信息',
|
||||
foreshadow_reminders=chapter_context.foreshadow_reminders or '暂无需要关注的伏笔',
|
||||
previous_chapter_summary=previous_summary,
|
||||
recent_chapters_context=chapter_context.recent_chapters_context or '',
|
||||
relevant_memories=chapter_context.relevant_memories or ''
|
||||
)
|
||||
else:
|
||||
template = await PromptService.get_template("CHAPTER_GENERATION_ONE_TO_MANY", user_id, db)
|
||||
base_prompt = PromptService.format_prompt(
|
||||
template,
|
||||
project_title=project.title,
|
||||
chapter_number=current_chapter.chapter_number,
|
||||
chapter_title=current_chapter.title,
|
||||
chapter_outline=chapter_context.chapter_outline,
|
||||
target_word_count=target_word_count,
|
||||
genre=project.genre or '未设定',
|
||||
narrative_perspective=chapter_perspective,
|
||||
characters_info=chapter_context.chapter_characters or '暂无角色信息',
|
||||
chapter_careers=chapter_context.chapter_careers or '暂无职业信息',
|
||||
foreshadow_reminders=chapter_context.foreshadow_reminders or '暂无需要关注的伏笔',
|
||||
relevant_memories=chapter_context.relevant_memories or '暂无相关记忆'
|
||||
)
|
||||
|
||||
# 应用写作风格
|
||||
if style_content:
|
||||
prompt = WritingStyleManager.apply_style_to_prompt(base_prompt, style_content)
|
||||
else:
|
||||
prompt = base_prompt
|
||||
|
||||
# === 准备阶段 ===
|
||||
await tracker.preparing("准备AI提示词...")
|
||||
|
||||
system_prompt_with_style = None
|
||||
if style_content:
|
||||
system_prompt_with_style = f"""【🎨 写作风格要求 - 最高优先级】
|
||||
|
||||
{style_content}
|
||||
|
||||
⚠️ 请严格遵循上述写作风格要求进行创作,这是最重要的指令!
|
||||
确保在整个章节创作过程中始终保持风格的一致性。"""
|
||||
|
||||
calculated_max_tokens = int(target_word_count * 3)
|
||||
calculated_max_tokens = max(2000, min(calculated_max_tokens, 16000))
|
||||
|
||||
generate_kwargs = {
|
||||
"prompt": prompt,
|
||||
"system_prompt": system_prompt_with_style,
|
||||
"tool_choice": "required",
|
||||
"max_tokens": calculated_max_tokens
|
||||
}
|
||||
if custom_model:
|
||||
generate_kwargs["model"] = custom_model
|
||||
|
||||
# === 生成阶段 ===
|
||||
full_content = ""
|
||||
chunk_count = 0
|
||||
|
||||
await tracker.generating(
|
||||
current_chars=0,
|
||||
estimated_total=target_word_count
|
||||
)
|
||||
|
||||
async for chunk in ai_service.generate_text_stream(**generate_kwargs):
|
||||
# 检查是否被取消
|
||||
if chunk_count % 10 == 0 and await tracker.check_cancelled():
|
||||
logger.info(f"🚫 后台章节生成被取消: {chapter_id}")
|
||||
return
|
||||
|
||||
full_content += chunk
|
||||
chunk_count += 1
|
||||
|
||||
# 每10个chunk更新一次进度
|
||||
if chunk_count % 10 == 0:
|
||||
await tracker.generating(
|
||||
current_chars=len(full_content),
|
||||
estimated_total=target_word_count,
|
||||
message=f'正在创作中... 已生成 {len(full_content)} 字'
|
||||
)
|
||||
|
||||
await asyncio.sleep(0)
|
||||
|
||||
# === 保存阶段 ===
|
||||
await tracker.saving("正在保存章节...", 0.3)
|
||||
|
||||
async with write_lock:
|
||||
# 重新获取章节(确保最新状态)
|
||||
chapter_result = await db.execute(
|
||||
select(Chapter).where(Chapter.id == chapter_id)
|
||||
)
|
||||
current_chapter = chapter_result.scalar_one_or_none()
|
||||
if not current_chapter:
|
||||
await tracker.error("保存时章节不存在")
|
||||
return
|
||||
|
||||
old_word_count = current_chapter.word_count or 0
|
||||
current_chapter.content = full_content
|
||||
new_word_count = len(full_content)
|
||||
current_chapter.word_count = new_word_count
|
||||
current_chapter.status = "completed"
|
||||
|
||||
# 更新项目字数
|
||||
project_result = await db.execute(
|
||||
select(Project).where(Project.id == current_chapter.project_id)
|
||||
)
|
||||
project_obj = project_result.scalar_one_or_none()
|
||||
if project_obj:
|
||||
project_obj.current_words = (project_obj.current_words or 0) - old_word_count + new_word_count
|
||||
|
||||
# 记录生成历史
|
||||
history = GenerationHistory(
|
||||
project_id=current_chapter.project_id,
|
||||
chapter_id=current_chapter.id,
|
||||
prompt=f"创作章节: 第{current_chapter.chapter_number}章 {current_chapter.title}",
|
||||
generated_content=full_content[:500] if len(full_content) > 500 else full_content,
|
||||
model="default"
|
||||
)
|
||||
db.add(history)
|
||||
|
||||
await db.commit()
|
||||
|
||||
logger.info(f"✅ 后台创作章节 {chapter_id} 完成,共 {new_word_count} 字")
|
||||
|
||||
# 🔮 自动标记伏笔
|
||||
try:
|
||||
plant_result = await foreshadow_service.auto_plant_pending_foreshadows(
|
||||
db=db,
|
||||
project_id=current_chapter.project_id,
|
||||
chapter_id=chapter_id,
|
||||
chapter_number=current_chapter.chapter_number,
|
||||
chapter_content=full_content
|
||||
)
|
||||
if plant_result.get('planted_count', 0) > 0:
|
||||
logger.info(f"🔮 自动标记伏笔已埋入: {plant_result['planted_count']}个")
|
||||
except Exception as plant_error:
|
||||
logger.warning(f"⚠️ 自动标记伏笔埋入失败: {str(plant_error)}")
|
||||
|
||||
# 创建分析任务
|
||||
analysis_task = AnalysisTask(
|
||||
chapter_id=chapter_id,
|
||||
user_id=user_id,
|
||||
project_id=current_chapter.project_id,
|
||||
status='pending',
|
||||
progress=0
|
||||
)
|
||||
db.add(analysis_task)
|
||||
await db.commit()
|
||||
await db.refresh(analysis_task)
|
||||
|
||||
logger.info(f"📋 后台生成:已创建分析任务: {analysis_task.id}")
|
||||
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
# 启动后台分析
|
||||
asyncio.create_task(
|
||||
analyze_chapter_background(
|
||||
chapter_id=chapter_id,
|
||||
user_id=user_id,
|
||||
project_id=current_chapter.project_id,
|
||||
task_id=analysis_task.id,
|
||||
ai_service=ai_service
|
||||
)
|
||||
)
|
||||
|
||||
# === 完成 ===
|
||||
await tracker.complete(f"创作完成!共 {new_word_count} 字")
|
||||
|
||||
# 更新任务结果
|
||||
from app.services.background_task_service import background_task_service
|
||||
from sqlalchemy.ext.asyncio import async_sessionmaker, AsyncSession as BgAsyncSession
|
||||
from app.database import get_engine as bg_get_engine
|
||||
try:
|
||||
engine = await bg_get_engine(user_id)
|
||||
AsyncSessionLocal = async_sessionmaker(engine, class_=BgAsyncSession, expire_on_commit=False)
|
||||
async with AsyncSessionLocal() as result_db:
|
||||
from sqlalchemy import update as sql_update
|
||||
await result_db.execute(
|
||||
sql_update(BackgroundTask)
|
||||
.where(BackgroundTask.id == task_id)
|
||||
.values(task_result={
|
||||
"chapter_id": chapter_id,
|
||||
"word_count": new_word_count,
|
||||
"analysis_task_id": analysis_task.id
|
||||
})
|
||||
)
|
||||
await result_db.commit()
|
||||
except Exception as e:
|
||||
logger.warning(f"⚠️ 更新任务结果失败: {e}")
|
||||
|
||||
|
||||
def _build_analysis_task_status_payload(
|
||||
chapter_id: str,
|
||||
task: Optional[AnalysisTask],
|
||||
|
||||
@@ -1717,6 +1717,523 @@ async def continue_outline_generator(
|
||||
yield await tracker.error(f"续写失败: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/generate", summary="AI生成/续写大纲(后台任务)")
|
||||
async def generate_outline_task(
|
||||
data: Dict[str, Any],
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
user_ai_service: AIService = Depends(get_user_ai_service)
|
||||
):
|
||||
"""
|
||||
使用后台任务生成或续写小说大纲(不怕断连,关闭浏览器也继续运行)
|
||||
|
||||
返回task_id,前端通过 GET /api/tasks/{task_id} 轮询进度
|
||||
|
||||
支持模式:
|
||||
- auto/new/continue(同 generate-stream)
|
||||
"""
|
||||
from app.services.background_task_service import background_task_service, TaskProgressTracker
|
||||
from app.database import get_engine
|
||||
from sqlalchemy.ext.asyncio import async_sessionmaker, AsyncSession as NewAsyncSession
|
||||
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
project = await verify_project_access(data.get("project_id"), user_id, db)
|
||||
|
||||
# 判断模式
|
||||
mode = data.get("mode", "auto")
|
||||
existing_result = await db.execute(
|
||||
select(Outline)
|
||||
.where(Outline.project_id == data.get("project_id"))
|
||||
.order_by(Outline.order_index)
|
||||
)
|
||||
existing_outlines = existing_result.scalars().all()
|
||||
|
||||
if mode == "auto":
|
||||
mode = "continue" if existing_outlines else "new"
|
||||
|
||||
data["user_id"] = user_id
|
||||
data["mode"] = mode
|
||||
|
||||
if mode == "continue" and not existing_outlines:
|
||||
raise HTTPException(status_code=400, detail="续写模式需要已有大纲")
|
||||
|
||||
# 创建后台任务
|
||||
task_type = "outline_new" if mode == "new" else "outline_continue"
|
||||
task = await background_task_service.create_task(
|
||||
user_id=user_id,
|
||||
project_id=data.get("project_id"),
|
||||
task_type=task_type,
|
||||
task_input=data,
|
||||
db=db
|
||||
)
|
||||
|
||||
# 后台执行的函数
|
||||
async def _run_outline_generation(task_id: str, user_id: str):
|
||||
engine = await get_engine(user_id)
|
||||
AsyncSessionLocal = async_sessionmaker(engine, class_=NewAsyncSession, expire_on_commit=False)
|
||||
|
||||
async with AsyncSessionLocal() as bg_db:
|
||||
tracker = TaskProgressTracker(task_id, user_id, "大纲")
|
||||
try:
|
||||
await tracker.start()
|
||||
|
||||
# 获取AI服务(需要在后台创建新实例)
|
||||
from app.api.settings import get_user_ai_service_from_db
|
||||
bg_ai_service = await get_user_ai_service_from_db(user_id, bg_db)
|
||||
|
||||
if mode == "new":
|
||||
await _run_new_outline_bg(data, bg_db, bg_ai_service, tracker)
|
||||
else:
|
||||
await _run_continue_outline_bg(data, bg_db, bg_ai_service, tracker, user_id)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 后台大纲生成失败: {e}", exc_info=True)
|
||||
await tracker.error(str(e))
|
||||
|
||||
await background_task_service.spawn_background_task(
|
||||
task.id, user_id, _run_outline_generation
|
||||
)
|
||||
|
||||
return {
|
||||
"task_id": task.id,
|
||||
"task_type": task_type,
|
||||
"status": "pending",
|
||||
"message": "任务已创建,请通过 GET /api/tasks/{task_id} 查询进度"
|
||||
}
|
||||
|
||||
|
||||
async def _run_new_outline_bg(
|
||||
data: Dict[str, Any],
|
||||
db: AsyncSession,
|
||||
user_ai_service: AIService,
|
||||
tracker
|
||||
):
|
||||
"""后台执行全新大纲生成"""
|
||||
from app.database import get_engine
|
||||
from sqlalchemy.ext.asyncio import async_sessionmaker, AsyncSession as BgAsyncSession
|
||||
|
||||
project_id = data.get("project_id")
|
||||
chapter_count = int(data.get("chapter_count", 10))
|
||||
user_id_for_mcp = data.get("user_id")
|
||||
|
||||
await tracker.loading("加载项目信息...", 0.3)
|
||||
result = await db.execute(select(Project).where(Project.id == project_id))
|
||||
project = result.scalar_one_or_none()
|
||||
if not project:
|
||||
await tracker.error("项目不存在")
|
||||
return
|
||||
|
||||
await tracker.loading(f"准备生成{chapter_count}章大纲...", 0.6)
|
||||
characters_result = await db.execute(select(Character).where(Character.project_id == project_id))
|
||||
characters = characters_result.scalars().all()
|
||||
characters_info = _build_characters_info(characters)
|
||||
|
||||
if user_id_for_mcp:
|
||||
user_ai_service.user_id = user_id_for_mcp
|
||||
user_ai_service.db_session = db
|
||||
|
||||
await tracker.preparing("准备AI提示词...")
|
||||
template = await PromptService.get_template("OUTLINE_CREATE", user_id_for_mcp, db)
|
||||
prompt = PromptService.format_prompt(
|
||||
template,
|
||||
title=project.title,
|
||||
theme=data.get("theme") or project.theme or "未设定",
|
||||
genre=data.get("genre") or project.genre or "通用",
|
||||
chapter_count=chapter_count,
|
||||
narrative_perspective=data.get("narrative_perspective") or "第三人称",
|
||||
time_period=project.world_time_period or "未设定",
|
||||
location=project.world_location or "未设定",
|
||||
atmosphere=project.world_atmosphere or "未设定",
|
||||
rules=project.world_rules or "未设定",
|
||||
characters_info=characters_info or "暂无角色信息",
|
||||
requirements=data.get("requirements") or "",
|
||||
mcp_references=""
|
||||
)
|
||||
|
||||
model_param = data.get("model")
|
||||
provider_param = data.get("provider")
|
||||
|
||||
estimated_total = chapter_count * 1000
|
||||
accumulated_text = ""
|
||||
chunk_count = 0
|
||||
|
||||
await tracker.generating(current_chars=0, estimated_total=estimated_total)
|
||||
|
||||
async for chunk in user_ai_service.generate_text_stream(
|
||||
prompt=prompt, provider=provider_param, model=model_param
|
||||
):
|
||||
chunk_count += 1
|
||||
accumulated_text += chunk
|
||||
if chunk_count % 10 == 0:
|
||||
if await tracker.check_cancelled():
|
||||
await tracker.error("任务已取消")
|
||||
return
|
||||
await tracker.generating(
|
||||
current_chars=len(accumulated_text),
|
||||
estimated_total=estimated_total
|
||||
)
|
||||
|
||||
await tracker.parsing("解析大纲数据...")
|
||||
ai_content = accumulated_text
|
||||
|
||||
# 解析响应(带重试)
|
||||
max_retries = 2
|
||||
retry_count = 0
|
||||
outline_data = None
|
||||
|
||||
while retry_count <= max_retries:
|
||||
try:
|
||||
outline_data = _parse_ai_response(ai_content, raise_on_error=True)
|
||||
break
|
||||
except JSONParseError:
|
||||
retry_count += 1
|
||||
if retry_count > max_retries:
|
||||
outline_data = _parse_ai_response(ai_content, raise_on_error=False)
|
||||
break
|
||||
await tracker.retry(retry_count, max_retries, "JSON解析失败")
|
||||
tracker.reset_generating_progress()
|
||||
accumulated_text = ""
|
||||
retry_prompt = prompt + "\n\n【重要提醒】请确保返回完整的JSON数组。"
|
||||
async for chunk in user_ai_service.generate_text_stream(
|
||||
prompt=retry_prompt, provider=provider_param, model=model_param
|
||||
):
|
||||
accumulated_text += chunk
|
||||
ai_content = accumulated_text
|
||||
|
||||
# ✅ P0修复:先保存新数据,再删除旧数据
|
||||
await tracker.saving("保存新大纲到数据库...", 0.2)
|
||||
outlines = await _save_outlines(project_id, outline_data, db, start_index=1)
|
||||
await db.commit() # 先提交新数据!
|
||||
logger.info(f"✅ 新大纲已保存: {len(outlines)} 章")
|
||||
|
||||
# 新数据安全后,再清理旧数据
|
||||
await tracker.saving("清理旧数据...", 0.6)
|
||||
try:
|
||||
from sqlalchemy import delete as sql_delete
|
||||
|
||||
# 获取旧大纲(不包括刚保存的)
|
||||
new_outline_ids = [o.id for o in outlines]
|
||||
old_outlines_result = await db.execute(
|
||||
select(Outline).where(
|
||||
Outline.project_id == project_id,
|
||||
~Outline.id.in_(new_outline_ids)
|
||||
)
|
||||
)
|
||||
old_outlines = old_outlines_result.scalars().all()
|
||||
|
||||
if old_outlines:
|
||||
old_outline_ids = [o.id for o in old_outlines]
|
||||
|
||||
# 清理旧章节
|
||||
old_chapters_result = await db.execute(
|
||||
select(Chapter).where(
|
||||
Chapter.project_id == project_id,
|
||||
~Chapter.id.in_([ch.id for ch in await db.execute(
|
||||
select(Chapter).where(Chapter.outline_id.in_(new_outline_ids) if new_outline_ids else False)
|
||||
).scalars().all()] if new_outline_ids else [])
|
||||
)
|
||||
)
|
||||
# 简化:删除不属于新大纲的旧章节
|
||||
# 先获取新大纲对应的章节(one-to-one模式下通过chapter_number匹配)
|
||||
new_order_indexes = [o.order_index for o in outlines]
|
||||
|
||||
if project.outline_mode == 'one-to-one':
|
||||
old_chapters_result = await db.execute(
|
||||
select(Chapter).where(
|
||||
Chapter.project_id == project_id,
|
||||
~Chapter.chapter_number.in_(new_order_indexes)
|
||||
)
|
||||
)
|
||||
else:
|
||||
old_chapters_result = await db.execute(
|
||||
select(Chapter).where(
|
||||
Chapter.project_id == project_id,
|
||||
Chapter.outline_id.in_(old_outline_ids)
|
||||
)
|
||||
)
|
||||
|
||||
old_chapters = old_chapters_result.scalars().all()
|
||||
deleted_word_count = sum(ch.word_count or 0 for ch in old_chapters)
|
||||
|
||||
# 清理伏笔和记忆
|
||||
for ch in old_chapters:
|
||||
try:
|
||||
await memory_service.delete_chapter_memories(
|
||||
user_id=user_id_for_mcp, project_id=project_id, chapter_id=ch.id
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
await foreshadow_service.delete_chapter_foreshadows(
|
||||
db=db, project_id=project_id, chapter_id=ch.id, only_analysis_source=True
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 删除旧章节
|
||||
if project.outline_mode == 'one-to-one':
|
||||
await db.execute(
|
||||
sql_delete(Chapter).where(
|
||||
Chapter.project_id == project_id,
|
||||
~Chapter.chapter_number.in_(new_order_indexes)
|
||||
)
|
||||
)
|
||||
else:
|
||||
await db.execute(
|
||||
sql_delete(Chapter).where(Chapter.outline_id.in_(old_outline_ids))
|
||||
)
|
||||
|
||||
if deleted_word_count > 0:
|
||||
project.current_words = max(0, project.current_words - deleted_word_count)
|
||||
|
||||
# 清理伏笔
|
||||
try:
|
||||
await foreshadow_service.clear_project_foreshadows_for_reset(db, project_id)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 清理分析
|
||||
try:
|
||||
from app.models.memory import PlotAnalysis
|
||||
await db.execute(sql_delete(PlotAnalysis).where(PlotAnalysis.project_id == project_id))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 删除旧大纲
|
||||
await db.execute(
|
||||
sql_delete(Outline).where(Outline.id.in_(old_outline_ids))
|
||||
)
|
||||
|
||||
await db.commit()
|
||||
logger.info(f"✅ 旧数据清理完成: 删除 {len(old_outlines)} 个旧大纲, {len(old_chapters)} 个旧章节")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 清理旧数据失败(新数据已安全保存): {e}")
|
||||
# 新数据已保存,旧数据清理失败不影响
|
||||
|
||||
# 角色校验
|
||||
await tracker.saving("🎭 校验角色信息...", 0.7)
|
||||
try:
|
||||
await _check_and_create_missing_characters_from_outlines(
|
||||
outline_data=outline_data, project_id=project_id, db=db,
|
||||
user_ai_service=user_ai_service, user_id=data.get("user_id"),
|
||||
enable_mcp=data.get("enable_mcp", True), tracker=tracker
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 组织校验
|
||||
try:
|
||||
await _check_and_create_missing_organizations_from_outlines(
|
||||
outline_data=outline_data, project_id=project_id, db=db,
|
||||
user_ai_service=user_ai_service, user_id=data.get("user_id"),
|
||||
enable_mcp=data.get("enable_mcp", True), tracker=tracker
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 保存结果到任务记录
|
||||
result_data = {
|
||||
"message": f"成功生成{len(outlines)}章大纲",
|
||||
"total_chapters": len(outlines),
|
||||
"outline_ids": [o.id for o in outlines]
|
||||
}
|
||||
|
||||
# 更新任务结果
|
||||
from app.models.background_task import BackgroundTask
|
||||
task_result = await db.execute(select(BackgroundTask).where(BackgroundTask.id == tracker.task_id))
|
||||
bg_task = task_result.scalar_one_or_none()
|
||||
if bg_task:
|
||||
bg_task.task_result = result_data
|
||||
await db.commit()
|
||||
|
||||
await tracker.complete(f"成功生成{len(outlines)}章大纲")
|
||||
logger.info(f"✅ 后台大纲生成完成: {len(outlines)} 章")
|
||||
|
||||
|
||||
async def _run_continue_outline_bg(
|
||||
data: Dict[str, Any],
|
||||
db: AsyncSession,
|
||||
user_ai_service: AIService,
|
||||
tracker,
|
||||
user_id: str
|
||||
):
|
||||
"""后台执行大纲续写"""
|
||||
project_id = data.get("project_id")
|
||||
total_chapters = int(data.get("chapter_count", 5))
|
||||
|
||||
await tracker.loading("加载项目信息...", 0.2)
|
||||
result = await db.execute(select(Project).where(Project.id == project_id))
|
||||
project = result.scalar_one_or_none()
|
||||
if not project:
|
||||
await tracker.error("项目不存在")
|
||||
return
|
||||
|
||||
existing_result = await db.execute(
|
||||
select(Outline).where(Outline.project_id == project_id).order_by(Outline.order_index)
|
||||
)
|
||||
existing_outlines = existing_result.scalars().all()
|
||||
if not existing_outlines:
|
||||
await tracker.error("续写模式需要已有大纲")
|
||||
return
|
||||
|
||||
last_chapter_number = existing_outlines[-1].order_index
|
||||
|
||||
characters_result = await db.execute(select(Character).where(Character.project_id == project_id))
|
||||
characters = characters_result.scalars().all()
|
||||
|
||||
batch_size = 5
|
||||
total_batches = (total_chapters + batch_size - 1) // batch_size
|
||||
all_new_outlines = []
|
||||
current_start_chapter = last_chapter_number + 1
|
||||
|
||||
stage_instructions = {
|
||||
"development": "继续展开情节,深化角色关系",
|
||||
"climax": "进入故事高潮,矛盾激化",
|
||||
"ending": "解决主要冲突,给出结局"
|
||||
}
|
||||
stage_instruction = stage_instructions.get(data.get("plot_stage", "development"), "")
|
||||
|
||||
for batch_num in range(total_batches):
|
||||
if await tracker.check_cancelled():
|
||||
await tracker.error("任务已取消")
|
||||
return
|
||||
|
||||
remaining = total_chapters - len(all_new_outlines)
|
||||
current_batch_size = min(batch_size, remaining)
|
||||
tracker.reset_generating_progress()
|
||||
|
||||
await tracker.generating(
|
||||
message=f"📝 第{batch_num + 1}/{total_batches}批: 生成第{current_start_chapter}-{current_start_chapter + current_batch_size - 1}章"
|
||||
)
|
||||
|
||||
latest_result = await db.execute(
|
||||
select(Outline).where(Outline.project_id == project_id).order_by(Outline.order_index)
|
||||
)
|
||||
latest_outlines = latest_result.scalars().all()
|
||||
|
||||
context = await _build_outline_continue_context(
|
||||
project=project, latest_outlines=latest_outlines, characters=characters,
|
||||
chapter_count=current_batch_size,
|
||||
plot_stage=data.get("plot_stage", "development"),
|
||||
story_direction=data.get("story_direction", "自然延续"),
|
||||
requirements=data.get("requirements", ""), db=db
|
||||
)
|
||||
|
||||
user_ai_service.user_id = user_id
|
||||
user_ai_service.db_session = db
|
||||
|
||||
# 获取伏笔提醒
|
||||
foreshadow_reminders_text = "暂无需要关注的伏笔"
|
||||
try:
|
||||
foreshadow_context = await foreshadow_service.build_chapter_context(
|
||||
db=db, project_id=project_id, chapter_number=current_start_chapter,
|
||||
include_pending=False, include_overdue=True, lookahead=10
|
||||
)
|
||||
if foreshadow_context and foreshadow_context.get("context_text"):
|
||||
foreshadow_reminders_text = foreshadow_context["context_text"]
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
template = await PromptService.get_template("OUTLINE_CONTINUE", user_id, db)
|
||||
prompt = PromptService.format_prompt(
|
||||
template,
|
||||
title=project.title, theme=project.theme or "未设定",
|
||||
genre=project.genre or "通用",
|
||||
narrative_perspective=project.narrative_perspective or "第三人称",
|
||||
time_period=project.world_time_period or "未设定",
|
||||
location=project.world_location or "未设定",
|
||||
atmosphere=project.world_atmosphere or "未设定",
|
||||
rules=project.world_rules or "未设定",
|
||||
recent_outlines=context['recent_outlines'],
|
||||
characters_info=context['characters_info'],
|
||||
foreshadow_reminders=foreshadow_reminders_text,
|
||||
chapter_count=current_batch_size,
|
||||
start_chapter=current_start_chapter,
|
||||
end_chapter=current_start_chapter + current_batch_size - 1,
|
||||
current_chapter_count=len(latest_outlines),
|
||||
plot_stage_instruction=stage_instruction,
|
||||
story_direction=data.get("story_direction", "自然延续"),
|
||||
requirements=data.get("requirements", ""),
|
||||
mcp_references=""
|
||||
)
|
||||
|
||||
accumulated_text = ""
|
||||
chunk_count = 0
|
||||
estimated_chars = current_batch_size * 1000
|
||||
|
||||
async for chunk in user_ai_service.generate_text_stream(
|
||||
prompt=prompt, provider=data.get("provider"), model=data.get("model")
|
||||
):
|
||||
chunk_count += 1
|
||||
accumulated_text += chunk
|
||||
if chunk_count % 10 == 0:
|
||||
await tracker.generating(
|
||||
current_chars=len(accumulated_text), estimated_total=estimated_chars,
|
||||
message=f"📝 第{batch_num + 1}/{total_batches}批生成中..."
|
||||
)
|
||||
|
||||
await tracker.parsing(f"解析第{batch_num + 1}批数据...")
|
||||
|
||||
# 解析
|
||||
max_retries = 2
|
||||
retry_count = 0
|
||||
outline_data = None
|
||||
while retry_count <= max_retries:
|
||||
try:
|
||||
outline_data = _parse_ai_response(accumulated_text, raise_on_error=True)
|
||||
break
|
||||
except JSONParseError:
|
||||
retry_count += 1
|
||||
if retry_count > max_retries:
|
||||
outline_data = _parse_ai_response(accumulated_text, raise_on_error=False)
|
||||
break
|
||||
await tracker.retry(retry_count, max_retries, "JSON解析失败")
|
||||
tracker.reset_generating_progress()
|
||||
accumulated_text = ""
|
||||
retry_prompt = prompt + "\n\n【重要提醒】请确保返回完整的JSON数组。"
|
||||
async for chunk in user_ai_service.generate_text_stream(
|
||||
prompt=retry_prompt, provider=data.get("provider"), model=data.get("model")
|
||||
):
|
||||
accumulated_text += chunk
|
||||
|
||||
# 保存当前批次
|
||||
await tracker.saving(f"保存第{batch_num + 1}批大纲...", 0.5)
|
||||
batch_outlines = await _save_outlines(
|
||||
project_id, outline_data, db, start_index=current_start_chapter
|
||||
)
|
||||
await db.commit()
|
||||
all_new_outlines.extend(batch_outlines)
|
||||
current_start_chapter += current_batch_size
|
||||
|
||||
# 角色校验
|
||||
try:
|
||||
await _check_and_create_missing_characters_from_outlines(
|
||||
outline_data=outline_data, project_id=project_id, db=db,
|
||||
user_ai_service=user_ai_service, user_id=user_id,
|
||||
enable_mcp=data.get("enable_mcp", True), tracker=tracker
|
||||
)
|
||||
await db.commit()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 保存结果
|
||||
result_data = {
|
||||
"message": f"成功续写{len(all_new_outlines)}章大纲",
|
||||
"total_chapters": len(all_new_outlines),
|
||||
"outline_ids": [o.id for o in all_new_outlines]
|
||||
}
|
||||
from app.models.background_task import BackgroundTask
|
||||
task_result = await db.execute(select(BackgroundTask).where(BackgroundTask.id == tracker.task_id))
|
||||
bg_task = task_result.scalar_one_or_none()
|
||||
if bg_task:
|
||||
bg_task.task_result = result_data
|
||||
await db.commit()
|
||||
|
||||
await tracker.complete(f"成功续写{len(all_new_outlines)}章大纲")
|
||||
logger.info(f"✅ 后台大纲续写完成: {len(all_new_outlines)} 章")
|
||||
|
||||
|
||||
@router.post("/generate-stream", summary="AI生成/续写大纲(SSE流式)")
|
||||
async def generate_outline_stream(
|
||||
data: Dict[str, Any],
|
||||
|
||||
@@ -160,6 +160,44 @@ async def get_user_ai_service(
|
||||
)
|
||||
|
||||
|
||||
async def get_user_ai_service_from_db(user_id: str, db: AsyncSession) -> AIService:
|
||||
"""
|
||||
从数据库直接创建用户AI服务实例(用于后台任务,不依赖FastAPI的Depends)
|
||||
"""
|
||||
from app.models.mcp_plugin import MCPPlugin
|
||||
|
||||
result = await db.execute(
|
||||
select(Settings).where(Settings.user_id == user_id)
|
||||
)
|
||||
settings = result.scalar_one_or_none()
|
||||
|
||||
if not settings:
|
||||
env_defaults = read_env_defaults()
|
||||
settings = Settings(user_id=user_id, **env_defaults)
|
||||
db.add(settings)
|
||||
await db.commit()
|
||||
await db.refresh(settings)
|
||||
|
||||
mcp_result = await db.execute(
|
||||
select(MCPPlugin).where(MCPPlugin.user_id == user_id)
|
||||
)
|
||||
mcp_plugins = mcp_result.scalars().all()
|
||||
enable_mcp = any(plugin.enabled for plugin in mcp_plugins) if mcp_plugins else False
|
||||
|
||||
return create_user_ai_service_with_mcp(
|
||||
api_provider=settings.api_provider,
|
||||
api_key=settings.api_key,
|
||||
api_base_url=settings.api_base_url or "",
|
||||
model_name=settings.llm_model,
|
||||
temperature=settings.temperature,
|
||||
max_tokens=settings.max_tokens,
|
||||
user_id=user_id,
|
||||
db_session=db,
|
||||
system_prompt=settings.system_prompt,
|
||||
enable_mcp=enable_mcp,
|
||||
)
|
||||
|
||||
|
||||
@router.get("", response_model=SettingsResponse)
|
||||
async def get_settings(
|
||||
user: User = Depends(require_login),
|
||||
|
||||
@@ -0,0 +1,122 @@
|
||||
"""后台任务API - 查询状态、取消任务"""
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from typing import Optional
|
||||
|
||||
from app.database import get_db
|
||||
from app.models.background_task import BackgroundTask
|
||||
from app.services.background_task_service import background_task_service
|
||||
from app.logger import get_logger
|
||||
|
||||
router = APIRouter(prefix="/tasks", tags=["后台任务"])
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@router.get("/{task_id}", summary="获取任务状态")
|
||||
async def get_task_status(
|
||||
task_id: str,
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""获取后台任务的状态和进度"""
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="未登录")
|
||||
|
||||
task = await background_task_service.get_task(task_id, user_id, db)
|
||||
if not task:
|
||||
raise HTTPException(status_code=404, detail="任务不存在")
|
||||
|
||||
return {
|
||||
"id": task.id,
|
||||
"task_type": task.task_type,
|
||||
"project_id": task.project_id,
|
||||
"status": task.status,
|
||||
"progress": task.progress,
|
||||
"status_message": task.status_message,
|
||||
"progress_details": task.progress_details,
|
||||
"error_message": task.error_message,
|
||||
"task_result": task.task_result,
|
||||
"retry_count": task.retry_count,
|
||||
"cancel_requested": task.cancel_requested,
|
||||
"created_at": task.created_at.isoformat() if task.created_at else None,
|
||||
"started_at": task.started_at.isoformat() if task.started_at else None,
|
||||
"completed_at": task.completed_at.isoformat() if task.completed_at else None,
|
||||
"updated_at": task.updated_at.isoformat() if task.updated_at else None,
|
||||
}
|
||||
|
||||
|
||||
@router.get("", summary="获取任务列表")
|
||||
async def get_tasks(
|
||||
project_id: str,
|
||||
request: Request,
|
||||
task_type: Optional[str] = None,
|
||||
limit: int = 20,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""获取项目的后台任务列表"""
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="未登录")
|
||||
|
||||
tasks = await background_task_service.get_project_tasks(
|
||||
project_id, user_id, db, task_type=task_type, limit=limit
|
||||
)
|
||||
|
||||
return {
|
||||
"items": [
|
||||
{
|
||||
"id": t.id,
|
||||
"task_type": t.task_type,
|
||||
"status": t.status,
|
||||
"progress": t.progress,
|
||||
"status_message": t.status_message,
|
||||
"progress_details": t.progress_details,
|
||||
"error_message": t.error_message,
|
||||
"created_at": t.created_at.isoformat() if t.created_at else None,
|
||||
"completed_at": t.completed_at.isoformat() if t.completed_at else None,
|
||||
}
|
||||
for t in tasks
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
@router.post("/{task_id}/cancel", summary="取消任务")
|
||||
async def cancel_task(
|
||||
task_id: str,
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""请求取消后台任务"""
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="未登录")
|
||||
|
||||
success = await background_task_service.cancel_task(task_id, user_id, db)
|
||||
if not success:
|
||||
raise HTTPException(status_code=400, detail="无法取消任务(不存在或已完成)")
|
||||
|
||||
return {"message": "任务已取消", "task_id": task_id}
|
||||
|
||||
|
||||
@router.delete("/{task_id}", summary="删除任务记录")
|
||||
async def delete_task(
|
||||
task_id: str,
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""删除已完成/失败的任务记录"""
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="未登录")
|
||||
|
||||
task = await background_task_service.get_task(task_id, user_id, db)
|
||||
if not task:
|
||||
raise HTTPException(status_code=404, detail="任务不存在")
|
||||
|
||||
if task.status in ("pending", "running"):
|
||||
raise HTTPException(status_code=400, detail="无法删除进行中的任务,请先取消")
|
||||
|
||||
await db.delete(task)
|
||||
await db.commit()
|
||||
return {"message": "任务记录已删除"}
|
||||
@@ -21,7 +21,8 @@ from app.models import (
|
||||
Settings, WritingStyle, ProjectDefaultStyle,
|
||||
RelationshipType, CharacterRelationship, Organization, OrganizationMember,
|
||||
StoryMemory, PlotAnalysis, AnalysisTask, BatchGenerationTask,
|
||||
RegenerationTask, Career, CharacterCareer, User, MCPPlugin, PromptTemplate
|
||||
RegenerationTask, Career, CharacterCareer, User, MCPPlugin, PromptTemplate,
|
||||
BackgroundTask
|
||||
)
|
||||
|
||||
# 引擎缓存:每个用户一个引擎
|
||||
|
||||
+17
-2
@@ -29,7 +29,21 @@ async def lifespan(app: FastAPI):
|
||||
"""应用生命周期管理"""
|
||||
# 注册MCP状态同步服务
|
||||
register_status_sync()
|
||||
|
||||
|
||||
# 安全保障:确保后台任务表存在(兼容未执行Alembic迁移的旧部署)
|
||||
try:
|
||||
from app.database import get_engine
|
||||
from app.models.background_task import BackgroundTask
|
||||
_startup_engine = await get_engine("system")
|
||||
async with _startup_engine.begin() as conn:
|
||||
# 仅创建 background_tasks 表(如果不存在),不影响其他表
|
||||
await conn.run_sync(
|
||||
lambda sync_conn: BackgroundTask.__table__.create(sync_conn, checkfirst=True)
|
||||
)
|
||||
logger.info("后台任务表检查完成")
|
||||
except Exception as e:
|
||||
logger.warning(f"后台任务表检查失败(不影响启动): {e}")
|
||||
|
||||
logger.info("应用启动完成")
|
||||
|
||||
yield
|
||||
@@ -133,7 +147,7 @@ from app.api import (
|
||||
auth, users, settings, writing_styles, memories,
|
||||
mcp_plugins, admin, inspiration, prompt_templates,
|
||||
changelog, careers, foreshadows, prompt_workshop, book_import,
|
||||
project_covers
|
||||
project_covers, tasks
|
||||
)
|
||||
|
||||
app.include_router(auth.router, prefix="/api")
|
||||
@@ -159,6 +173,7 @@ app.include_router(prompt_templates.router, prefix="/api") # 提示词模板管
|
||||
app.include_router(changelog.router, prefix="/api") # 更新日志API
|
||||
app.include_router(prompt_workshop.router, prefix="/api") # 提示词工坊API
|
||||
app.include_router(book_import.router, prefix="/api") # 拆书导入API
|
||||
app.include_router(tasks.router, prefix="/api") # 后台任务API
|
||||
|
||||
static_dir = Path(__file__).parent.parent / "static"
|
||||
generated_assets_root_dir = Path(__file__).parent.parent / "storage"
|
||||
|
||||
@@ -18,6 +18,7 @@ from app.models.career import Career, CharacterCareer
|
||||
from app.models.prompt_template import PromptTemplate
|
||||
from app.models.foreshadow import Foreshadow
|
||||
from app.models.prompt_workshop import PromptWorkshopItem, PromptSubmission, PromptWorkshopLike
|
||||
from app.models.background_task import BackgroundTask
|
||||
|
||||
__all__ = [
|
||||
"Project",
|
||||
@@ -46,5 +47,6 @@ __all__ = [
|
||||
"Foreshadow",
|
||||
"PromptWorkshopItem",
|
||||
"PromptSubmission",
|
||||
"PromptWorkshopLike"
|
||||
]
|
||||
"PromptWorkshopLike",
|
||||
"BackgroundTask"
|
||||
]
|
||||
|
||||
@@ -0,0 +1,46 @@
|
||||
"""后台任务数据模型 - 用于长时间运行的AI生成任务"""
|
||||
from sqlalchemy import Column, String, Integer, DateTime, Boolean, JSON, Text
|
||||
from sqlalchemy.sql import func
|
||||
from app.database import Base
|
||||
import uuid
|
||||
|
||||
|
||||
class BackgroundTask(Base):
|
||||
"""后台任务表 - 追踪所有长时间运行的生成任务"""
|
||||
__tablename__ = "background_tasks"
|
||||
|
||||
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
|
||||
user_id = Column(String(100), nullable=False, index=True, comment="用户ID")
|
||||
project_id = Column(String(36), nullable=False, index=True, comment="项目ID")
|
||||
|
||||
# 任务类型
|
||||
task_type = Column(String(50), nullable=False, comment="任务类型: outline_new/outline_continue/outline_expand/chapter_generate/chapter_batch/wizard")
|
||||
|
||||
# 任务状态
|
||||
status = Column(String(20), default="pending", comment="任务状态: pending/running/completed/failed/cancelled")
|
||||
progress = Column(Integer, default=0, comment="进度百分比(0-100)")
|
||||
status_message = Column(String(500), comment="当前状态消息")
|
||||
|
||||
# 任务输入/输出
|
||||
task_input = Column(JSON, comment="任务输入参数(JSON)")
|
||||
task_result = Column(JSON, comment="任务结果(JSON)")
|
||||
error_message = Column(Text, comment="错误信息")
|
||||
|
||||
# 进度详情(用于前端展示实时进度)
|
||||
progress_details = Column(JSON, comment="进度详情: {stage, message, word_count, etc.}")
|
||||
|
||||
# 取消支持
|
||||
cancel_requested = Column(Boolean, default=False, comment="是否请求取消")
|
||||
|
||||
# 重试信息
|
||||
retry_count = Column(Integer, default=0, comment="已重试次数")
|
||||
max_retries = Column(Integer, default=3, comment="最大重试次数")
|
||||
|
||||
# 时间记录
|
||||
created_at = Column(DateTime, server_default=func.now(), comment="创建时间")
|
||||
started_at = Column(DateTime, comment="开始时间")
|
||||
completed_at = Column(DateTime, comment="完成时间")
|
||||
updated_at = Column(DateTime, server_default=func.now(), onupdate=func.now(), comment="更新时间")
|
||||
|
||||
def __repr__(self):
|
||||
return f"<BackgroundTask(id={self.id[:8]}, type={self.task_type}, status={self.status}, progress={self.progress}%)>"
|
||||
@@ -138,6 +138,7 @@ class BatchGenerateRequest(BaseModel):
|
||||
enable_mcp: bool = Field(True, description="是否启用MCP工具增强(搜索参考资料)")
|
||||
max_retries: int = Field(3, description="每个章节的最大重试次数", ge=0, le=5)
|
||||
model: Optional[str] = Field(None, description="指定使用的AI模型,不提供则使用用户默认模型")
|
||||
narrative_perspective: Optional[str] = Field(None, description="临时指定叙事人称,不提供则使用项目默认")
|
||||
|
||||
|
||||
class BatchGenerateResponse(BaseModel):
|
||||
|
||||
@@ -0,0 +1,387 @@
|
||||
"""后台任务管理服务 - 管理长时间运行的AI生成任务"""
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
from typing import Dict, Any, Optional, Callable, Awaitable
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
from sqlalchemy import select, update
|
||||
from app.database import get_engine
|
||||
from app.models.background_task import BackgroundTask
|
||||
from app.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class TaskProgressTracker:
|
||||
"""后台任务进度追踪器(替代SSE的WizardProgressTracker)"""
|
||||
|
||||
def __init__(self, task_id: str, user_id: str, task_name: str = "任务"):
|
||||
self.task_id = task_id
|
||||
self.user_id = user_id
|
||||
self.task_name = task_name
|
||||
self.current_progress = 0
|
||||
self._last_generating_progress = 20
|
||||
|
||||
async def _update_task(self, **kwargs):
|
||||
"""更新任务状态到数据库"""
|
||||
try:
|
||||
engine = await get_engine(self.user_id)
|
||||
AsyncSessionLocal = async_sessionmaker(
|
||||
engine, class_=AsyncSession, expire_on_commit=False
|
||||
)
|
||||
async with AsyncSessionLocal() as session:
|
||||
result = await session.execute(
|
||||
select(BackgroundTask).where(BackgroundTask.id == self.task_id)
|
||||
)
|
||||
task = result.scalar_one_or_none()
|
||||
if task:
|
||||
for key, value in kwargs.items():
|
||||
setattr(task, key, value)
|
||||
task.updated_at = datetime.now()
|
||||
await session.commit()
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 更新任务进度失败: {e}")
|
||||
|
||||
async def start(self, message: str = None):
|
||||
self.current_progress = 0
|
||||
msg = message or f"开始生成{self.task_name}..."
|
||||
await self._update_task(
|
||||
status="running", progress=0, status_message=msg,
|
||||
started_at=datetime.now(),
|
||||
progress_details={"stage": "init", "message": msg}
|
||||
)
|
||||
|
||||
async def loading(self, message: str = None, sub_progress: float = 0.5):
|
||||
progress = 5 + int(10 * sub_progress)
|
||||
self.current_progress = progress
|
||||
msg = message or "加载数据中..."
|
||||
await self._update_task(
|
||||
progress=progress, status_message=msg,
|
||||
progress_details={"stage": "loading", "message": msg}
|
||||
)
|
||||
|
||||
async def preparing(self, message: str = None):
|
||||
self.current_progress = 17
|
||||
msg = message or "准备AI提示词..."
|
||||
await self._update_task(
|
||||
progress=17, status_message=msg,
|
||||
progress_details={"stage": "preparing", "message": msg}
|
||||
)
|
||||
|
||||
async def generating(self, current_chars: int = 0, estimated_total: int = 5000,
|
||||
message: str = None, retry_count: int = 0, max_retries: int = 3):
|
||||
sub_progress = min(current_chars / max(estimated_total, 1), 1.0)
|
||||
progress = 20 + int(65 * sub_progress)
|
||||
if progress < self._last_generating_progress:
|
||||
progress = self._last_generating_progress
|
||||
else:
|
||||
self._last_generating_progress = progress
|
||||
self.current_progress = progress
|
||||
|
||||
retry_suffix = f" (重试 {retry_count}/{max_retries})" if retry_count > 0 else ""
|
||||
msg = message or f"生成{self.task_name}中... ({current_chars}字符){retry_suffix}"
|
||||
await self._update_task(
|
||||
progress=progress, status_message=msg,
|
||||
progress_details={"stage": "generating", "message": msg, "current_chars": current_chars}
|
||||
)
|
||||
|
||||
async def parsing(self, message: str = None):
|
||||
self.current_progress = 88
|
||||
msg = message or f"解析{self.task_name}数据..."
|
||||
await self._update_task(
|
||||
progress=88, status_message=msg,
|
||||
progress_details={"stage": "parsing", "message": msg}
|
||||
)
|
||||
|
||||
async def saving(self, message: str = None, sub_progress: float = 0.5):
|
||||
progress = 92 + int(6 * sub_progress)
|
||||
self.current_progress = progress
|
||||
msg = message or f"保存{self.task_name}到数据库..."
|
||||
await self._update_task(
|
||||
progress=progress, status_message=msg,
|
||||
progress_details={"stage": "saving", "message": msg}
|
||||
)
|
||||
|
||||
async def complete(self, message: str = None):
|
||||
self.current_progress = 100
|
||||
msg = message or f"{self.task_name}生成完成!"
|
||||
await self._update_task(
|
||||
status="completed", progress=100, status_message=msg,
|
||||
completed_at=datetime.now(),
|
||||
progress_details={"stage": "complete", "message": msg}
|
||||
)
|
||||
|
||||
async def error(self, error_message: str):
|
||||
await self._update_task(
|
||||
status="failed", error_message=error_message,
|
||||
status_message=f"失败: {error_message}",
|
||||
completed_at=datetime.now(),
|
||||
progress_details={"stage": "error", "message": error_message}
|
||||
)
|
||||
|
||||
async def warning(self, message: str):
|
||||
await self._update_task(
|
||||
status_message=f"⚠️ {message}",
|
||||
progress_details={"stage": "warning", "message": message}
|
||||
)
|
||||
|
||||
async def retry(self, retry_count: int, max_retries: int, reason: str = "准备重试"):
|
||||
msg = f"⚠️ {reason}... ({retry_count}/{max_retries})"
|
||||
await self._update_task(
|
||||
status_message=msg, retry_count=retry_count,
|
||||
progress_details={"stage": "retry", "message": msg, "retry_count": retry_count}
|
||||
)
|
||||
|
||||
def reset_generating_progress(self):
|
||||
self._last_generating_progress = 20
|
||||
|
||||
async def check_cancelled(self) -> bool:
|
||||
"""检查任务是否被取消"""
|
||||
try:
|
||||
engine = await get_engine(self.user_id)
|
||||
AsyncSessionLocal = async_sessionmaker(
|
||||
engine, class_=AsyncSession, expire_on_commit=False
|
||||
)
|
||||
async with AsyncSessionLocal() as session:
|
||||
result = await session.execute(
|
||||
select(BackgroundTask.cancel_requested)
|
||||
.where(BackgroundTask.id == self.task_id)
|
||||
)
|
||||
cancelled = result.scalar_one_or_none()
|
||||
return bool(cancelled)
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
class BackgroundTaskService:
|
||||
"""后台任务管理服务(按用户排队:同用户任务逐个执行,不同用户可并发)"""
|
||||
|
||||
def __init__(self):
|
||||
self._user_queues: Dict[str, asyncio.Queue] = {} # user_id -> Queue
|
||||
self._user_workers: Dict[str, bool] = {} # user_id -> worker是否运行中
|
||||
|
||||
def _ensure_user_queue(self, user_id: str) -> asyncio.Queue:
|
||||
"""确保指定用户的队列已初始化"""
|
||||
if user_id not in self._user_queues:
|
||||
self._user_queues[user_id] = asyncio.Queue()
|
||||
return self._user_queues[user_id]
|
||||
|
||||
async def _start_user_worker(self, user_id: str):
|
||||
"""启动指定用户的工作协程"""
|
||||
if self._user_workers.get(user_id, False):
|
||||
return
|
||||
self._user_workers[user_id] = True
|
||||
asyncio.create_task(self._user_worker_loop(user_id))
|
||||
logger.info(f"📋 用户 {user_id[:8]} 的任务队列工作协程已启动")
|
||||
|
||||
async def _user_worker_loop(self, user_id: str):
|
||||
"""从指定用户的队列中逐个取出任务并执行"""
|
||||
queue = self._user_queues[user_id]
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
task_item = await queue.get()
|
||||
task_id = task_item["task_id"]
|
||||
task_func = task_item["task_func"]
|
||||
args = task_item["args"]
|
||||
kwargs = task_item["kwargs"]
|
||||
|
||||
logger.info(f"🔄 [用户{user_id[:8]}] 队列开始执行任务: {task_id[:8]} (队列剩余: {queue.qsize()})")
|
||||
|
||||
try:
|
||||
await task_func(task_id, args["user_id"], *args["extra_args"], **kwargs)
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 后台任务 {task_id[:8]} 异常: {e}", exc_info=True)
|
||||
# 确保任务状态更新为失败
|
||||
try:
|
||||
engine = await get_engine(user_id)
|
||||
AsyncSessionLocal = async_sessionmaker(
|
||||
engine, class_=AsyncSession, expire_on_commit=False
|
||||
)
|
||||
async with AsyncSessionLocal() as session:
|
||||
result = await session.execute(
|
||||
select(BackgroundTask).where(BackgroundTask.id == task_id)
|
||||
)
|
||||
task = result.scalar_one_or_none()
|
||||
if task and task.status == "running":
|
||||
task.status = "failed"
|
||||
task.error_message = str(e)
|
||||
task.status_message = f"任务失败: {str(e)}"
|
||||
task.completed_at = datetime.now()
|
||||
await session.commit()
|
||||
except Exception as update_err:
|
||||
logger.error(f"❌ 更新失败任务状态失败: {update_err}")
|
||||
finally:
|
||||
queue.task_done()
|
||||
logger.info(f"✅ [用户{user_id[:8]}] 队列任务完成: {task_id[:8]} (队列剩余: {queue.qsize()})")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ [用户{user_id[:8]}] 队列工作循环异常: {e}", exc_info=True)
|
||||
finally:
|
||||
# 工作协程退出时清理标记
|
||||
self._user_workers.pop(user_id, None)
|
||||
logger.info(f"📋 用户 {user_id[:8]} 的工作协程已退出")
|
||||
|
||||
@staticmethod
|
||||
async def create_task(
|
||||
user_id: str,
|
||||
project_id: str,
|
||||
task_type: str,
|
||||
task_input: Dict[str, Any] = None,
|
||||
db: AsyncSession = None
|
||||
) -> BackgroundTask:
|
||||
"""创建后台任务记录"""
|
||||
task = BackgroundTask(
|
||||
user_id=user_id,
|
||||
project_id=project_id,
|
||||
task_type=task_type,
|
||||
task_input=task_input or {},
|
||||
status="pending",
|
||||
progress=0,
|
||||
status_message="任务已创建,等待执行..."
|
||||
)
|
||||
db.add(task)
|
||||
await db.commit()
|
||||
await db.refresh(task)
|
||||
logger.info(f"📋 创建后台任务: {task.id[:8]} type={task_type} project={project_id[:8]}")
|
||||
return task
|
||||
|
||||
@staticmethod
|
||||
async def get_task(task_id: str, user_id: str, db: AsyncSession) -> Optional[BackgroundTask]:
|
||||
"""获取任务详情"""
|
||||
result = await db.execute(
|
||||
select(BackgroundTask).where(
|
||||
BackgroundTask.id == task_id,
|
||||
BackgroundTask.user_id == user_id
|
||||
)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
@staticmethod
|
||||
async def get_project_tasks(
|
||||
project_id: str, user_id: str, db: AsyncSession,
|
||||
task_type: str = None, limit: int = 20
|
||||
) -> list:
|
||||
"""获取项目的任务列表"""
|
||||
query = (
|
||||
select(BackgroundTask)
|
||||
.where(
|
||||
BackgroundTask.project_id == project_id,
|
||||
BackgroundTask.user_id == user_id
|
||||
)
|
||||
.order_by(BackgroundTask.created_at.desc())
|
||||
)
|
||||
if task_type:
|
||||
query = query.where(BackgroundTask.task_type == task_type)
|
||||
query = query.limit(limit)
|
||||
result = await db.execute(query)
|
||||
return result.scalars().all()
|
||||
|
||||
@staticmethod
|
||||
async def cancel_task(task_id: str, user_id: str, db: AsyncSession) -> bool:
|
||||
"""请求取消任务"""
|
||||
result = await db.execute(
|
||||
select(BackgroundTask).where(
|
||||
BackgroundTask.id == task_id,
|
||||
BackgroundTask.user_id == user_id
|
||||
)
|
||||
)
|
||||
task = result.scalar_one_or_none()
|
||||
if not task:
|
||||
return False
|
||||
if task.status not in ("pending", "running"):
|
||||
return False
|
||||
task.cancel_requested = True
|
||||
task.status = "cancelled"
|
||||
task.status_message = "任务已取消"
|
||||
task.completed_at = datetime.now()
|
||||
await db.commit()
|
||||
logger.info(f"🚫 取消任务: {task_id[:8]}")
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
async def cleanup_old_tasks(user_id: str, db: AsyncSession, days: int = 7):
|
||||
"""清理旧任务记录"""
|
||||
from sqlalchemy import delete as sql_delete
|
||||
from datetime import timedelta
|
||||
cutoff = datetime.now() - timedelta(days=days)
|
||||
result = await db.execute(
|
||||
sql_delete(BackgroundTask).where(
|
||||
BackgroundTask.user_id == user_id,
|
||||
BackgroundTask.status.in_(["completed", "failed", "cancelled"]),
|
||||
BackgroundTask.completed_at < cutoff
|
||||
)
|
||||
)
|
||||
if result.rowcount > 0:
|
||||
await db.commit()
|
||||
logger.info(f"🧹 清理用户 {user_id[:8]} 的 {result.rowcount} 条旧任务记录")
|
||||
|
||||
async def spawn_background_task(
|
||||
self,
|
||||
task_id: str,
|
||||
user_id: str,
|
||||
task_func: Callable[..., Awaitable],
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
"""
|
||||
将任务加入该用户的队列排队执行(同一用户FIFO,不同用户可并发)
|
||||
|
||||
Args:
|
||||
task_id: 任务ID
|
||||
user_id: 用户ID
|
||||
task_func: 异步任务函数
|
||||
*args, **kwargs: 传递给task_func的参数
|
||||
"""
|
||||
# 确保该用户的队列和工作协程已启动
|
||||
queue = self._ensure_user_queue(user_id)
|
||||
await self._start_user_worker(user_id)
|
||||
|
||||
# 将任务放入该用户的队列
|
||||
await queue.put({
|
||||
"task_id": task_id,
|
||||
"task_func": task_func,
|
||||
"args": {"user_id": user_id, "extra_args": args},
|
||||
"kwargs": kwargs,
|
||||
})
|
||||
queue_size = queue.qsize()
|
||||
logger.info(f"📥 任务已加入用户 {user_id[:8]} 的队列: {task_id[:8]} (当前队列长度: {queue_size})")
|
||||
|
||||
# 更新任务状态,显示排队位置
|
||||
try:
|
||||
engine = await get_engine(user_id)
|
||||
AsyncSessionLocal = async_sessionmaker(
|
||||
engine, class_=AsyncSession, expire_on_commit=False
|
||||
)
|
||||
async with AsyncSessionLocal() as session:
|
||||
result = await session.execute(
|
||||
select(BackgroundTask).where(BackgroundTask.id == task_id)
|
||||
)
|
||||
task = result.scalar_one_or_none()
|
||||
if task and task.status == "pending":
|
||||
if queue_size > 0:
|
||||
task.status_message = f"排队中,前方还有 {queue_size} 个任务等待..."
|
||||
else:
|
||||
task.status_message = "即将开始执行..."
|
||||
task.progress_details = {"stage": "queued", "queue_size": queue_size}
|
||||
task.updated_at = datetime.now()
|
||||
await session.commit()
|
||||
except Exception as e:
|
||||
logger.error(f"更新队列位置信息失败: {e}")
|
||||
|
||||
def get_queue_size(self, user_id: str = None) -> int:
|
||||
"""获取队列中等待的任务数量"""
|
||||
if user_id:
|
||||
queue = self._user_queues.get(user_id)
|
||||
return queue.qsize() if queue else 0
|
||||
# 所有用户队列总数
|
||||
return sum(q.qsize() for q in self._user_queues.values())
|
||||
|
||||
def get_all_queue_info(self) -> Dict[str, int]:
|
||||
"""获取所有用户的队列信息"""
|
||||
return {
|
||||
uid: q.qsize() for uid, q in self._user_queues.items() if q.qsize() > 0
|
||||
}
|
||||
|
||||
|
||||
# 全局单例
|
||||
background_task_service = BackgroundTaskService()
|
||||
@@ -1284,7 +1284,7 @@ class ForeshadowService:
|
||||
planted_foreshadows = await self.get_planted_foreshadows_for_analysis(db, project_id)
|
||||
|
||||
# 每章最多创建的新伏笔数量
|
||||
MAX_NEW_FORESHADOWS_PER_CHAPTER = 2
|
||||
MAX_NEW_FORESHADOWS_PER_CHAPTER = 5
|
||||
new_foreshadow_count = 0
|
||||
|
||||
for fs_data in analysis_foreshadows:
|
||||
|
||||
@@ -26,15 +26,101 @@ _QUOTE_MAP = {
|
||||
}
|
||||
|
||||
|
||||
def _is_content_quote(text: str, pos: int) -> bool:
|
||||
"""
|
||||
判断字符串值内的 '"' 是否为内容引号(需转义)而非 JSON 结束引号。
|
||||
|
||||
合法 JSON 中,字符串结束引号之后的非空白字符必须是:
|
||||
',' (值分隔) / '}' (关闭对象) / ']' (关闭数组)
|
||||
|
||||
如果 '"' 后面不符合这些模式,则是 AI 写入的内容引号,需要转义。
|
||||
"""
|
||||
j = pos + 1
|
||||
|
||||
# 跳过空格和制表符
|
||||
while j < len(text) and text[j] in ' \t':
|
||||
j += 1
|
||||
|
||||
if j >= len(text):
|
||||
return False # 文本末尾,视为结束引号
|
||||
|
||||
ch = text[j]
|
||||
|
||||
# } 或 ] → 结束引号
|
||||
if ch in ('}', ']'):
|
||||
return False
|
||||
|
||||
# 换行 → 检查下一行开头判断
|
||||
if ch == '\n' or ch == '\r':
|
||||
k = j + (2 if (ch == '\r' and j + 1 < len(text) and text[j + 1] == '\n') else 1)
|
||||
while k < len(text) and text[k] in ' \t':
|
||||
k += 1
|
||||
if k >= len(text):
|
||||
return False
|
||||
# 下一行以 " (JSON key) 或 } 或 ] 开头 → 结束引号
|
||||
if text[k] == '"' or text[k] in ('}', ']'):
|
||||
return False
|
||||
return True
|
||||
|
||||
# , → 需要检查逗号后面是什么
|
||||
if ch == ',':
|
||||
k = j + 1
|
||||
while k < len(text) and text[k] in ' \t':
|
||||
k += 1
|
||||
|
||||
if k >= len(text):
|
||||
return False
|
||||
|
||||
# 逗号后跟换行 → 检查下一行
|
||||
if text[k] in ('\n', '\r'):
|
||||
k2 = k + (2 if (text[k] == '\r' and k + 1 < len(text) and text[k + 1] == '\n') else 1)
|
||||
while k2 < len(text) and text[k2] in ' \t\n\r':
|
||||
k2 += 1
|
||||
if k2 >= len(text):
|
||||
return False
|
||||
if text[k2] == '"' or text[k2] in ('}', ']'):
|
||||
return False
|
||||
return True
|
||||
|
||||
after_comma = text[k]
|
||||
|
||||
# 结构性逗号后应为 JSON 值的开头
|
||||
if after_comma == '"':
|
||||
return False # 字符串值或 key
|
||||
if after_comma.isdigit() or after_comma == '-':
|
||||
return False # 数字
|
||||
if after_comma in ('{', '['):
|
||||
return False # 对象/数组
|
||||
if text[k:k+4] in ('true', 'null'):
|
||||
return False
|
||||
if text[k:k+5] == 'false':
|
||||
return False
|
||||
|
||||
# 逗号后不是 JSON 值开头 → 内容逗号,引号是内容引号
|
||||
return True
|
||||
|
||||
# : → 通常在字符串结束后不可能出现,保守处理为结束引号
|
||||
if ch == ':':
|
||||
return False
|
||||
|
||||
# 其他字符(中文、字母等)→ 内容引号
|
||||
return True
|
||||
|
||||
|
||||
def _fix_json_string_values(text: str) -> str:
|
||||
"""
|
||||
修复JSON字符串值中的常见问题:
|
||||
1. 裸换行符/制表符 → 转义
|
||||
2. 字符串值内的中文引号 → 转义为ASCII引号(避免破坏JSON结构)
|
||||
3. 结构位置的中文引号 → 直接替换为ASCII引号
|
||||
上下文感知的 JSON 修复,区分字符串内外分别处理。
|
||||
|
||||
AI生成的JSON常在字符串值中插入未转义的换行符和中文引号。
|
||||
此函数遍历文本,区分字符串内外,分别处理。
|
||||
字符串值内:
|
||||
1. 裸换行符/制表符 → 转义
|
||||
2. 中文引号(""等) → 转义为 \\"
|
||||
3. 未转义的 ASCII 双引号 → 智能检测:内容引号转义,结束引号保留
|
||||
4. 中文逗号/冒号 → 保留原样(是内容字符)
|
||||
|
||||
结构位置(字符串外):
|
||||
1. 中文引号 → ASCII 引号
|
||||
2. 中文逗号 → ASCII 逗号
|
||||
3. 中文冒号 → ASCII 冒号
|
||||
"""
|
||||
if not text or '"' not in text:
|
||||
return text
|
||||
@@ -47,111 +133,234 @@ def _fix_json_string_values(text: str) -> str:
|
||||
while i < len(text):
|
||||
c = text[i]
|
||||
|
||||
if c == '"' and not in_string:
|
||||
# 进入字符串
|
||||
in_string = True
|
||||
# === 非字符串内(结构位置)===
|
||||
if not in_string:
|
||||
# 结构位置的中文标点 → ASCII
|
||||
if c == '\uff0c': # ,→ ,
|
||||
result.append(',')
|
||||
fixed_count += 1
|
||||
i += 1
|
||||
continue
|
||||
if c == '\uff1a': # :→ :
|
||||
result.append(':')
|
||||
fixed_count += 1
|
||||
i += 1
|
||||
continue
|
||||
if c in _QUOTE_MAP:
|
||||
result.append(_QUOTE_MAP[c])
|
||||
fixed_count += 1
|
||||
i += 1
|
||||
continue
|
||||
|
||||
# ASCII 双引号 → 进入字符串
|
||||
if c == '"':
|
||||
in_string = True
|
||||
result.append(c)
|
||||
i += 1
|
||||
continue
|
||||
|
||||
result.append(c)
|
||||
i += 1
|
||||
continue
|
||||
|
||||
if in_string:
|
||||
if c == '\\':
|
||||
# 转义字符,检查下一个字符是否合法
|
||||
if i + 1 < len(text):
|
||||
next_c = text[i + 1]
|
||||
# JSON 合法转义:\" \\ \/ \b \f \n \r \t \uXXXX
|
||||
if next_c in ('"', '\\', '/', 'b', 'f', 'n', 'r', 't'):
|
||||
# 合法转义,直接保留
|
||||
result.append(c)
|
||||
result.append(next_c)
|
||||
i += 2
|
||||
# === 字符串值内 ===
|
||||
|
||||
# 转义字符处理
|
||||
if c == '\\':
|
||||
if i + 1 < len(text):
|
||||
next_c = text[i + 1]
|
||||
if next_c in ('"', '\\', '/', 'b', 'f', 'n', 'r', 't'):
|
||||
result.append(c)
|
||||
result.append(next_c)
|
||||
i += 2
|
||||
continue
|
||||
elif next_c == 'u':
|
||||
if i + 5 < len(text) and all(text[i+2+k] in '0123456789abcdefABCDEF' for k in range(4)):
|
||||
result.append(text[i:i+6])
|
||||
i += 6
|
||||
continue
|
||||
elif next_c == 'u':
|
||||
# Unicode 转义 \uXXXX,检查是否有4个十六进制字符
|
||||
if i + 5 < len(text) and all(text[i+2+k] in '0123456789abcdefABCDEF' for k in range(4)):
|
||||
result.append(text[i:i+6])
|
||||
i += 6
|
||||
continue
|
||||
else:
|
||||
# 不完整的unicode转义,去掉反斜杠
|
||||
result.append(next_c)
|
||||
fixed_count += 1
|
||||
i += 2
|
||||
continue
|
||||
else:
|
||||
# 非法转义字符(如 \c \p \d 等),去掉反斜杠只保留字符
|
||||
result.append(next_c)
|
||||
fixed_count += 1
|
||||
i += 2
|
||||
continue
|
||||
else:
|
||||
# 末尾孤立的反斜杠,去掉
|
||||
result.append(next_c)
|
||||
fixed_count += 1
|
||||
i += 1
|
||||
i += 2
|
||||
continue
|
||||
|
||||
if c == '"':
|
||||
# 字符串结束
|
||||
else:
|
||||
fixed_count += 1
|
||||
i += 1
|
||||
continue
|
||||
|
||||
# ASCII 双引号 → 智能判断是结束引号还是内容引号
|
||||
if c == '"':
|
||||
if _is_content_quote(text, i):
|
||||
# 内容引号,需要转义
|
||||
result.append('\\')
|
||||
result.append('"')
|
||||
fixed_count += 1
|
||||
i += 1
|
||||
continue
|
||||
else:
|
||||
# 结束引号
|
||||
in_string = False
|
||||
result.append(c)
|
||||
i += 1
|
||||
continue
|
||||
|
||||
if c == '\n':
|
||||
# 裸换行符 → 替换为转义换行
|
||||
result.append('\\')
|
||||
result.append('n')
|
||||
fixed_count += 1
|
||||
i += 1
|
||||
continue
|
||||
|
||||
if c == '\r':
|
||||
# 裸回车符 → 忽略或替换
|
||||
if i + 1 < len(text) and text[i + 1] == '\n':
|
||||
result.append('\\')
|
||||
result.append('n')
|
||||
fixed_count += 1
|
||||
i += 2
|
||||
else:
|
||||
result.append('\\')
|
||||
result.append('n')
|
||||
fixed_count += 1
|
||||
i += 1
|
||||
continue
|
||||
|
||||
if c == '\t':
|
||||
# 裸制表符 → 替换为转义制表符
|
||||
result.append('\\')
|
||||
result.append('t')
|
||||
fixed_count += 1
|
||||
i += 1
|
||||
continue
|
||||
|
||||
# 字符串值内的中文引号 → 转义为 \"(避免破坏JSON结构)
|
||||
if c in _QUOTE_MAP:
|
||||
result.append('\\')
|
||||
result.append(_QUOTE_MAP[c])
|
||||
fixed_count += 1
|
||||
i += 1
|
||||
continue
|
||||
|
||||
# 非字符串内的字符
|
||||
# 结构位置的中文引号 → 直接替换
|
||||
if not in_string and c in _QUOTE_MAP:
|
||||
result.append(_QUOTE_MAP[c])
|
||||
# 裸换行符 → 转义
|
||||
if c == '\n':
|
||||
result.append('\\')
|
||||
result.append('n')
|
||||
fixed_count += 1
|
||||
i += 1
|
||||
continue
|
||||
|
||||
if c == '\r':
|
||||
if i + 1 < len(text) and text[i + 1] == '\n':
|
||||
result.append('\\')
|
||||
result.append('n')
|
||||
fixed_count += 1
|
||||
i += 2
|
||||
else:
|
||||
result.append('\\')
|
||||
result.append('n')
|
||||
fixed_count += 1
|
||||
i += 1
|
||||
continue
|
||||
|
||||
if c == '\t':
|
||||
result.append('\\')
|
||||
result.append('t')
|
||||
fixed_count += 1
|
||||
i += 1
|
||||
continue
|
||||
|
||||
# 中文引号处理
|
||||
if c in _QUOTE_MAP:
|
||||
mapped = _QUOTE_MAP[c]
|
||||
if mapped == '"':
|
||||
# 中文双引号在字符串内需要转义
|
||||
result.append('\\')
|
||||
result.append('"')
|
||||
else:
|
||||
# 中文单引号在双引号字符串内不需要转义,直接替换
|
||||
result.append(mapped)
|
||||
fixed_count += 1
|
||||
i += 1
|
||||
continue
|
||||
|
||||
# 其他字符(包括中文逗号、中文冒号)→ 保留原样
|
||||
result.append(c)
|
||||
i += 1
|
||||
|
||||
if fixed_count > 0:
|
||||
logger.debug(f"✅ 修复了{fixed_count}个JSON问题(裸控制字符/中文引号)")
|
||||
logger.debug(f"✅ 修复了{fixed_count}个JSON问题(引号/控制字符/中文标点)")
|
||||
|
||||
return ''.join(result)
|
||||
|
||||
|
||||
def _fix_all_invalid_escapes(text: str) -> str:
|
||||
"""
|
||||
兜底修复:扫描整个文本中的无效JSON转义序列。
|
||||
|
||||
当 _fix_json_string_values 因字符串边界追踪错误而遗漏某些无效转义时,
|
||||
此函数作为兜底,不依赖字符串状态追踪,扫描整个文本修复所有无效转义。
|
||||
|
||||
有效JSON转义:\\" \\\\ \\/ \\b \\f \\n \\r \\t \\uXXXX
|
||||
其他 \\X 均为无效转义,修复方式为去掉反斜杠只保留字符。
|
||||
"""
|
||||
if '\\' not in text:
|
||||
return text
|
||||
|
||||
result = []
|
||||
i = 0
|
||||
fixed = 0
|
||||
|
||||
while i < len(text):
|
||||
if text[i] == '\\' and i + 1 < len(text):
|
||||
next_c = text[i + 1]
|
||||
if next_c in ('"', '\\', '/', 'b', 'f', 'n', 'r', 't'):
|
||||
# 有效转义,保留
|
||||
result.append(text[i])
|
||||
result.append(next_c)
|
||||
i += 2
|
||||
continue
|
||||
elif next_c == 'u':
|
||||
# Unicode 转义,检查是否有4个十六进制字符
|
||||
if i + 5 < len(text) and all(
|
||||
text[i + 2 + k] in '0123456789abcdefABCDEF'
|
||||
for k in range(4)
|
||||
):
|
||||
result.append(text[i:i + 6])
|
||||
i += 6
|
||||
continue
|
||||
else:
|
||||
# 不完整的unicode转义,去掉反斜杠
|
||||
result.append(next_c)
|
||||
fixed += 1
|
||||
i += 2
|
||||
continue
|
||||
else:
|
||||
# 无效转义(如 \引 \影 \某种 等),去掉反斜杠只保留字符
|
||||
result.append(next_c)
|
||||
fixed += 1
|
||||
i += 2
|
||||
continue
|
||||
else:
|
||||
result.append(text[i])
|
||||
i += 1
|
||||
|
||||
if fixed > 0:
|
||||
logger.info(f"✅ 兜底修复了{fixed}个无效JSON转义序列")
|
||||
|
||||
return ''.join(result)
|
||||
|
||||
|
||||
def _fix_multiple_objects_as_value(text: str) -> str:
|
||||
"""
|
||||
修复AI生成的JSON中,多个对象作为属性值但未合并的问题。
|
||||
|
||||
示例:
|
||||
"key": {"a": "1"}, {"b": "2"} → "key": {"a": "1", "b": "2"}
|
||||
|
||||
AI有时在输出对象类型的属性值时,输出了多个独立的对象而不是合并为一个。
|
||||
例如 relationship_changes 字段输出多个角色关系变化时可能出现此问题。
|
||||
此函数检测并合并这些对象。
|
||||
"""
|
||||
if '{' not in text or '}' not in text:
|
||||
return text
|
||||
|
||||
# 匹配嵌套层级不超过2的对象: { ... } 其中 ... 不含 { 或仅含单层嵌套
|
||||
nested_obj = r'\{(?:[^{}]|\{(?:[^{}]|\{[^{}]*\})*\})*\}'
|
||||
|
||||
# 模式:属性冒号后跟一个对象,然后逗号和另一个对象(没有属性名)
|
||||
# 即 "key": {obj1}, {obj2} → "key": {obj1, obj2}
|
||||
pattern = r'(":)\s*(' + nested_obj + r')\s*,\s*(' + nested_obj + r')'
|
||||
|
||||
def merge_objects(match):
|
||||
colon = match.group(1)
|
||||
obj1_content = match.group(2)[1:-1] # 去掉外层的 { }
|
||||
obj2_content = match.group(3)[1:-1] # 去掉外层的 { }
|
||||
# 合并为一个对象
|
||||
return f'{colon} {{{obj1_content}, {obj2_content}}}'
|
||||
|
||||
prev = None
|
||||
count = 0
|
||||
max_iterations = 10
|
||||
while prev != text and count < max_iterations:
|
||||
prev = text
|
||||
text = re.sub(pattern, merge_objects, text)
|
||||
count += 1
|
||||
|
||||
if count > 1:
|
||||
logger.info(f"✅ 修复了{count - 1}处多对象属性值合并")
|
||||
|
||||
return text
|
||||
|
||||
|
||||
def clean_json_response(text: str) -> str:
|
||||
"""清洗 AI 返回的 JSON(改进版 - 流式安全)"""
|
||||
try:
|
||||
@@ -162,11 +371,8 @@ def clean_json_response(text: str) -> str:
|
||||
original_length = len(text)
|
||||
logger.debug(f"🔍 开始清洗JSON,原始长度: {original_length}")
|
||||
|
||||
# 替换中文逗号/冒号(AI可能在JSON结构位置使用,全局替换是安全的)
|
||||
text = text.replace('\uff0c', ',') # ,→ ,
|
||||
text = text.replace('\uff1a', ':') # :→ :
|
||||
|
||||
# 修复JSON中的中文引号和裸控制字符(上下文感知,区分字符串内外)
|
||||
# 上下文感知修复:中文引号/逗号/冒号、裸控制字符、未转义的内容引号
|
||||
# (区分字符串内外:结构位置替换为ASCII,字符串内保留或转义)
|
||||
text = _fix_json_string_values(text)
|
||||
|
||||
# 去除 markdown 代码块
|
||||
@@ -286,9 +492,35 @@ def clean_json_response(text: str) -> str:
|
||||
json.loads(result)
|
||||
logger.debug(f"✅ 清洗后JSON验证成功")
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"❌ 清洗后JSON仍然无效: {e}")
|
||||
logger.debug(f" 结果预览: {result[:500]}")
|
||||
logger.debug(f" 结果结尾: ...{result[-200:]}")
|
||||
logger.warning(f"⚠️ 清洗后JSON仍然无效: {e},尝试修复结构性问题...")
|
||||
|
||||
# 修复1:合并多对象属性值(AI可能输出 "key": {a:1}, {b:2} )
|
||||
result = _fix_multiple_objects_as_value(result)
|
||||
|
||||
try:
|
||||
json.loads(result)
|
||||
logger.info(f"✅ 修复多对象属性值后JSON验证成功")
|
||||
except json.JSONDecodeError:
|
||||
pass # 继续尝试其他修复
|
||||
else:
|
||||
return result
|
||||
|
||||
# 修复2:兜底修复无效转义序列(不依赖字符串边界追踪)
|
||||
logger.warning(f"⚠️ 继续尝试兜底修复无效转义...")
|
||||
result = _fix_all_invalid_escapes(result)
|
||||
try:
|
||||
json.loads(result)
|
||||
logger.info(f"✅ 兜底修复后JSON验证成功")
|
||||
except json.JSONDecodeError as e2:
|
||||
# 修复3:再次尝试合并多对象属性值(转义修复后可能产生新的合并机会)
|
||||
result = _fix_multiple_objects_as_value(result)
|
||||
try:
|
||||
json.loads(result)
|
||||
logger.info(f"✅ 二次修复后JSON验证成功")
|
||||
except json.JSONDecodeError as e3:
|
||||
logger.error(f"❌ 所有修复后JSON仍然无效: {e3}")
|
||||
logger.debug(f" 结果预览: {result[:500]}")
|
||||
logger.debug(f" 结果结尾: ...{result[-200:]}")
|
||||
|
||||
return result
|
||||
|
||||
@@ -339,6 +571,16 @@ def loads_json(text: str) -> Any:
|
||||
except (json.JSONDecodeError, Exception):
|
||||
pass
|
||||
|
||||
# 兜底修复无效转义序列后重试
|
||||
fixed_text = _fix_all_invalid_escapes(text)
|
||||
if fixed_text != text:
|
||||
try:
|
||||
result = json.loads(fixed_text)
|
||||
logger.info("✅ 兜底修复无效转义后json.loads成功")
|
||||
return result
|
||||
except (json.JSONDecodeError, Exception):
|
||||
pass
|
||||
|
||||
# json5 容错解析
|
||||
if HAS_JSON5:
|
||||
try:
|
||||
@@ -347,6 +589,14 @@ def loads_json(text: str) -> Any:
|
||||
logger.info("✅ json5容错解析成功")
|
||||
return result
|
||||
except Exception as e5:
|
||||
# json5也失败,尝试对修复后的文本使用json5
|
||||
if fixed_text != text:
|
||||
try:
|
||||
result = json5.loads(fixed_text)
|
||||
logger.info("✅ 兜底修复无效转义后json5容错解析成功")
|
||||
return result
|
||||
except Exception:
|
||||
pass
|
||||
logger.error(f"❌ json5容错解析也失败: {e5}")
|
||||
|
||||
# 最终失败,抛出标准异常
|
||||
|
||||
+1
-1
@@ -53,7 +53,7 @@ services:
|
||||
build:
|
||||
context: .
|
||||
dockerfile: Dockerfile
|
||||
image: mumujie/mumuainovel:latest
|
||||
image: mumujie/mumuainovel:dev
|
||||
container_name: mumuainovel
|
||||
depends_on:
|
||||
postgres:
|
||||
|
||||
Generated
+31
@@ -23,6 +23,9 @@ importers:
|
||||
'@types/canvas-confetti':
|
||||
specifier: ^1.9.0
|
||||
version: 1.9.0
|
||||
'@types/dagre':
|
||||
specifier: ^0.7.54
|
||||
version: 0.7.54
|
||||
'@xyflow/react':
|
||||
specifier: ^12.10.1
|
||||
version: 12.10.1(@types/react@18.3.28)(react-dom@18.3.1(react@18.3.1))(react@18.3.1)
|
||||
@@ -35,6 +38,9 @@ importers:
|
||||
canvas-confetti:
|
||||
specifier: ^1.9.4
|
||||
version: 1.9.4
|
||||
dagre:
|
||||
specifier: ^0.8.5
|
||||
version: 0.8.5
|
||||
dayjs:
|
||||
specifier: ^1.11.13
|
||||
version: 1.11.19
|
||||
@@ -734,6 +740,9 @@ packages:
|
||||
'@types/d3-zoom@3.0.8':
|
||||
resolution: {integrity: sha512-iqMC4/YlFCSlO8+2Ii1GGGliCAY4XdeG748w5vQUbevlbDu0zSjH/+jojorQVBK/se0j6DUFNPBGSqD3YWYnDw==}
|
||||
|
||||
'@types/dagre@0.7.54':
|
||||
resolution: {integrity: sha512-QjcRY+adGbYvBFS7cwv5txhVIwX1XXIUswWl+kSQTbI6NjgZydrZkEKX/etzVd7i+bCsCb40Z/xlBY5eoFuvWQ==}
|
||||
|
||||
'@types/estree@1.0.8':
|
||||
resolution: {integrity: sha512-dWHzHa2WqEXI/O1E9OjrocMTKJl2mSrEolh1Iomrv6U+JuNwaHXsXx9bLu5gG7BUWFIN0skIQJQ/L1rIex4X6w==}
|
||||
|
||||
@@ -996,6 +1005,9 @@ packages:
|
||||
resolution: {integrity: sha512-b8AmV3kfQaqWAuacbPuNbL6vahnOJflOhexLzMMNLga62+/nh0JzvJ0aO/5a5MVgUFGS7Hu1P9P03o3fJkDCyw==}
|
||||
engines: {node: '>=12'}
|
||||
|
||||
dagre@0.8.5:
|
||||
resolution: {integrity: sha512-/aTqmnRta7x7MCCpExk7HQL2O4owCT2h8NT//9I1OQ9vt29Pa0BzSAkR5lwFUcQ7491yVi/3CXU9jQ5o0Mn2Sw==}
|
||||
|
||||
dayjs@1.11.19:
|
||||
resolution: {integrity: sha512-t5EcLVS6QPBNqM2z8fakk/NKel+Xzshgt8FFKAn+qwlD1pzZWxh0nVCrvFK7ZDb6XucZeF9z8C7CBWTRIVApAw==}
|
||||
|
||||
@@ -1200,6 +1212,9 @@ packages:
|
||||
resolution: {integrity: sha512-ZUKRh6/kUFoAiTAtTYPZJ3hw9wNxx+BIBOijnlG9PnrJsCcSjs1wyyD6vJpaYtgnzDrKYRSqf3OO6Rfa93xsRg==}
|
||||
engines: {node: '>= 0.4'}
|
||||
|
||||
graphlib@2.1.8:
|
||||
resolution: {integrity: sha512-jcLLfkpoVGmH7/InMC/1hIvOPSUh38oJtGhvrOFGzioE1DZ+0YW16RgmOJhHiuWTvGiJQ9Z1Ik43JvkRPRvE+A==}
|
||||
|
||||
has-flag@4.0.0:
|
||||
resolution: {integrity: sha512-EykJT/Q1KjTWctppgIAgfSO0tKVuZUjhgMr17kqTumMl6Afv3EISleU7qZUzoXDFTAHTDC4NOoG/ZxU3EvlMPQ==}
|
||||
engines: {node: '>=8'}
|
||||
@@ -1299,6 +1314,9 @@ packages:
|
||||
lodash.merge@4.6.2:
|
||||
resolution: {integrity: sha512-0KpjqXRVvrYyCsX1swR/XTK0va6VQkQM6MNo7PqW77ByjAhoARA8EfrP1N4+KlKj8YS0ZUCtRT/YUuhyYDujIQ==}
|
||||
|
||||
lodash@4.18.1:
|
||||
resolution: {integrity: sha512-dMInicTPVE8d1e5otfwmmjlxkZoUpiVLwyeTdUsi/Caj/gfzzblBcCE5sRHV/AsjuCmxWrte2TNGSYuCeCq+0Q==}
|
||||
|
||||
loose-envify@1.4.0:
|
||||
resolution: {integrity: sha512-lyuxPGr/Wfhrlem2CL/UcnUc1zcqKAImBDzukY7Y5F/yQiNdko6+fRLevlw1HgMySw7f611UIY408EtxRSoK3Q==}
|
||||
hasBin: true
|
||||
@@ -2498,6 +2516,8 @@ snapshots:
|
||||
'@types/d3-interpolate': 3.0.4
|
||||
'@types/d3-selection': 3.0.11
|
||||
|
||||
'@types/dagre@0.7.54': {}
|
||||
|
||||
'@types/estree@1.0.8': {}
|
||||
|
||||
'@types/json-schema@7.0.15': {}
|
||||
@@ -2861,6 +2881,11 @@ snapshots:
|
||||
d3-selection: 3.0.0
|
||||
d3-transition: 3.0.1(d3-selection@3.0.0)
|
||||
|
||||
dagre@0.8.5:
|
||||
dependencies:
|
||||
graphlib: 2.1.8
|
||||
lodash: 4.18.1
|
||||
|
||||
dayjs@1.11.19: {}
|
||||
|
||||
debug@4.4.3:
|
||||
@@ -3082,6 +3107,10 @@ snapshots:
|
||||
|
||||
gopd@1.2.0: {}
|
||||
|
||||
graphlib@2.1.8:
|
||||
dependencies:
|
||||
lodash: 4.18.1
|
||||
|
||||
has-flag@4.0.0: {}
|
||||
|
||||
has-symbols@1.1.0: {}
|
||||
@@ -3158,6 +3187,8 @@ snapshots:
|
||||
|
||||
lodash.merge@4.6.2: {}
|
||||
|
||||
lodash@4.18.1: {}
|
||||
|
||||
loose-envify@1.4.0:
|
||||
dependencies:
|
||||
js-tokens: 4.0.0
|
||||
|
||||
+394
-30
@@ -1,15 +1,15 @@
|
||||
import { useState, useEffect, useRef, useMemo, useCallback } from 'react';
|
||||
import { List, Button, Modal, Form, Input, Select, message, Empty, Space, Badge, Tag, Card, InputNumber, Alert, Radio, Descriptions, Collapse, Popconfirm, Pagination, theme } from 'antd';
|
||||
import { EditOutlined, FileTextOutlined, ThunderboltOutlined, LockOutlined, DownloadOutlined, SettingOutlined, FundOutlined, SyncOutlined, CheckCircleOutlined, CloseCircleOutlined, RocketOutlined, StopOutlined, InfoCircleOutlined, CaretRightOutlined, DeleteOutlined, BookOutlined, FormOutlined, PlusOutlined, ReadOutlined } from '@ant-design/icons';
|
||||
import { List, Button, Modal, Form, Input, Select, message, Empty, Space, Badge, Tag, Progress, Card, InputNumber, Alert, Radio, Descriptions, Collapse, Popconfirm, Pagination, theme } from 'antd';
|
||||
import { EditOutlined, FileTextOutlined, ThunderboltOutlined, LockOutlined, DownloadOutlined, SettingOutlined, FundOutlined, SyncOutlined, CheckCircleOutlined, CloseCircleOutlined, RocketOutlined, StopOutlined, InfoCircleOutlined, CaretRightOutlined, DeleteOutlined, BookOutlined, FormOutlined, PlusOutlined, ReadOutlined, ClockCircleOutlined, LoadingOutlined } from '@ant-design/icons';
|
||||
import { useStore } from '../store';
|
||||
import { useChapterSync } from '../store/hooks';
|
||||
import { generateChapterBackground, getProjectTasks, cancelTask, deleteTask, type TaskStatus as BgTaskStatus } from '../services/backgroundTaskService';
|
||||
import { projectApi, writingStyleApi, chapterApi } from '../services/api';
|
||||
import type { Chapter, ChapterUpdate, ApiError, WritingStyle, AnalysisTask, ExpansionPlanData } from '../types';
|
||||
import type { TextAreaRef } from 'antd/es/input/TextArea';
|
||||
import ChapterAnalysis from '../components/ChapterAnalysis';
|
||||
import ExpansionPlanEditor from '../components/ExpansionPlanEditor';
|
||||
import { SSELoadingOverlay } from '../components/SSELoadingOverlay';
|
||||
import { SSEProgressModal } from '../components/SSEProgressModal';
|
||||
import ChapterReader from '../components/ChapterReader';
|
||||
import PartialRegenerateToolbar from '../components/PartialRegenerateToolbar';
|
||||
import PartialRegenerateModal from '../components/PartialRegenerateModal';
|
||||
@@ -97,6 +97,112 @@ export default function Chapters() {
|
||||
const [singleChapterProgress, setSingleChapterProgress] = useState(0);
|
||||
const [singleChapterProgressMessage, setSingleChapterProgressMessage] = useState('');
|
||||
|
||||
// 后台生成任务状态
|
||||
const [bgTaskVisible, setBgTaskVisible] = useState(false);
|
||||
const [bgTaskProgress, setBgTaskProgress] = useState(0);
|
||||
const [bgTaskMessage, setBgTaskMessage] = useState('');
|
||||
const [bgTaskRunning, setBgTaskRunning] = useState(false);
|
||||
const bgTaskCancelRef = useRef<(() => void) | null>(null);
|
||||
const [projectBgTasks, setProjectBgTasks] = useState<BgTaskStatus[]>([]);
|
||||
const bgPollTimerRef = useRef<ReturnType<typeof setTimeout> | null>(null);
|
||||
// 后台任务列表 Modal 状态
|
||||
const [taskListVisible, setTaskListVisible] = useState(false);
|
||||
const [taskList, setTaskList] = useState<BgTaskStatus[]>([]);
|
||||
const [taskListLoading, setTaskListLoading] = useState(false);
|
||||
|
||||
// 轮询项目后台任务
|
||||
useEffect(() => {
|
||||
if (!currentProject) return;
|
||||
const pollBgTasks = async () => {
|
||||
try {
|
||||
const resp = await getProjectTasks(currentProject.id, 'chapter_generate', 10);
|
||||
const active = resp.items.filter(t => t.status === 'pending' || t.status === 'running');
|
||||
setProjectBgTasks(active);
|
||||
// 如果有活跃任务,继续轮询
|
||||
if (active.length > 0) {
|
||||
bgPollTimerRef.current = setTimeout(pollBgTasks, 3000);
|
||||
}
|
||||
} catch {}
|
||||
};
|
||||
pollBgTasks();
|
||||
return () => { if (bgPollTimerRef.current) clearTimeout(bgPollTimerRef.current); };
|
||||
}, [currentProject]);
|
||||
|
||||
// 加载并显示后台任务列表
|
||||
const showTaskListModal = async () => {
|
||||
if (!currentProject?.id) return;
|
||||
setTaskListVisible(true);
|
||||
setTaskListLoading(true);
|
||||
try {
|
||||
const result = await getProjectTasks(currentProject.id);
|
||||
setTaskList(result.items || []);
|
||||
} catch (error) {
|
||||
message.error('加载任务列表失败');
|
||||
} finally {
|
||||
setTaskListLoading(false);
|
||||
}
|
||||
};
|
||||
|
||||
// 刷新任务列表
|
||||
const refreshTaskList = async () => {
|
||||
if (!currentProject?.id) return;
|
||||
setTaskListLoading(true);
|
||||
try {
|
||||
const result = await getProjectTasks(currentProject.id);
|
||||
setTaskList(result.items || []);
|
||||
const active = (result.items || []).filter(t => t.status === 'pending' || t.status === 'running');
|
||||
setProjectBgTasks(active);
|
||||
} catch (error) {
|
||||
console.error('刷新任务列表失败:', error);
|
||||
} finally {
|
||||
setTaskListLoading(false);
|
||||
}
|
||||
};
|
||||
|
||||
// 获取任务状态标签
|
||||
const getTaskStatusTag = (status: BgTaskStatus['status']) => {
|
||||
switch (status) {
|
||||
case 'pending': return <Tag icon={<ClockCircleOutlined />} color="default">等待中</Tag>;
|
||||
case 'running': return <Tag icon={<LoadingOutlined />} color="processing">运行中</Tag>;
|
||||
case 'completed': return <Tag icon={<CheckCircleOutlined />} color="success">已完成</Tag>;
|
||||
case 'failed': return <Tag icon={<CloseCircleOutlined />} color="error">失败</Tag>;
|
||||
case 'cancelled': return <Tag icon={<CloseCircleOutlined />} color="default">已取消</Tag>;
|
||||
default: return <Tag>{status}</Tag>;
|
||||
}
|
||||
};
|
||||
|
||||
// 获取任务类型标签
|
||||
const getTaskTypeLabel = (taskType: string) => {
|
||||
switch (taskType) {
|
||||
case 'chapter_generate': return '章节生成';
|
||||
case 'outline_new': return '大纲生成';
|
||||
case 'outline_continue': return '大纲续写';
|
||||
default: return taskType;
|
||||
}
|
||||
};
|
||||
|
||||
// 处理取消后台任务
|
||||
const handleCancelBgTask = async (taskId: string) => {
|
||||
try {
|
||||
await cancelTask(taskId);
|
||||
message.success('任务已取消');
|
||||
refreshTaskList();
|
||||
} catch (error) {
|
||||
message.error('取消任务失败');
|
||||
}
|
||||
};
|
||||
|
||||
// 处理删除任务记录
|
||||
const handleDeleteBgTask = async (taskId: string) => {
|
||||
try {
|
||||
await deleteTask(taskId);
|
||||
message.success('任务记录已删除');
|
||||
refreshTaskList();
|
||||
} catch (error) {
|
||||
message.error('删除任务记录失败');
|
||||
}
|
||||
};
|
||||
|
||||
// 批量生成相关状态
|
||||
const [batchGenerateVisible, setBatchGenerateVisible] = useState(false);
|
||||
const [batchGenerating, setBatchGenerating] = useState(false);
|
||||
@@ -523,7 +629,7 @@ export default function Chapters() {
|
||||
if (data.has_active_task && data.task) {
|
||||
const task = data.task;
|
||||
|
||||
// 恢复任务状态
|
||||
// 恢复任务状态(只在顶部进度条显示,不弹出Modal)
|
||||
setBatchTaskId(task.batch_id);
|
||||
setBatchProgress({
|
||||
status: task.status,
|
||||
@@ -532,12 +638,12 @@ export default function Chapters() {
|
||||
current_chapter_number: task.current_chapter_number,
|
||||
});
|
||||
setBatchGenerating(true);
|
||||
setBatchGenerateVisible(true);
|
||||
// 不设置 setBatchGenerateVisible(true),避免弹出Modal遮挡页面
|
||||
|
||||
// 启动轮询
|
||||
startBatchPolling(task.batch_id);
|
||||
|
||||
message.info('检测到未完成的批量生成任务,已自动恢复');
|
||||
message.info('检测到未完成的批量生成任务,已在顶部显示进度');
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('检查批量生成任务失败:', error);
|
||||
@@ -971,9 +1077,62 @@ export default function Chapters() {
|
||||
});
|
||||
};
|
||||
|
||||
|
||||
// 后台生成章节(关闭浏览器也不影响)
|
||||
const handleBackgroundGenerate = async () => {
|
||||
if (!editingId) return;
|
||||
if (!selectedStyleId) {
|
||||
message.error("请先选择写作风格");
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
setBgTaskVisible(true);
|
||||
setBgTaskRunning(true);
|
||||
setBgTaskProgress(0);
|
||||
setBgTaskMessage("正在创建后台任务...");
|
||||
|
||||
const cancelFn = await generateChapterBackground(
|
||||
editingId,
|
||||
{
|
||||
style_id: selectedStyleId,
|
||||
target_word_count: targetWordCount,
|
||||
model: selectedModel,
|
||||
narrative_perspective: temporaryNarrativePerspective,
|
||||
},
|
||||
(status) => {
|
||||
setBgTaskProgress(status.progress || 0);
|
||||
setBgTaskMessage(status.status_message || "处理中...");
|
||||
},
|
||||
(_) => {
|
||||
setBgTaskProgress(100);
|
||||
setBgTaskMessage("生成完成!");
|
||||
setBgTaskRunning(false);
|
||||
message.success("后台章节生成完成!");
|
||||
refreshChapters();
|
||||
if (currentProject) {
|
||||
projectApi.getProject(currentProject.id).then(setCurrentProject).catch(console.error);
|
||||
}
|
||||
loadAnalysisTasks();
|
||||
},
|
||||
(error) => {
|
||||
setBgTaskRunning(false);
|
||||
setBgTaskMessage("失败: " + error);
|
||||
message.error("后台生成失败: " + error);
|
||||
}
|
||||
);
|
||||
|
||||
bgTaskCancelRef.current = cancelFn;
|
||||
message.info("已提交后台生成任务,可以关闭此页面");
|
||||
} catch (error) {
|
||||
message.error("创建后台任务失败");
|
||||
setBgTaskRunning(false);
|
||||
}
|
||||
};
|
||||
const getStatusColor = (status: string) => {
|
||||
const colors: Record<string, string> = {
|
||||
'draft': 'default',
|
||||
'pending': 'warning',
|
||||
'writing': 'processing',
|
||||
'completed': 'success',
|
||||
};
|
||||
@@ -983,6 +1142,7 @@ export default function Chapters() {
|
||||
const getStatusText = (status: string) => {
|
||||
const texts: Record<string, string> = {
|
||||
'draft': '草稿',
|
||||
'pending': '待处理',
|
||||
'writing': '创作中',
|
||||
'completed': '已完成',
|
||||
};
|
||||
@@ -1387,6 +1547,7 @@ export default function Chapters() {
|
||||
>
|
||||
<Select>
|
||||
<Select.Option value="draft">草稿</Select.Option>
|
||||
<Select.Option value="pending">待处理</Select.Option>
|
||||
<Select.Option value="writing">创作中</Select.Option>
|
||||
<Select.Option value="completed">已完成</Select.Option>
|
||||
</Select>
|
||||
@@ -1931,6 +2092,13 @@ export default function Chapters() {
|
||||
>
|
||||
一键分析{batchAnalyzableChapterCount > 0 ? ` (${batchAnalyzableChapterCount})` : ''}
|
||||
</Button>
|
||||
<Button
|
||||
icon={<ClockCircleOutlined />}
|
||||
onClick={showTaskListModal}
|
||||
>
|
||||
后台任务
|
||||
{projectBgTasks.length > 0 && <Badge count={projectBgTasks.length} size="small" style={{ marginLeft: 4 }} />}
|
||||
</Button>
|
||||
<Button
|
||||
type="primary"
|
||||
icon={<RocketOutlined />}
|
||||
@@ -1955,6 +2123,103 @@ export default function Chapters() {
|
||||
</Space>
|
||||
</div>
|
||||
|
||||
{/* 后台生成任务进度 */}
|
||||
{(projectBgTasks.length > 0 || (batchGenerating && batchProgress)) && (
|
||||
<div style={{
|
||||
marginBottom: 16,
|
||||
padding: '12px 16px',
|
||||
background: token.colorInfoBg,
|
||||
borderRadius: token.borderRadius,
|
||||
border: `1px solid ${token.colorInfoBorder}`
|
||||
}}>
|
||||
<div style={{ display: 'flex', alignItems: 'center', gap: 8, marginBottom: 8 }}>
|
||||
<RocketOutlined style={{ color: token.colorInfo }} spin />
|
||||
<span style={{ fontWeight: 600, color: token.colorInfo }}>
|
||||
后台生成任务
|
||||
</span>
|
||||
<span style={{ fontSize: 12, color: token.colorTextSecondary }}>
|
||||
关闭浏览器也不影响,完成后自动保存
|
||||
</span>
|
||||
</div>
|
||||
{/* 批量生成进度 */}
|
||||
{batchGenerating && batchProgress && (
|
||||
<div style={{
|
||||
display: 'flex', alignItems: 'center', gap: 12,
|
||||
padding: '8px 0',
|
||||
borderBottom: `1px solid ${token.colorBorderSecondary}`
|
||||
}}>
|
||||
<Tag color="processing" style={{ minWidth: 60, textAlign: 'center' }}>
|
||||
批量生成
|
||||
</Tag>
|
||||
<div style={{ flex: 1 }}>
|
||||
<div style={{ fontSize: 12, marginBottom: 4, color: token.colorText }}>
|
||||
{batchProgress.current_chapter_number
|
||||
? `正在生成第 ${batchProgress.current_chapter_number} 章`
|
||||
: '批量生成中...'} ({batchProgress.completed}/{batchProgress.total})
|
||||
</div>
|
||||
<div style={{
|
||||
background: token.colorBgLayout, borderRadius: 4,
|
||||
height: 8, overflow: 'hidden'
|
||||
}}>
|
||||
<div style={{
|
||||
background: token.colorInfo, height: '100%',
|
||||
width: (batchProgress.total > 0 ? Math.round((batchProgress.completed / batchProgress.total) * 100) : 0) + '%',
|
||||
transition: 'width 0.3s'
|
||||
}} />
|
||||
</div>
|
||||
</div>
|
||||
<span style={{ fontSize: 13, fontWeight: 600, color: token.colorInfo, minWidth: 40, textAlign: 'right' }}>
|
||||
{batchProgress.total > 0 ? Math.round((batchProgress.completed / batchProgress.total) * 100) : 0}%
|
||||
</span>
|
||||
<Button size="small" danger onClick={() => {
|
||||
modal.confirm({
|
||||
title: '确认取消',
|
||||
content: '确定要取消批量生成吗?已生成的章节将保留。',
|
||||
okText: '确定取消',
|
||||
cancelText: '继续生成',
|
||||
okButtonProps: { danger: true },
|
||||
centered: true,
|
||||
onOk: handleCancelBatchGenerate,
|
||||
});
|
||||
}}>
|
||||
取消
|
||||
</Button>
|
||||
</div>
|
||||
)}
|
||||
{/* 单章节后台生成进度 */}
|
||||
{projectBgTasks.map(task => (
|
||||
<div key={task.id} style={{
|
||||
display: 'flex', alignItems: 'center', gap: 12,
|
||||
padding: '6px 0',
|
||||
borderBottom: `1px solid ${token.colorBorderSecondary}`
|
||||
}}>
|
||||
<Tag color={task.status === 'running' ? 'processing' : 'default'}
|
||||
style={{ minWidth: 60, textAlign: 'center' }}>
|
||||
{task.status === 'running' ? '生成中' : '排队中'}
|
||||
</Tag>
|
||||
<div style={{ flex: 1 }}>
|
||||
<div style={{
|
||||
background: token.colorBgLayout, borderRadius: 4,
|
||||
height: 6, overflow: 'hidden'
|
||||
}}>
|
||||
<div style={{
|
||||
background: token.colorInfo, height: '100%',
|
||||
width: (task.progress || 0) + '%',
|
||||
transition: 'width 0.3s'
|
||||
}} />
|
||||
</div>
|
||||
</div>
|
||||
<span style={{ fontSize: 12, color: token.colorTextSecondary, minWidth: 40, textAlign: 'right' }}>
|
||||
{task.progress || 0}%
|
||||
</span>
|
||||
<span style={{ fontSize: 12, color: token.colorTextSecondary }}>
|
||||
{task.status_message || ''}
|
||||
</span>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
)}
|
||||
|
||||
<div style={{ flex: 1, overflowY: 'auto', minHeight: 0 }}>
|
||||
{chapters.length === 0 ? (
|
||||
<Empty description="还没有章节,开始创作吧!" />
|
||||
@@ -2435,6 +2700,7 @@ export default function Chapters() {
|
||||
<Form.Item label="状态" name="status">
|
||||
<Select placeholder="选择状态">
|
||||
<Select.Option value="draft">草稿</Select.Option>
|
||||
<Select.Option value="pending">待处理</Select.Option>
|
||||
<Select.Option value="writing">创作中</Select.Option>
|
||||
<Select.Option value="completed">已完成</Select.Option>
|
||||
</Select>
|
||||
@@ -2497,23 +2763,56 @@ export default function Chapters() {
|
||||
const disabledReason = currentChapter ? getGenerateDisabledReason(currentChapter) : '';
|
||||
|
||||
return (
|
||||
<>
|
||||
<Button
|
||||
type="primary"
|
||||
icon={canGenerate ? <ThunderboltOutlined /> : <LockOutlined />}
|
||||
onClick={() => currentChapter && showGenerateModal(currentChapter)}
|
||||
loading={isContinuing}
|
||||
disabled={!canGenerate}
|
||||
disabled={!canGenerate || bgTaskRunning}
|
||||
danger={!canGenerate}
|
||||
style={{ fontWeight: 'bold' }}
|
||||
title={!canGenerate ? disabledReason : '根据大纲和前置章节内容创作'}
|
||||
title={!canGenerate ? disabledReason : '根据大纲和前置章节内容创作(流式)'}
|
||||
>
|
||||
{isMobile ? 'AI' : 'AI创作'}
|
||||
</Button>
|
||||
<Button
|
||||
icon={<RocketOutlined />}
|
||||
onClick={handleBackgroundGenerate}
|
||||
disabled={!canGenerate || bgTaskRunning || isContinuing}
|
||||
loading={bgTaskRunning}
|
||||
style={{ fontWeight: 'bold' }}
|
||||
title={!canGenerate ? disabledReason : '后台生成:关闭浏览器也不影响,完成后自动保存'}
|
||||
>
|
||||
{isMobile ? '后台' : '后台生成'}
|
||||
</Button>
|
||||
</>
|
||||
);
|
||||
})()}
|
||||
</Space.Compact>
|
||||
</Form.Item>
|
||||
|
||||
{/* 后台生成进度 */}
|
||||
{bgTaskVisible && (
|
||||
<Alert
|
||||
message={bgTaskRunning ? '后台生成进行中...' : '后台生成完成'}
|
||||
description={
|
||||
<div>
|
||||
<div style={{ marginBottom: 8 }}>{bgTaskMessage}</div>
|
||||
<div style={{ background: '#f0f0f0', borderRadius: 4, height: 8, overflow: 'hidden' }}>
|
||||
<div style={{ background: '#1890ff', height: '100%', width: bgTaskProgress + '%', transition: 'width 0.3s' }} />
|
||||
</div>
|
||||
<div style={{ fontSize: 12, color: '#888', marginTop: 4 }}>{bgTaskProgress}%</div>
|
||||
</div>
|
||||
}
|
||||
type={bgTaskRunning ? 'info' : (bgTaskProgress >= 100 ? 'success' : 'error')}
|
||||
showIcon
|
||||
style={{ marginBottom: 12 }}
|
||||
closable={!bgTaskRunning}
|
||||
onClose={() => setBgTaskVisible(false)}
|
||||
/>
|
||||
)}
|
||||
|
||||
{/* 第一行:写作风格 + 叙事角度 */}
|
||||
<div style={{
|
||||
display: isMobile ? 'block' : 'flex',
|
||||
@@ -2934,29 +3233,94 @@ export default function Chapters() {
|
||||
message={singleChapterProgressMessage}
|
||||
/>
|
||||
|
||||
{/* 批量生成进度显示 - 使用统一的进度组件 */}
|
||||
<SSEProgressModal
|
||||
visible={batchGenerating}
|
||||
progress={batchProgress ? Math.round((batchProgress.completed / batchProgress.total) * 100) : 0}
|
||||
message={
|
||||
batchProgress?.current_chapter_number
|
||||
? `正在生成第 ${batchProgress.current_chapter_number} 章... (${batchProgress.completed}/${batchProgress.total})`
|
||||
: `批量生成进行中... (${batchProgress?.completed || 0}/${batchProgress?.total || 0})`
|
||||
{/* 后台任务列表 Modal */}
|
||||
<Modal
|
||||
title={
|
||||
<Space>
|
||||
<ClockCircleOutlined />
|
||||
<span>后台任务</span>
|
||||
{taskList.filter(t => t.status === 'running' || t.status === 'pending').length > 0 && (
|
||||
<Badge count={taskList.filter(t => t.status === 'running' || t.status === 'pending').length} />
|
||||
)}
|
||||
</Space>
|
||||
}
|
||||
title="批量生成章节"
|
||||
onCancel={() => {
|
||||
modal.confirm({
|
||||
title: '确认取消',
|
||||
content: '确定要取消批量生成吗?已生成的章节将保留。',
|
||||
okText: '确定取消',
|
||||
cancelText: '继续生成',
|
||||
okButtonProps: { danger: true },
|
||||
centered: true,
|
||||
onOk: handleCancelBatchGenerate,
|
||||
});
|
||||
}}
|
||||
cancelButtonText="取消任务"
|
||||
/>
|
||||
open={taskListVisible}
|
||||
onCancel={() => setTaskListVisible(false)}
|
||||
width={isMobile ? '95%' : 700}
|
||||
centered
|
||||
footer={
|
||||
<Space>
|
||||
<Button icon={<SyncOutlined />} onClick={refreshTaskList} loading={taskListLoading}>
|
||||
刷新
|
||||
</Button>
|
||||
<Button onClick={() => setTaskListVisible(false)}>
|
||||
关闭
|
||||
</Button>
|
||||
</Space>
|
||||
}
|
||||
>
|
||||
{taskListLoading && taskList.length === 0 ? (
|
||||
<div style={{ textAlign: 'center', padding: 40 }}>
|
||||
<LoadingOutlined style={{ fontSize: 24 }} />
|
||||
<div style={{ marginTop: 12, color: token.colorTextSecondary }}>加载中...</div>
|
||||
</div>
|
||||
) : taskList.length === 0 ? (
|
||||
<Empty description="暂无后台任务" />
|
||||
) : (
|
||||
<List
|
||||
dataSource={taskList}
|
||||
renderItem={(task) => (
|
||||
<List.Item
|
||||
key={task.id}
|
||||
actions={[
|
||||
...(task.status === 'running' || task.status === 'pending'
|
||||
? [<Button key="cancel" size="small" danger onClick={() => handleCancelBgTask(task.id)}>取消</Button>]
|
||||
: []
|
||||
),
|
||||
...(task.status === 'completed' || task.status === 'failed' || task.status === 'cancelled'
|
||||
? [<Button key="delete" size="small" type="link" danger onClick={() => handleDeleteBgTask(task.id)}>删除</Button>]
|
||||
: []
|
||||
),
|
||||
].filter(Boolean)}
|
||||
>
|
||||
<List.Item.Meta
|
||||
title={
|
||||
<Space>
|
||||
{getTaskStatusTag(task.status)}
|
||||
<span>{getTaskTypeLabel(task.task_type)}</span>
|
||||
{task.status === 'running' || task.status === 'pending' ? (
|
||||
<Progress percent={task.progress} size="small" style={{ width: 120 }} />
|
||||
) : null}
|
||||
</Space>
|
||||
}
|
||||
description={
|
||||
<div>
|
||||
<div style={{ fontSize: 12, color: token.colorTextSecondary }}>
|
||||
{task.status_message || '无状态信息'}
|
||||
</div>
|
||||
<div style={{ fontSize: 11, color: token.colorTextTertiary, marginTop: 4 }}>
|
||||
创建: {task.created_at ? new Date(task.created_at).toLocaleString() : '-'}
|
||||
{task.completed_at && ' | 完成: ' + new Date(task.completed_at).toLocaleString()}
|
||||
</div>
|
||||
{task.error_message && (
|
||||
<div style={{ fontSize: 12, color: token.colorError, marginTop: 4 }}>
|
||||
{'❌ ' + task.error_message}
|
||||
</div>
|
||||
)}
|
||||
{task.task_result && task.status === 'completed' && (
|
||||
<div style={{ fontSize: 12, color: token.colorSuccess, marginTop: 4 }}>
|
||||
{'✅ ' + ((task.task_result as Record<string, unknown>).message as string || '任务完成')}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
}
|
||||
/>
|
||||
</List.Item>
|
||||
)}
|
||||
/>
|
||||
)}
|
||||
</Modal>
|
||||
|
||||
|
||||
{/* 章节阅读器 */}
|
||||
{readingChapter && (
|
||||
|
||||
+209
-24
@@ -1,10 +1,11 @@
|
||||
import { useState, useEffect, useMemo } from 'react';
|
||||
import { Button, List, Modal, Form, Input, message, Empty, Space, Popconfirm, Card, Select, Radio, Tag, InputNumber, Tabs, Pagination, theme } from 'antd';
|
||||
import { EditOutlined, DeleteOutlined, ThunderboltOutlined, BranchesOutlined, AppstoreAddOutlined, CheckCircleOutlined, ExclamationCircleOutlined, PlusOutlined, FileTextOutlined } from '@ant-design/icons';
|
||||
import { useState, useEffect, useMemo, useRef } from 'react';
|
||||
import { Button, List, Modal, Form, Input, message, Empty, Space, Popconfirm, Card, Select, Radio, Tag, InputNumber, Tabs, Pagination, theme, Progress, Badge, Tooltip } from 'antd';
|
||||
import { EditOutlined, DeleteOutlined, ThunderboltOutlined, BranchesOutlined, AppstoreAddOutlined, CheckCircleOutlined, ExclamationCircleOutlined, PlusOutlined, FileTextOutlined, ClockCircleOutlined, ReloadOutlined, CloseCircleOutlined, LoadingOutlined } from '@ant-design/icons';
|
||||
import { useStore } from '../store';
|
||||
import { useOutlineSync } from '../store/hooks';
|
||||
import { SSEPostClient } from '../utils/sseClient';
|
||||
import { SSEProgressModal } from '../components/SSEProgressModal';
|
||||
import { generateOutlineBackground, getProjectTasks, cancelTask, deleteTask, type TaskStatus } from '../services/backgroundTaskService';
|
||||
import { outlineApi, chapterApi, projectApi, characterApi } from '../services/api';
|
||||
import type { OutlineExpansionResponse, BatchOutlineExpansionResponse, ChapterPlanItem, ApiError, Character } from '../types';
|
||||
|
||||
@@ -154,6 +155,14 @@ export default function Outline() {
|
||||
const [sseMessage, setSSEMessage] = useState('');
|
||||
const [sseModalVisible, setSSEModalVisible] = useState(false);
|
||||
|
||||
// 后台任务取消函数引用
|
||||
const cancelGenerateRef = useRef<(() => void) | null>(null);
|
||||
|
||||
// 后台任务列表状态
|
||||
const [taskListVisible, setTaskListVisible] = useState(false);
|
||||
const [taskList, setTaskList] = useState<TaskStatus[]>([]);
|
||||
const [taskListLoading, setTaskListLoading] = useState(false);
|
||||
|
||||
useEffect(() => {
|
||||
const handleResize = () => {
|
||||
setIsMobile(window.innerWidth <= 768);
|
||||
@@ -573,33 +582,31 @@ export default function Outline() {
|
||||
console.log('6. 最终请求数据:', JSON.stringify(requestData, null, 2));
|
||||
console.log('=========================');
|
||||
|
||||
// 使用SSE客户端
|
||||
const apiUrl = `/api/outlines/generate-stream`;
|
||||
const client = new SSEPostClient(apiUrl, requestData, {
|
||||
onProgress: (msg: string, progress: number) => {
|
||||
setSSEMessage(msg);
|
||||
setSSEProgress(progress);
|
||||
// 使用后台任务生成(不怕断连,关闭浏览器也继续运行)
|
||||
setSSEMessage('正在创建后台任务...');
|
||||
|
||||
const cancelFn = await generateOutlineBackground(
|
||||
requestData,
|
||||
(status) => {
|
||||
setSSEProgress(status.progress);
|
||||
setSSEMessage(status.status_message || '处理中...');
|
||||
},
|
||||
onResult: (data: unknown) => {
|
||||
console.log('生成完成,结果:', data);
|
||||
(result) => {
|
||||
message.success(result.task_result?.message as string || '大纲生成完成!');
|
||||
setSSEModalVisible(false);
|
||||
setIsGenerating(false);
|
||||
cancelGenerateRef.current = null;
|
||||
refreshOutlines();
|
||||
},
|
||||
onError: (error: string) => {
|
||||
// 现在只处理真正的错误
|
||||
(error) => {
|
||||
message.error(`生成失败: ${error}`);
|
||||
setSSEModalVisible(false);
|
||||
setIsGenerating(false);
|
||||
},
|
||||
onComplete: () => {
|
||||
message.success('大纲生成完成!');
|
||||
setSSEModalVisible(false);
|
||||
setIsGenerating(false);
|
||||
// 刷新大纲列表
|
||||
refreshOutlines();
|
||||
cancelGenerateRef.current = null;
|
||||
}
|
||||
});
|
||||
);
|
||||
|
||||
// 开始连接
|
||||
client.connect();
|
||||
cancelGenerateRef.current = cancelFn;
|
||||
|
||||
} catch (error) {
|
||||
console.error('AI生成失败:', error);
|
||||
@@ -1895,8 +1902,168 @@ export default function Outline() {
|
||||
};
|
||||
|
||||
|
||||
// 加载并显示后台任务列表
|
||||
const showTaskListModal = async () => {
|
||||
if (!currentProject?.id) return;
|
||||
setTaskListVisible(true);
|
||||
setTaskListLoading(true);
|
||||
try {
|
||||
const result = await getProjectTasks(currentProject.id);
|
||||
setTaskList(result.items || []);
|
||||
} catch (error) {
|
||||
message.error('加载任务列表失败');
|
||||
} finally {
|
||||
setTaskListLoading(false);
|
||||
}
|
||||
};
|
||||
|
||||
// 刷新任务列表
|
||||
const refreshTaskList = async () => {
|
||||
if (!currentProject?.id) return;
|
||||
setTaskListLoading(true);
|
||||
try {
|
||||
const result = await getProjectTasks(currentProject.id);
|
||||
setTaskList(result.items || []);
|
||||
} catch (error) {
|
||||
console.error('刷新任务列表失败:', error);
|
||||
} finally {
|
||||
setTaskListLoading(false);
|
||||
}
|
||||
};
|
||||
|
||||
// 获取任务状态标签
|
||||
const getTaskStatusTag = (status: TaskStatus['status']) => {
|
||||
switch (status) {
|
||||
case 'pending': return <Tag icon={<ClockCircleOutlined />} color="default">等待中</Tag>;
|
||||
case 'running': return <Tag icon={<LoadingOutlined />} color="processing">运行中</Tag>;
|
||||
case 'completed': return <Tag icon={<CheckCircleOutlined />} color="success">已完成</Tag>;
|
||||
case 'failed': return <Tag icon={<CloseCircleOutlined />} color="error">失败</Tag>;
|
||||
case 'cancelled': return <Tag icon={<CloseCircleOutlined />} color="default">已取消</Tag>;
|
||||
default: return <Tag>{status}</Tag>;
|
||||
}
|
||||
};
|
||||
|
||||
// 获取任务类型标签
|
||||
const getTaskTypeLabel = (taskType: string) => {
|
||||
switch (taskType) {
|
||||
case 'outline_new': return '大纲生成';
|
||||
case 'outline_continue': return '大纲续写';
|
||||
default: return taskType;
|
||||
}
|
||||
};
|
||||
|
||||
// 处理取消后台任务
|
||||
const handleCancelTask = async (taskId: string) => {
|
||||
try {
|
||||
await cancelTask(taskId);
|
||||
message.success('任务已取消');
|
||||
refreshTaskList();
|
||||
} catch (error) {
|
||||
message.error('取消任务失败');
|
||||
}
|
||||
};
|
||||
|
||||
// 处理删除任务记录
|
||||
const handleDeleteTask = async (taskId: string) => {
|
||||
try {
|
||||
await deleteTask(taskId);
|
||||
message.success('任务记录已删除');
|
||||
refreshTaskList();
|
||||
} catch (error) {
|
||||
message.error('删除任务记录失败');
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<>
|
||||
{/* 后台任务列表 Modal */}
|
||||
<Modal
|
||||
title={
|
||||
<Space>
|
||||
<ClockCircleOutlined />
|
||||
<span>后台任务</span>
|
||||
{taskList.filter(t => t.status === 'running' || t.status === 'pending').length > 0 && (
|
||||
<Badge count={taskList.filter(t => t.status === 'running' || t.status === 'pending').length} />
|
||||
)}
|
||||
</Space>
|
||||
}
|
||||
open={taskListVisible}
|
||||
onCancel={() => setTaskListVisible(false)}
|
||||
width={isMobile ? '95%' : 700}
|
||||
centered
|
||||
footer={
|
||||
<Space>
|
||||
<Button icon={<ReloadOutlined />} onClick={refreshTaskList} loading={taskListLoading}>
|
||||
刷新
|
||||
</Button>
|
||||
<Button onClick={() => setTaskListVisible(false)}>
|
||||
关闭
|
||||
</Button>
|
||||
</Space>
|
||||
}
|
||||
>
|
||||
{taskListLoading && taskList.length === 0 ? (
|
||||
<div style={{ textAlign: 'center', padding: 40 }}>
|
||||
<LoadingOutlined style={{ fontSize: 24 }} />
|
||||
<div style={{ marginTop: 12, color: token.colorTextSecondary }}>加载中...</div>
|
||||
</div>
|
||||
) : taskList.length === 0 ? (
|
||||
<Empty description="暂无后台任务" />
|
||||
) : (
|
||||
<List
|
||||
dataSource={taskList}
|
||||
renderItem={(task) => (
|
||||
<List.Item
|
||||
key={task.id}
|
||||
actions={[
|
||||
...(task.status === 'running' || task.status === 'pending'
|
||||
? [<Button key="cancel" size="small" danger onClick={() => handleCancelTask(task.id)}>取消</Button>]
|
||||
: []
|
||||
),
|
||||
...(task.status === 'completed' || task.status === 'failed' || task.status === 'cancelled'
|
||||
? [<Button key="delete" size="small" type="link" danger onClick={() => handleDeleteTask(task.id)}>删除</Button>]
|
||||
: []
|
||||
),
|
||||
].filter(Boolean)}
|
||||
>
|
||||
<List.Item.Meta
|
||||
title={
|
||||
<Space>
|
||||
{getTaskStatusTag(task.status)}
|
||||
<span>{getTaskTypeLabel(task.task_type)}</span>
|
||||
{task.status === 'running' || task.status === 'pending' ? (
|
||||
<Progress percent={task.progress} size="small" style={{ width: 120 }} />
|
||||
) : null}
|
||||
</Space>
|
||||
}
|
||||
description={
|
||||
<div>
|
||||
<div style={{ fontSize: 12, color: token.colorTextSecondary }}>
|
||||
{task.status_message || '无状态信息'}
|
||||
</div>
|
||||
<div style={{ fontSize: 11, color: token.colorTextTertiary, marginTop: 4 }}>
|
||||
创建: {task.created_at ? new Date(task.created_at).toLocaleString() : '-'}
|
||||
{task.completed_at && ` | 完成: ${new Date(task.completed_at).toLocaleString()}`}
|
||||
</div>
|
||||
{task.error_message && (
|
||||
<div style={{ fontSize: 12, color: token.colorError, marginTop: 4 }}>
|
||||
❌ {task.error_message}
|
||||
</div>
|
||||
)}
|
||||
{task.task_result && task.status === 'completed' && (
|
||||
<div style={{ fontSize: 12, color: token.colorSuccess, marginTop: 4 }}>
|
||||
✅ {(task.task_result as Record<string, unknown>).message as string || '任务完成'}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
}
|
||||
/>
|
||||
</List.Item>
|
||||
)}
|
||||
/>
|
||||
)}
|
||||
</Modal>
|
||||
|
||||
{/* 批量展开预览 Modal */}
|
||||
<Modal
|
||||
title={
|
||||
@@ -1923,7 +2090,16 @@ export default function Outline() {
|
||||
visible={sseModalVisible}
|
||||
progress={sseProgress}
|
||||
message={sseMessage}
|
||||
title="AI生成中..."
|
||||
title="AI生成中(后台运行,可关闭页面)..."
|
||||
onCancel={() => {
|
||||
if (cancelGenerateRef.current) {
|
||||
cancelGenerateRef.current();
|
||||
cancelGenerateRef.current = null;
|
||||
}
|
||||
setSSEModalVisible(false);
|
||||
setIsGenerating(false);
|
||||
message.info('已取消生成任务');
|
||||
}}
|
||||
/>
|
||||
|
||||
<div style={{ display: 'flex', flexDirection: 'column', height: '100%' }}>
|
||||
@@ -1977,6 +2153,15 @@ export default function Outline() {
|
||||
>
|
||||
{isMobile ? 'AI生成/续写' : 'AI生成/续写大纲'}
|
||||
</Button>
|
||||
<Tooltip title="查看后台任务进度">
|
||||
<Button
|
||||
icon={<ClockCircleOutlined />}
|
||||
onClick={showTaskListModal}
|
||||
block={isMobile}
|
||||
>
|
||||
{isMobile ? '任务' : '后台任务'}
|
||||
</Button>
|
||||
</Tooltip>
|
||||
{outlines.length > 0 && currentProject?.outline_mode === 'one-to-many' && (
|
||||
<Button
|
||||
icon={<AppstoreAddOutlined />}
|
||||
|
||||
@@ -0,0 +1,227 @@
|
||||
/**
|
||||
* 后台任务服务 - 轮询任务进度,替代SSE
|
||||
*/
|
||||
|
||||
const API_BASE = '/api/tasks';
|
||||
|
||||
export interface TaskStatus {
|
||||
id: string;
|
||||
task_type: string;
|
||||
project_id: string;
|
||||
status: 'pending' | 'running' | 'completed' | 'failed' | 'cancelled';
|
||||
progress: number; // 0-100
|
||||
status_message: string | null;
|
||||
progress_details: {
|
||||
stage: string;
|
||||
message: string;
|
||||
current_chars?: number;
|
||||
retry_count?: number;
|
||||
} | null;
|
||||
error_message: string | null;
|
||||
task_result: Record<string, unknown> | null;
|
||||
retry_count: number;
|
||||
cancel_requested: boolean;
|
||||
created_at: string | null;
|
||||
started_at: string | null;
|
||||
completed_at: string | null;
|
||||
updated_at: string | null;
|
||||
}
|
||||
|
||||
export interface TaskListResponse {
|
||||
items: TaskStatus[];
|
||||
}
|
||||
|
||||
/**
|
||||
* 查询任务状态
|
||||
*/
|
||||
export async function getTaskStatus(taskId: string): Promise<TaskStatus> {
|
||||
const response = await fetch(`${API_BASE}/${taskId}`);
|
||||
if (!response.ok) {
|
||||
throw new Error(`查询任务状态失败: ${response.statusText}`);
|
||||
}
|
||||
return response.json();
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取项目的任务列表
|
||||
*/
|
||||
export async function getProjectTasks(
|
||||
projectId: string,
|
||||
taskType?: string,
|
||||
limit: number = 20
|
||||
): Promise<TaskListResponse> {
|
||||
const params = new URLSearchParams({ project_id: projectId, limit: String(limit) });
|
||||
if (taskType) params.set('task_type', taskType);
|
||||
const response = await fetch(`${API_BASE}?${params}`);
|
||||
if (!response.ok) {
|
||||
throw new Error(`获取任务列表失败: ${response.statusText}`);
|
||||
}
|
||||
return response.json();
|
||||
}
|
||||
|
||||
/**
|
||||
* 取消任务
|
||||
*/
|
||||
export async function cancelTask(taskId: string): Promise<void> {
|
||||
const response = await fetch(`${API_BASE}/${taskId}/cancel`, { method: 'POST' });
|
||||
if (!response.ok) {
|
||||
throw new Error(`取消任务失败: ${response.statusText}`);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 删除任务记录
|
||||
*/
|
||||
export async function deleteTask(taskId: string): Promise<void> {
|
||||
const response = await fetch(`${API_BASE}/${taskId}`, { method: 'DELETE' });
|
||||
if (!response.ok) {
|
||||
throw new Error(`删除任务失败: ${response.statusText}`);
|
||||
}
|
||||
}
|
||||
|
||||
export type TaskProgressCallback = (status: TaskStatus) => void;
|
||||
export type TaskCompleteCallback = (result: TaskStatus) => void;
|
||||
export type TaskErrorCallback = (error: string, status: TaskStatus) => void;
|
||||
|
||||
/**
|
||||
* 轮询任务直到完成
|
||||
*
|
||||
* @param taskId 任务ID
|
||||
* @param onProgress 进度回调
|
||||
* @param onComplete 完成回调
|
||||
* @param onError 错误回调
|
||||
* @param intervalMs 轮询间隔(毫秒),默认2000
|
||||
* @returns 取消轮询的函数
|
||||
*/
|
||||
export function pollTaskUntilComplete(
|
||||
taskId: string,
|
||||
onProgress: TaskProgressCallback,
|
||||
onComplete: TaskCompleteCallback,
|
||||
onError: TaskErrorCallback,
|
||||
intervalMs: number = 2000
|
||||
): () => void {
|
||||
let cancelled = false;
|
||||
let timerId: ReturnType<typeof setTimeout>;
|
||||
|
||||
const poll = async () => {
|
||||
if (cancelled) return;
|
||||
|
||||
try {
|
||||
const status = await getTaskStatus(taskId);
|
||||
|
||||
if (cancelled) return;
|
||||
|
||||
onProgress(status);
|
||||
|
||||
if (status.status === 'completed') {
|
||||
onComplete(status);
|
||||
return;
|
||||
}
|
||||
|
||||
if (status.status === 'failed') {
|
||||
onError(status.error_message || '任务失败', status);
|
||||
return;
|
||||
}
|
||||
|
||||
if (status.status === 'cancelled') {
|
||||
onError('任务已取消', status);
|
||||
return;
|
||||
}
|
||||
|
||||
// 继续轮询(运行中时加快轮询频率)
|
||||
const nextInterval = status.status === 'running' ? intervalMs : intervalMs * 2;
|
||||
timerId = setTimeout(poll, nextInterval);
|
||||
} catch (err) {
|
||||
if (!cancelled) {
|
||||
onError(err instanceof Error ? err.message : '查询任务状态失败', {} as TaskStatus);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// 立即开始第一次轮询
|
||||
timerId = setTimeout(poll, 0);
|
||||
|
||||
// 返回取消函数
|
||||
return () => {
|
||||
cancelled = true;
|
||||
clearTimeout(timerId);
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* 请求后台生成大纲并轮询进度
|
||||
*
|
||||
* @param data 请求参数(同原来的generate-stream)
|
||||
* @param onProgress 进度回调
|
||||
* @param onComplete 完成回调
|
||||
* @param onError 错误回调
|
||||
* @returns 取消函数(同时取消轮询和后台任务)
|
||||
*/
|
||||
export async function generateOutlineBackground(
|
||||
data: unknown,
|
||||
onProgress: TaskProgressCallback,
|
||||
onComplete: TaskCompleteCallback,
|
||||
onError: TaskErrorCallback
|
||||
): Promise<() => void> {
|
||||
// 1. 创建后台任务
|
||||
const response = await fetch('/api/outlines/generate', {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify(data),
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const err = await response.json().catch(() => ({ detail: response.statusText }));
|
||||
onError(err.detail || '创建任务失败', {} as TaskStatus);
|
||||
return () => {};
|
||||
}
|
||||
|
||||
const { task_id } = await response.json();
|
||||
|
||||
// 2. 开始轮询
|
||||
let cancelPolling = pollTaskUntilComplete(task_id, onProgress, onComplete, onError);
|
||||
|
||||
// 3. 返回统一的取消函数(取消轮询 + 取消后台任务)
|
||||
return () => {
|
||||
cancelPolling();
|
||||
cancelTask(task_id).catch(() => {});
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* 请求后台生成章节内容并轮询进度
|
||||
* 关闭浏览器不影响生成,生成完成后内容自动保存到数据库
|
||||
*/
|
||||
export async function generateChapterBackground(
|
||||
chapterId: string,
|
||||
options: {
|
||||
style_id?: number | null;
|
||||
target_word_count?: number;
|
||||
model?: string | null;
|
||||
narrative_perspective?: string | null;
|
||||
enable_mcp?: boolean;
|
||||
},
|
||||
onProgress: TaskProgressCallback,
|
||||
onComplete: TaskCompleteCallback,
|
||||
onError: TaskErrorCallback
|
||||
): Promise<() => void> {
|
||||
const response = await fetch(`/api/chapters/${chapterId}/generate-background`, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify(options),
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const err = await response.json().catch(() => ({ detail: response.statusText }));
|
||||
onError(err.detail || '创建章节生成任务失败', {} as TaskStatus);
|
||||
return () => {};
|
||||
}
|
||||
|
||||
const { task_id } = await response.json();
|
||||
const cancelPolling = pollTaskUntilComplete(task_id, onProgress, onComplete, onError);
|
||||
|
||||
return () => {
|
||||
cancelPolling();
|
||||
cancelTask(task_id).catch(() => {});
|
||||
};
|
||||
}
|
||||
Reference in New Issue
Block a user