From 2bd8b61e918a861de842040a5c4cdd5413991a8d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9C=AA=E6=9D=A5?= Date: Wed, 29 Apr 2026 08:31:07 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E5=90=8E=E5=8F=B0=E4=BB=BB=E5=8A=A1?= =?UTF-8?q?=E7=B3=BB=E7=BB=9F=20+=20JSON=E5=AE=B9=E9=94=99=E8=A7=A3?= =?UTF-8?q?=E6=9E=90=20+=20SSE=E5=BF=83=E8=B7=B3=E4=BF=9D=E6=B4=BB=20+=20?= =?UTF-8?q?=E5=A4=9A=E9=A1=B9Bug=E4=BF=AE=E5=A4=8D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 新功能: - 大纲/章节生成改为服务端后台任务,支持断线续传 - 后台任务队列排队执行,按用户排队(同用户串行不同用户并发) - 章节管理页面添加后台任务列表弹窗和进度面板 - 章节状态添加 pending(待处理)选项 - 集成json5容错解析器 + 上下文感知JSON修复 - SSE流式生成添加心跳保活,防止连接超时 - SSEPostClient添加credentials:include修复network error - 每章最大伏笔数从2调整为5 - 添加大纲读区伏笔的功能 Bug修复: - 修复AI生成JSON中未转义引号/中文标点/多对象属性值未合并 - 修复JSON非法转义字符清洗和中文引号处理 - 修复MCP插件TimeoutError/连接失败上下文清理 - MCP插件后台注册添加重试机制 - 续写模式添加缺失的mcp_references参数 - 修复Alembic迁移链分叉 - 使用torch CPU版本加速Docker构建 --- .gitignore | 3 +- .../20260427_1200_abc12345_添加后台任务表.py | 46 ++ ...260427_1200_def45678ghi9_添加后台任务表.py | 44 ++ backend/app/api/chapters.py | 445 +++++++++++++++ backend/app/api/outlines.py | 517 ++++++++++++++++++ backend/app/api/settings.py | 38 ++ backend/app/api/tasks.py | 122 +++++ backend/app/database.py | 3 +- backend/app/main.py | 19 +- backend/app/models/__init__.py | 6 +- backend/app/models/background_task.py | 46 ++ backend/app/schemas/chapter.py | 1 + .../app/services/background_task_service.py | 387 +++++++++++++ backend/app/services/foreshadow_service.py | 2 +- backend/app/services/json_helper.py | 428 ++++++++++++--- docker-compose.yml | 2 +- frontend/pnpm-lock.yaml | 31 ++ frontend/src/pages/Chapters.tsx | 424 +++++++++++++- frontend/src/pages/Outline.tsx | 233 +++++++- .../src/services/backgroundTaskService.ts | 227 ++++++++ 20 files changed, 2873 insertions(+), 151 deletions(-) create mode 100644 backend/alembic/postgres/versions/20260427_1200_abc12345_添加后台任务表.py create mode 100644 backend/alembic/sqlite/versions/20260427_1200_def45678ghi9_添加后台任务表.py create mode 100644 backend/app/api/tasks.py create mode 100644 backend/app/models/background_task.py create mode 100644 backend/app/services/background_task_service.py create mode 100644 frontend/src/services/backgroundTaskService.ts diff --git a/.gitignore b/.gitignore index b6302eb..f5830fe 100644 --- a/.gitignore +++ b/.gitignore @@ -124,4 +124,5 @@ test_api.py backend/embedding/ # 提示词工坊实例标识(每个部署实例必须唯一) -backend/.instance_id \ No newline at end of file +backend/.instance_id +test.json diff --git a/backend/alembic/postgres/versions/20260427_1200_abc12345_添加后台任务表.py b/backend/alembic/postgres/versions/20260427_1200_abc12345_添加后台任务表.py new file mode 100644 index 0000000..69b10af --- /dev/null +++ b/backend/alembic/postgres/versions/20260427_1200_abc12345_添加后台任务表.py @@ -0,0 +1,46 @@ +"""添加后台任务表 + +Revision ID: abc12345 +Revises: +Create Date: 2026-04-27 12:00:00.000000 +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects.postgresql import JSON, JSONB + +# revision identifiers +revision = 'abc12345' +down_revision = '9a1b2c3d4e5f' +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.create_table( + 'background_tasks', + sa.Column('id', sa.String(36), primary_key=True), + sa.Column('user_id', sa.String(100), nullable=False, index=True, comment='用户ID'), + sa.Column('project_id', sa.String(36), nullable=False, index=True, comment='项目ID'), + sa.Column('task_type', sa.String(50), nullable=False, comment='任务类型'), + sa.Column('status', sa.String(20), default='pending', comment='任务状态'), + sa.Column('progress', sa.Integer, default=0, comment='进度百分比'), + sa.Column('status_message', sa.String(500), comment='当前状态消息'), + sa.Column('task_input', JSON, comment='任务输入参数'), + sa.Column('task_result', JSON, comment='任务结果'), + sa.Column('error_message', sa.Text, comment='错误信息'), + sa.Column('progress_details', JSON, comment='进度详情'), + sa.Column('cancel_requested', sa.Boolean, default=False, comment='是否请求取消'), + sa.Column('retry_count', sa.Integer, default=0, comment='已重试次数'), + sa.Column('max_retries', sa.Integer, default=3, comment='最大重试次数'), + sa.Column('created_at', sa.DateTime, server_default=sa.func.now(), comment='创建时间'), + sa.Column('started_at', sa.DateTime, comment='开始时间'), + sa.Column('completed_at', sa.DateTime, comment='完成时间'), + sa.Column('updated_at', sa.DateTime, server_default=sa.func.now(), onupdate=sa.func.now(), comment='更新时间'), + ) + # 添加复合索引:按用户+项目+状态查询 + op.create_index('ix_background_tasks_user_project', 'background_tasks', ['user_id', 'project_id', 'status']) + + +def downgrade() -> None: + op.drop_index('ix_background_tasks_user_project', table_name='background_tasks') + op.drop_table('background_tasks') \ No newline at end of file diff --git a/backend/alembic/sqlite/versions/20260427_1200_def45678ghi9_添加后台任务表.py b/backend/alembic/sqlite/versions/20260427_1200_def45678ghi9_添加后台任务表.py new file mode 100644 index 0000000..372bac1 --- /dev/null +++ b/backend/alembic/sqlite/versions/20260427_1200_def45678ghi9_添加后台任务表.py @@ -0,0 +1,44 @@ +"""添加后台任务表 + +Revision ID: def45678ghi9 +Revises: ab12cd34ef56 +Create Date: 2026-04-27 12:00:00.000000 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = 'def45678ghi9' +down_revision = 'ab12cd34ef56' +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.create_table( + 'background_tasks', + sa.Column('id', sa.String(), primary_key=True), + sa.Column('task_type', sa.String(), nullable=False, index=True), + sa.Column('project_id', sa.String(), nullable=False, index=True), + sa.Column('user_id', sa.String(), nullable=True), + sa.Column('status', sa.String(), default='pending', nullable=False, index=True), + sa.Column('progress', sa.Float(), default=0.0), + sa.Column('status_message', sa.String(), nullable=True), + sa.Column('progress_details', sa.Text(), nullable=True), + sa.Column('error_message', sa.Text(), nullable=True), + sa.Column('task_params', sa.Text(), nullable=True), + sa.Column('task_result', sa.Text(), nullable=True), + sa.Column('cancel_requested', sa.Boolean(), default=False), + sa.Column('retry_count', sa.Integer(), default=0), + sa.Column('max_retries', sa.Integer(), default=3), + sa.Column('created_at', sa.DateTime(), server_default=sa.func.now()), + sa.Column('started_at', sa.DateTime(), nullable=True), + sa.Column('completed_at', sa.DateTime(), nullable=True), + sa.Column('updated_at', sa.DateTime(), server_default=sa.func.now(), onupdate=sa.func.now()), + ) + + +def downgrade() -> None: + op.drop_table('background_tasks') \ No newline at end of file diff --git a/backend/app/api/chapters.py b/backend/app/api/chapters.py index f1753a3..f32c8cc 100644 --- a/backend/app/api/chapters.py +++ b/backend/app/api/chapters.py @@ -27,6 +27,7 @@ from app.models.analysis_task import AnalysisTask from app.models.memory import PlotAnalysis, StoryMemory from app.models.batch_generation_task import BatchGenerationTask from app.models.regeneration_task import RegenerationTask +from app.models.background_task import BackgroundTask from app.schemas.chapter import ( ChapterCreate, ChapterUpdate, @@ -1815,6 +1816,450 @@ async def generate_chapter_content_stream( return create_sse_response(event_generator()) +@router.post("/{chapter_id}/generate-background", summary="AI创作章节内容(后台任务)") +async def generate_chapter_content_background( + chapter_id: str, + request: Request, + generate_request: ChapterGenerateRequest = ChapterGenerateRequest(), + db: AsyncSession = Depends(get_db) +): + """ + 创建后台任务来生成章节内容。 + 任务创建后立即返回task_id,前端通过 GET /api/tasks/{task_id} 轮询进度。 + 关闭浏览器不影响生成,生成完成后内容自动保存到数据库。 + """ + user_id = getattr(request.state, 'user_id', None) + if not user_id: + raise HTTPException(status_code=401, detail="未登录") + + # 验证章节存在 + result = await db.execute( + select(Chapter).where(Chapter.id == chapter_id) + ) + chapter = result.scalar_one_or_none() + if not chapter: + raise HTTPException(status_code=404, detail="章节不存在") + + # 验证项目权限 + project = await verify_project_access(chapter.project_id, user_id, db) + + # 检查前置条件 + can_generate, error_msg, _ = await check_prerequisites(db, chapter) + if not can_generate: + raise HTTPException(status_code=400, detail=error_msg) + + # 创建后台任务 + from app.services.background_task_service import background_task_service, TaskProgressTracker + task = await background_task_service.create_task( + user_id=user_id, + project_id=chapter.project_id, + task_type="chapter_generate", + task_input={ + "chapter_id": chapter_id, + "style_id": generate_request.style_id, + "target_word_count": generate_request.target_word_count or 3000, + "enable_mcp": generate_request.enable_mcp, + "model": generate_request.model, + "narrative_perspective": generate_request.narrative_perspective, + }, + db=db + ) + + # 后台执行的函数 + async def _run_chapter_generation(task_id: str, bg_user_id: str): + from app.database import get_engine + from sqlalchemy.ext.asyncio import async_sessionmaker, AsyncSession as BgAsyncSession + + engine = await get_engine(bg_user_id) + AsyncSessionLocal = async_sessionmaker(engine, class_=BgAsyncSession, expire_on_commit=False) + + async with AsyncSessionLocal() as bg_db: + tracker = TaskProgressTracker(task_id, bg_user_id, "章节") + try: + await tracker.start() + + # 获取AI服务 + from app.api.settings import get_user_ai_service_from_db + bg_ai_service = await get_user_ai_service_from_db(bg_user_id, bg_db) + + await _run_chapter_generation_bg( + task_input={ + "chapter_id": chapter_id, + "style_id": generate_request.style_id, + "target_word_count": generate_request.target_word_count or 3000, + "enable_mcp": generate_request.enable_mcp, + "model": generate_request.model, + "narrative_perspective": generate_request.narrative_perspective, + }, + db=bg_db, + ai_service=bg_ai_service, + tracker=tracker, + user_id=bg_user_id, + task_id=task_id, + ) + + except Exception as e: + logger.error(f"❌ 后台章节生成失败: {e}", exc_info=True) + await tracker.error(str(e)) + + await background_task_service.spawn_background_task( + task.id, user_id, _run_chapter_generation + ) + + return { + "task_id": task.id, + "task_type": "chapter_generate", + "status": "pending", + "message": "任务已创建,请通过 GET /api/tasks/{task_id} 查询进度" + } + + +async def _run_chapter_generation_bg( + task_input: dict, + db: AsyncSession, + ai_service: AIService, + tracker, + user_id: str, + task_id: str, +): + """后台执行章节生成(不使用SSE,直接生成并保存)""" + from app.services.chapter_context_service import ( + OneToManyContextBuilder, + OneToOneContextBuilder + ) + + chapter_id = task_input["chapter_id"] + style_id = task_input.get("style_id") + target_word_count = task_input.get("target_word_count", 3000) + custom_model = task_input.get("model") + temp_narrative_perspective = task_input.get("narrative_perspective") + write_lock = await get_db_write_lock(user_id) + + # === 加载阶段 === + await tracker.loading("加载章节信息...", 0.2) + + chapter_result = await db.execute( + select(Chapter).where(Chapter.id == chapter_id) + ) + current_chapter = chapter_result.scalar_one_or_none() + if not current_chapter: + await tracker.error("章节不存在") + return + + await tracker.loading("加载项目信息...", 0.4) + + project_result = await db.execute( + select(Project).where(Project.id == current_chapter.project_id) + ) + project = project_result.scalar_one_or_none() + if not project: + await tracker.error("项目不存在") + return + + outline_mode = project.outline_mode if project else 'one-to-many' + + # 获取大纲 + if current_chapter.outline_id: + outline_result = await db.execute( + select(Outline).where(Outline.id == current_chapter.outline_id) + ) + else: + outline_result = await db.execute( + select(Outline) + .where(Outline.project_id == current_chapter.project_id) + .where(Outline.order_index == current_chapter.chapter_number) + ) + outline = outline_result.scalar_one_or_none() + + # 获取写作风格 + style_content = "" + if style_id: + style_result = await db.execute( + select(WritingStyle).where(WritingStyle.id == style_id) + ) + style = style_result.scalar_one_or_none() + if style and (style.user_id is None or style.user_id == user_id): + style_content = style.prompt_content or "" + + # === 构建上下文 === + if outline_mode == 'one-to-one': + context_builder = OneToOneContextBuilder( + memory_service=memory_service, + foreshadow_service=foreshadow_service + ) + chapter_context = await context_builder.build( + chapter=current_chapter, + project=project, + outline=outline, + user_id=user_id, + db=db, + target_word_count=target_word_count + ) + else: + context_builder = OneToManyContextBuilder( + memory_service=memory_service, + foreshadow_service=foreshadow_service + ) + chapter_context = await context_builder.build( + chapter=current_chapter, + project=project, + outline=outline, + user_id=user_id, + db=db, + style_content=style_content, + target_word_count=target_word_count, + temp_narrative_perspective=temp_narrative_perspective + ) + + await tracker.loading("上下文构建完成", 0.8) + + # 确定叙事人称 + chapter_perspective = ( + temp_narrative_perspective or + project.narrative_perspective or + '第三人称' + ) + + # === 准备提示词 === + if outline_mode == 'one-to-one': + if chapter_context.continuation_point: + template = await PromptService.get_template("CHAPTER_GENERATION_ONE_TO_ONE_NEXT", user_id, db) + base_prompt = PromptService.format_prompt( + template, + project_title=project.title, + chapter_number=current_chapter.chapter_number, + chapter_title=current_chapter.title, + chapter_outline=chapter_context.chapter_outline, + target_word_count=target_word_count, + genre=project.genre or '未设定', + narrative_perspective=chapter_perspective, + previous_chapter_content=chapter_context.continuation_point, + previous_chapter_summary=chapter_context.previous_chapter_summary or '(无上一章摘要)', + characters_info=chapter_context.chapter_characters or '暂无角色信息', + chapter_careers=chapter_context.chapter_careers or '暂无职业信息', + foreshadow_reminders=chapter_context.foreshadow_reminders or '暂无需要关注的伏笔', + relevant_memories=chapter_context.relevant_memories or '暂无相关记忆' + ) + else: + template = await PromptService.get_template("CHAPTER_GENERATION_ONE_TO_ONE", user_id, db) + base_prompt = PromptService.format_prompt( + template, + project_title=project.title, + chapter_number=current_chapter.chapter_number, + chapter_title=current_chapter.title, + chapter_outline=chapter_context.chapter_outline, + target_word_count=target_word_count, + genre=project.genre or '未设定', + narrative_perspective=chapter_perspective, + characters_info=chapter_context.chapter_characters or '暂无角色信息', + chapter_careers=chapter_context.chapter_careers or '暂无职业信息', + foreshadow_reminders=chapter_context.foreshadow_reminders or '暂无需要关注的伏笔', + relevant_memories=chapter_context.relevant_memories or '暂无相关记忆' + ) + else: + if chapter_context.continuation_point: + previous_summary = chapter_context.previous_chapter_summary or "(无上一章摘要,请根据锚点续写)" + template = await PromptService.get_template("CHAPTER_GENERATION_ONE_TO_MANY_NEXT", user_id, db) + base_prompt = PromptService.format_prompt( + template, + project_title=project.title, + chapter_number=current_chapter.chapter_number, + chapter_title=current_chapter.title, + chapter_outline=chapter_context.chapter_outline, + target_word_count=target_word_count, + continuation_point=chapter_context.continuation_point, + genre=project.genre or '未设定', + narrative_perspective=chapter_perspective, + characters_info=chapter_context.chapter_characters or '暂无角色信息', + chapter_careers=chapter_context.chapter_careers or '暂无职业信息', + foreshadow_reminders=chapter_context.foreshadow_reminders or '暂无需要关注的伏笔', + previous_chapter_summary=previous_summary, + recent_chapters_context=chapter_context.recent_chapters_context or '', + relevant_memories=chapter_context.relevant_memories or '' + ) + else: + template = await PromptService.get_template("CHAPTER_GENERATION_ONE_TO_MANY", user_id, db) + base_prompt = PromptService.format_prompt( + template, + project_title=project.title, + chapter_number=current_chapter.chapter_number, + chapter_title=current_chapter.title, + chapter_outline=chapter_context.chapter_outline, + target_word_count=target_word_count, + genre=project.genre or '未设定', + narrative_perspective=chapter_perspective, + characters_info=chapter_context.chapter_characters or '暂无角色信息', + chapter_careers=chapter_context.chapter_careers or '暂无职业信息', + foreshadow_reminders=chapter_context.foreshadow_reminders or '暂无需要关注的伏笔', + relevant_memories=chapter_context.relevant_memories or '暂无相关记忆' + ) + + # 应用写作风格 + if style_content: + prompt = WritingStyleManager.apply_style_to_prompt(base_prompt, style_content) + else: + prompt = base_prompt + + # === 准备阶段 === + await tracker.preparing("准备AI提示词...") + + system_prompt_with_style = None + if style_content: + system_prompt_with_style = f"""【🎨 写作风格要求 - 最高优先级】 + +{style_content} + +⚠️ 请严格遵循上述写作风格要求进行创作,这是最重要的指令! +确保在整个章节创作过程中始终保持风格的一致性。""" + + calculated_max_tokens = int(target_word_count * 3) + calculated_max_tokens = max(2000, min(calculated_max_tokens, 16000)) + + generate_kwargs = { + "prompt": prompt, + "system_prompt": system_prompt_with_style, + "tool_choice": "required", + "max_tokens": calculated_max_tokens + } + if custom_model: + generate_kwargs["model"] = custom_model + + # === 生成阶段 === + full_content = "" + chunk_count = 0 + + await tracker.generating( + current_chars=0, + estimated_total=target_word_count + ) + + async for chunk in ai_service.generate_text_stream(**generate_kwargs): + # 检查是否被取消 + if chunk_count % 10 == 0 and await tracker.check_cancelled(): + logger.info(f"🚫 后台章节生成被取消: {chapter_id}") + return + + full_content += chunk + chunk_count += 1 + + # 每10个chunk更新一次进度 + if chunk_count % 10 == 0: + await tracker.generating( + current_chars=len(full_content), + estimated_total=target_word_count, + message=f'正在创作中... 已生成 {len(full_content)} 字' + ) + + await asyncio.sleep(0) + + # === 保存阶段 === + await tracker.saving("正在保存章节...", 0.3) + + async with write_lock: + # 重新获取章节(确保最新状态) + chapter_result = await db.execute( + select(Chapter).where(Chapter.id == chapter_id) + ) + current_chapter = chapter_result.scalar_one_or_none() + if not current_chapter: + await tracker.error("保存时章节不存在") + return + + old_word_count = current_chapter.word_count or 0 + current_chapter.content = full_content + new_word_count = len(full_content) + current_chapter.word_count = new_word_count + current_chapter.status = "completed" + + # 更新项目字数 + project_result = await db.execute( + select(Project).where(Project.id == current_chapter.project_id) + ) + project_obj = project_result.scalar_one_or_none() + if project_obj: + project_obj.current_words = (project_obj.current_words or 0) - old_word_count + new_word_count + + # 记录生成历史 + history = GenerationHistory( + project_id=current_chapter.project_id, + chapter_id=current_chapter.id, + prompt=f"创作章节: 第{current_chapter.chapter_number}章 {current_chapter.title}", + generated_content=full_content[:500] if len(full_content) > 500 else full_content, + model="default" + ) + db.add(history) + + await db.commit() + + logger.info(f"✅ 后台创作章节 {chapter_id} 完成,共 {new_word_count} 字") + + # 🔮 自动标记伏笔 + try: + plant_result = await foreshadow_service.auto_plant_pending_foreshadows( + db=db, + project_id=current_chapter.project_id, + chapter_id=chapter_id, + chapter_number=current_chapter.chapter_number, + chapter_content=full_content + ) + if plant_result.get('planted_count', 0) > 0: + logger.info(f"🔮 自动标记伏笔已埋入: {plant_result['planted_count']}个") + except Exception as plant_error: + logger.warning(f"⚠️ 自动标记伏笔埋入失败: {str(plant_error)}") + + # 创建分析任务 + analysis_task = AnalysisTask( + chapter_id=chapter_id, + user_id=user_id, + project_id=current_chapter.project_id, + status='pending', + progress=0 + ) + db.add(analysis_task) + await db.commit() + await db.refresh(analysis_task) + + logger.info(f"📋 后台生成:已创建分析任务: {analysis_task.id}") + + await asyncio.sleep(0.05) + + # 启动后台分析 + asyncio.create_task( + analyze_chapter_background( + chapter_id=chapter_id, + user_id=user_id, + project_id=current_chapter.project_id, + task_id=analysis_task.id, + ai_service=ai_service + ) + ) + + # === 完成 === + await tracker.complete(f"创作完成!共 {new_word_count} 字") + + # 更新任务结果 + from app.services.background_task_service import background_task_service + from sqlalchemy.ext.asyncio import async_sessionmaker, AsyncSession as BgAsyncSession + from app.database import get_engine as bg_get_engine + try: + engine = await bg_get_engine(user_id) + AsyncSessionLocal = async_sessionmaker(engine, class_=BgAsyncSession, expire_on_commit=False) + async with AsyncSessionLocal() as result_db: + from sqlalchemy import update as sql_update + await result_db.execute( + sql_update(BackgroundTask) + .where(BackgroundTask.id == task_id) + .values(task_result={ + "chapter_id": chapter_id, + "word_count": new_word_count, + "analysis_task_id": analysis_task.id + }) + ) + await result_db.commit() + except Exception as e: + logger.warning(f"⚠️ 更新任务结果失败: {e}") + + def _build_analysis_task_status_payload( chapter_id: str, task: Optional[AnalysisTask], diff --git a/backend/app/api/outlines.py b/backend/app/api/outlines.py index aac00e4..1dc4e57 100644 --- a/backend/app/api/outlines.py +++ b/backend/app/api/outlines.py @@ -1717,6 +1717,523 @@ async def continue_outline_generator( yield await tracker.error(f"续写失败: {str(e)}") +@router.post("/generate", summary="AI生成/续写大纲(后台任务)") +async def generate_outline_task( + data: Dict[str, Any], + request: Request, + db: AsyncSession = Depends(get_db), + user_ai_service: AIService = Depends(get_user_ai_service) +): + """ + 使用后台任务生成或续写小说大纲(不怕断连,关闭浏览器也继续运行) + + 返回task_id,前端通过 GET /api/tasks/{task_id} 轮询进度 + + 支持模式: + - auto/new/continue(同 generate-stream) + """ + from app.services.background_task_service import background_task_service, TaskProgressTracker + from app.database import get_engine + from sqlalchemy.ext.asyncio import async_sessionmaker, AsyncSession as NewAsyncSession + + user_id = getattr(request.state, 'user_id', None) + project = await verify_project_access(data.get("project_id"), user_id, db) + + # 判断模式 + mode = data.get("mode", "auto") + existing_result = await db.execute( + select(Outline) + .where(Outline.project_id == data.get("project_id")) + .order_by(Outline.order_index) + ) + existing_outlines = existing_result.scalars().all() + + if mode == "auto": + mode = "continue" if existing_outlines else "new" + + data["user_id"] = user_id + data["mode"] = mode + + if mode == "continue" and not existing_outlines: + raise HTTPException(status_code=400, detail="续写模式需要已有大纲") + + # 创建后台任务 + task_type = "outline_new" if mode == "new" else "outline_continue" + task = await background_task_service.create_task( + user_id=user_id, + project_id=data.get("project_id"), + task_type=task_type, + task_input=data, + db=db + ) + + # 后台执行的函数 + async def _run_outline_generation(task_id: str, user_id: str): + engine = await get_engine(user_id) + AsyncSessionLocal = async_sessionmaker(engine, class_=NewAsyncSession, expire_on_commit=False) + + async with AsyncSessionLocal() as bg_db: + tracker = TaskProgressTracker(task_id, user_id, "大纲") + try: + await tracker.start() + + # 获取AI服务(需要在后台创建新实例) + from app.api.settings import get_user_ai_service_from_db + bg_ai_service = await get_user_ai_service_from_db(user_id, bg_db) + + if mode == "new": + await _run_new_outline_bg(data, bg_db, bg_ai_service, tracker) + else: + await _run_continue_outline_bg(data, bg_db, bg_ai_service, tracker, user_id) + + except Exception as e: + logger.error(f"❌ 后台大纲生成失败: {e}", exc_info=True) + await tracker.error(str(e)) + + await background_task_service.spawn_background_task( + task.id, user_id, _run_outline_generation + ) + + return { + "task_id": task.id, + "task_type": task_type, + "status": "pending", + "message": "任务已创建,请通过 GET /api/tasks/{task_id} 查询进度" + } + + +async def _run_new_outline_bg( + data: Dict[str, Any], + db: AsyncSession, + user_ai_service: AIService, + tracker +): + """后台执行全新大纲生成""" + from app.database import get_engine + from sqlalchemy.ext.asyncio import async_sessionmaker, AsyncSession as BgAsyncSession + + project_id = data.get("project_id") + chapter_count = int(data.get("chapter_count", 10)) + user_id_for_mcp = data.get("user_id") + + await tracker.loading("加载项目信息...", 0.3) + result = await db.execute(select(Project).where(Project.id == project_id)) + project = result.scalar_one_or_none() + if not project: + await tracker.error("项目不存在") + return + + await tracker.loading(f"准备生成{chapter_count}章大纲...", 0.6) + characters_result = await db.execute(select(Character).where(Character.project_id == project_id)) + characters = characters_result.scalars().all() + characters_info = _build_characters_info(characters) + + if user_id_for_mcp: + user_ai_service.user_id = user_id_for_mcp + user_ai_service.db_session = db + + await tracker.preparing("准备AI提示词...") + template = await PromptService.get_template("OUTLINE_CREATE", user_id_for_mcp, db) + prompt = PromptService.format_prompt( + template, + title=project.title, + theme=data.get("theme") or project.theme or "未设定", + genre=data.get("genre") or project.genre or "通用", + chapter_count=chapter_count, + narrative_perspective=data.get("narrative_perspective") or "第三人称", + time_period=project.world_time_period or "未设定", + location=project.world_location or "未设定", + atmosphere=project.world_atmosphere or "未设定", + rules=project.world_rules or "未设定", + characters_info=characters_info or "暂无角色信息", + requirements=data.get("requirements") or "", + mcp_references="" + ) + + model_param = data.get("model") + provider_param = data.get("provider") + + estimated_total = chapter_count * 1000 + accumulated_text = "" + chunk_count = 0 + + await tracker.generating(current_chars=0, estimated_total=estimated_total) + + async for chunk in user_ai_service.generate_text_stream( + prompt=prompt, provider=provider_param, model=model_param + ): + chunk_count += 1 + accumulated_text += chunk + if chunk_count % 10 == 0: + if await tracker.check_cancelled(): + await tracker.error("任务已取消") + return + await tracker.generating( + current_chars=len(accumulated_text), + estimated_total=estimated_total + ) + + await tracker.parsing("解析大纲数据...") + ai_content = accumulated_text + + # 解析响应(带重试) + max_retries = 2 + retry_count = 0 + outline_data = None + + while retry_count <= max_retries: + try: + outline_data = _parse_ai_response(ai_content, raise_on_error=True) + break + except JSONParseError: + retry_count += 1 + if retry_count > max_retries: + outline_data = _parse_ai_response(ai_content, raise_on_error=False) + break + await tracker.retry(retry_count, max_retries, "JSON解析失败") + tracker.reset_generating_progress() + accumulated_text = "" + retry_prompt = prompt + "\n\n【重要提醒】请确保返回完整的JSON数组。" + async for chunk in user_ai_service.generate_text_stream( + prompt=retry_prompt, provider=provider_param, model=model_param + ): + accumulated_text += chunk + ai_content = accumulated_text + + # ✅ P0修复:先保存新数据,再删除旧数据 + await tracker.saving("保存新大纲到数据库...", 0.2) + outlines = await _save_outlines(project_id, outline_data, db, start_index=1) + await db.commit() # 先提交新数据! + logger.info(f"✅ 新大纲已保存: {len(outlines)} 章") + + # 新数据安全后,再清理旧数据 + await tracker.saving("清理旧数据...", 0.6) + try: + from sqlalchemy import delete as sql_delete + + # 获取旧大纲(不包括刚保存的) + new_outline_ids = [o.id for o in outlines] + old_outlines_result = await db.execute( + select(Outline).where( + Outline.project_id == project_id, + ~Outline.id.in_(new_outline_ids) + ) + ) + old_outlines = old_outlines_result.scalars().all() + + if old_outlines: + old_outline_ids = [o.id for o in old_outlines] + + # 清理旧章节 + old_chapters_result = await db.execute( + select(Chapter).where( + Chapter.project_id == project_id, + ~Chapter.id.in_([ch.id for ch in await db.execute( + select(Chapter).where(Chapter.outline_id.in_(new_outline_ids) if new_outline_ids else False) + ).scalars().all()] if new_outline_ids else []) + ) + ) + # 简化:删除不属于新大纲的旧章节 + # 先获取新大纲对应的章节(one-to-one模式下通过chapter_number匹配) + new_order_indexes = [o.order_index for o in outlines] + + if project.outline_mode == 'one-to-one': + old_chapters_result = await db.execute( + select(Chapter).where( + Chapter.project_id == project_id, + ~Chapter.chapter_number.in_(new_order_indexes) + ) + ) + else: + old_chapters_result = await db.execute( + select(Chapter).where( + Chapter.project_id == project_id, + Chapter.outline_id.in_(old_outline_ids) + ) + ) + + old_chapters = old_chapters_result.scalars().all() + deleted_word_count = sum(ch.word_count or 0 for ch in old_chapters) + + # 清理伏笔和记忆 + for ch in old_chapters: + try: + await memory_service.delete_chapter_memories( + user_id=user_id_for_mcp, project_id=project_id, chapter_id=ch.id + ) + except Exception: + pass + try: + await foreshadow_service.delete_chapter_foreshadows( + db=db, project_id=project_id, chapter_id=ch.id, only_analysis_source=True + ) + except Exception: + pass + + # 删除旧章节 + if project.outline_mode == 'one-to-one': + await db.execute( + sql_delete(Chapter).where( + Chapter.project_id == project_id, + ~Chapter.chapter_number.in_(new_order_indexes) + ) + ) + else: + await db.execute( + sql_delete(Chapter).where(Chapter.outline_id.in_(old_outline_ids)) + ) + + if deleted_word_count > 0: + project.current_words = max(0, project.current_words - deleted_word_count) + + # 清理伏笔 + try: + await foreshadow_service.clear_project_foreshadows_for_reset(db, project_id) + except Exception: + pass + + # 清理分析 + try: + from app.models.memory import PlotAnalysis + await db.execute(sql_delete(PlotAnalysis).where(PlotAnalysis.project_id == project_id)) + except Exception: + pass + + # 删除旧大纲 + await db.execute( + sql_delete(Outline).where(Outline.id.in_(old_outline_ids)) + ) + + await db.commit() + logger.info(f"✅ 旧数据清理完成: 删除 {len(old_outlines)} 个旧大纲, {len(old_chapters)} 个旧章节") + + except Exception as e: + logger.error(f"❌ 清理旧数据失败(新数据已安全保存): {e}") + # 新数据已保存,旧数据清理失败不影响 + + # 角色校验 + await tracker.saving("🎭 校验角色信息...", 0.7) + try: + await _check_and_create_missing_characters_from_outlines( + outline_data=outline_data, project_id=project_id, db=db, + user_ai_service=user_ai_service, user_id=data.get("user_id"), + enable_mcp=data.get("enable_mcp", True), tracker=tracker + ) + except Exception: + pass + + # 组织校验 + try: + await _check_and_create_missing_organizations_from_outlines( + outline_data=outline_data, project_id=project_id, db=db, + user_ai_service=user_ai_service, user_id=data.get("user_id"), + enable_mcp=data.get("enable_mcp", True), tracker=tracker + ) + except Exception: + pass + + # 保存结果到任务记录 + result_data = { + "message": f"成功生成{len(outlines)}章大纲", + "total_chapters": len(outlines), + "outline_ids": [o.id for o in outlines] + } + + # 更新任务结果 + from app.models.background_task import BackgroundTask + task_result = await db.execute(select(BackgroundTask).where(BackgroundTask.id == tracker.task_id)) + bg_task = task_result.scalar_one_or_none() + if bg_task: + bg_task.task_result = result_data + await db.commit() + + await tracker.complete(f"成功生成{len(outlines)}章大纲") + logger.info(f"✅ 后台大纲生成完成: {len(outlines)} 章") + + +async def _run_continue_outline_bg( + data: Dict[str, Any], + db: AsyncSession, + user_ai_service: AIService, + tracker, + user_id: str +): + """后台执行大纲续写""" + project_id = data.get("project_id") + total_chapters = int(data.get("chapter_count", 5)) + + await tracker.loading("加载项目信息...", 0.2) + result = await db.execute(select(Project).where(Project.id == project_id)) + project = result.scalar_one_or_none() + if not project: + await tracker.error("项目不存在") + return + + existing_result = await db.execute( + select(Outline).where(Outline.project_id == project_id).order_by(Outline.order_index) + ) + existing_outlines = existing_result.scalars().all() + if not existing_outlines: + await tracker.error("续写模式需要已有大纲") + return + + last_chapter_number = existing_outlines[-1].order_index + + characters_result = await db.execute(select(Character).where(Character.project_id == project_id)) + characters = characters_result.scalars().all() + + batch_size = 5 + total_batches = (total_chapters + batch_size - 1) // batch_size + all_new_outlines = [] + current_start_chapter = last_chapter_number + 1 + + stage_instructions = { + "development": "继续展开情节,深化角色关系", + "climax": "进入故事高潮,矛盾激化", + "ending": "解决主要冲突,给出结局" + } + stage_instruction = stage_instructions.get(data.get("plot_stage", "development"), "") + + for batch_num in range(total_batches): + if await tracker.check_cancelled(): + await tracker.error("任务已取消") + return + + remaining = total_chapters - len(all_new_outlines) + current_batch_size = min(batch_size, remaining) + tracker.reset_generating_progress() + + await tracker.generating( + message=f"📝 第{batch_num + 1}/{total_batches}批: 生成第{current_start_chapter}-{current_start_chapter + current_batch_size - 1}章" + ) + + latest_result = await db.execute( + select(Outline).where(Outline.project_id == project_id).order_by(Outline.order_index) + ) + latest_outlines = latest_result.scalars().all() + + context = await _build_outline_continue_context( + project=project, latest_outlines=latest_outlines, characters=characters, + chapter_count=current_batch_size, + plot_stage=data.get("plot_stage", "development"), + story_direction=data.get("story_direction", "自然延续"), + requirements=data.get("requirements", ""), db=db + ) + + user_ai_service.user_id = user_id + user_ai_service.db_session = db + + # 获取伏笔提醒 + foreshadow_reminders_text = "暂无需要关注的伏笔" + try: + foreshadow_context = await foreshadow_service.build_chapter_context( + db=db, project_id=project_id, chapter_number=current_start_chapter, + include_pending=False, include_overdue=True, lookahead=10 + ) + if foreshadow_context and foreshadow_context.get("context_text"): + foreshadow_reminders_text = foreshadow_context["context_text"] + except Exception: + pass + + template = await PromptService.get_template("OUTLINE_CONTINUE", user_id, db) + prompt = PromptService.format_prompt( + template, + title=project.title, theme=project.theme or "未设定", + genre=project.genre or "通用", + narrative_perspective=project.narrative_perspective or "第三人称", + time_period=project.world_time_period or "未设定", + location=project.world_location or "未设定", + atmosphere=project.world_atmosphere or "未设定", + rules=project.world_rules or "未设定", + recent_outlines=context['recent_outlines'], + characters_info=context['characters_info'], + foreshadow_reminders=foreshadow_reminders_text, + chapter_count=current_batch_size, + start_chapter=current_start_chapter, + end_chapter=current_start_chapter + current_batch_size - 1, + current_chapter_count=len(latest_outlines), + plot_stage_instruction=stage_instruction, + story_direction=data.get("story_direction", "自然延续"), + requirements=data.get("requirements", ""), + mcp_references="" + ) + + accumulated_text = "" + chunk_count = 0 + estimated_chars = current_batch_size * 1000 + + async for chunk in user_ai_service.generate_text_stream( + prompt=prompt, provider=data.get("provider"), model=data.get("model") + ): + chunk_count += 1 + accumulated_text += chunk + if chunk_count % 10 == 0: + await tracker.generating( + current_chars=len(accumulated_text), estimated_total=estimated_chars, + message=f"📝 第{batch_num + 1}/{total_batches}批生成中..." + ) + + await tracker.parsing(f"解析第{batch_num + 1}批数据...") + + # 解析 + max_retries = 2 + retry_count = 0 + outline_data = None + while retry_count <= max_retries: + try: + outline_data = _parse_ai_response(accumulated_text, raise_on_error=True) + break + except JSONParseError: + retry_count += 1 + if retry_count > max_retries: + outline_data = _parse_ai_response(accumulated_text, raise_on_error=False) + break + await tracker.retry(retry_count, max_retries, "JSON解析失败") + tracker.reset_generating_progress() + accumulated_text = "" + retry_prompt = prompt + "\n\n【重要提醒】请确保返回完整的JSON数组。" + async for chunk in user_ai_service.generate_text_stream( + prompt=retry_prompt, provider=data.get("provider"), model=data.get("model") + ): + accumulated_text += chunk + + # 保存当前批次 + await tracker.saving(f"保存第{batch_num + 1}批大纲...", 0.5) + batch_outlines = await _save_outlines( + project_id, outline_data, db, start_index=current_start_chapter + ) + await db.commit() + all_new_outlines.extend(batch_outlines) + current_start_chapter += current_batch_size + + # 角色校验 + try: + await _check_and_create_missing_characters_from_outlines( + outline_data=outline_data, project_id=project_id, db=db, + user_ai_service=user_ai_service, user_id=user_id, + enable_mcp=data.get("enable_mcp", True), tracker=tracker + ) + await db.commit() + except Exception: + pass + + # 保存结果 + result_data = { + "message": f"成功续写{len(all_new_outlines)}章大纲", + "total_chapters": len(all_new_outlines), + "outline_ids": [o.id for o in all_new_outlines] + } + from app.models.background_task import BackgroundTask + task_result = await db.execute(select(BackgroundTask).where(BackgroundTask.id == tracker.task_id)) + bg_task = task_result.scalar_one_or_none() + if bg_task: + bg_task.task_result = result_data + await db.commit() + + await tracker.complete(f"成功续写{len(all_new_outlines)}章大纲") + logger.info(f"✅ 后台大纲续写完成: {len(all_new_outlines)} 章") + + @router.post("/generate-stream", summary="AI生成/续写大纲(SSE流式)") async def generate_outline_stream( data: Dict[str, Any], diff --git a/backend/app/api/settings.py b/backend/app/api/settings.py index 13ec0bb..4ad7f4a 100644 --- a/backend/app/api/settings.py +++ b/backend/app/api/settings.py @@ -160,6 +160,44 @@ async def get_user_ai_service( ) +async def get_user_ai_service_from_db(user_id: str, db: AsyncSession) -> AIService: + """ + 从数据库直接创建用户AI服务实例(用于后台任务,不依赖FastAPI的Depends) + """ + from app.models.mcp_plugin import MCPPlugin + + result = await db.execute( + select(Settings).where(Settings.user_id == user_id) + ) + settings = result.scalar_one_or_none() + + if not settings: + env_defaults = read_env_defaults() + settings = Settings(user_id=user_id, **env_defaults) + db.add(settings) + await db.commit() + await db.refresh(settings) + + mcp_result = await db.execute( + select(MCPPlugin).where(MCPPlugin.user_id == user_id) + ) + mcp_plugins = mcp_result.scalars().all() + enable_mcp = any(plugin.enabled for plugin in mcp_plugins) if mcp_plugins else False + + return create_user_ai_service_with_mcp( + api_provider=settings.api_provider, + api_key=settings.api_key, + api_base_url=settings.api_base_url or "", + model_name=settings.llm_model, + temperature=settings.temperature, + max_tokens=settings.max_tokens, + user_id=user_id, + db_session=db, + system_prompt=settings.system_prompt, + enable_mcp=enable_mcp, + ) + + @router.get("", response_model=SettingsResponse) async def get_settings( user: User = Depends(require_login), diff --git a/backend/app/api/tasks.py b/backend/app/api/tasks.py new file mode 100644 index 0000000..6c6f092 --- /dev/null +++ b/backend/app/api/tasks.py @@ -0,0 +1,122 @@ +"""后台任务API - 查询状态、取消任务""" +from fastapi import APIRouter, Depends, HTTPException, Request +from sqlalchemy.ext.asyncio import AsyncSession +from typing import Optional + +from app.database import get_db +from app.models.background_task import BackgroundTask +from app.services.background_task_service import background_task_service +from app.logger import get_logger + +router = APIRouter(prefix="/tasks", tags=["后台任务"]) +logger = get_logger(__name__) + + +@router.get("/{task_id}", summary="获取任务状态") +async def get_task_status( + task_id: str, + request: Request, + db: AsyncSession = Depends(get_db) +): + """获取后台任务的状态和进度""" + user_id = getattr(request.state, 'user_id', None) + if not user_id: + raise HTTPException(status_code=401, detail="未登录") + + task = await background_task_service.get_task(task_id, user_id, db) + if not task: + raise HTTPException(status_code=404, detail="任务不存在") + + return { + "id": task.id, + "task_type": task.task_type, + "project_id": task.project_id, + "status": task.status, + "progress": task.progress, + "status_message": task.status_message, + "progress_details": task.progress_details, + "error_message": task.error_message, + "task_result": task.task_result, + "retry_count": task.retry_count, + "cancel_requested": task.cancel_requested, + "created_at": task.created_at.isoformat() if task.created_at else None, + "started_at": task.started_at.isoformat() if task.started_at else None, + "completed_at": task.completed_at.isoformat() if task.completed_at else None, + "updated_at": task.updated_at.isoformat() if task.updated_at else None, + } + + +@router.get("", summary="获取任务列表") +async def get_tasks( + project_id: str, + request: Request, + task_type: Optional[str] = None, + limit: int = 20, + db: AsyncSession = Depends(get_db) +): + """获取项目的后台任务列表""" + user_id = getattr(request.state, 'user_id', None) + if not user_id: + raise HTTPException(status_code=401, detail="未登录") + + tasks = await background_task_service.get_project_tasks( + project_id, user_id, db, task_type=task_type, limit=limit + ) + + return { + "items": [ + { + "id": t.id, + "task_type": t.task_type, + "status": t.status, + "progress": t.progress, + "status_message": t.status_message, + "progress_details": t.progress_details, + "error_message": t.error_message, + "created_at": t.created_at.isoformat() if t.created_at else None, + "completed_at": t.completed_at.isoformat() if t.completed_at else None, + } + for t in tasks + ] + } + + +@router.post("/{task_id}/cancel", summary="取消任务") +async def cancel_task( + task_id: str, + request: Request, + db: AsyncSession = Depends(get_db) +): + """请求取消后台任务""" + user_id = getattr(request.state, 'user_id', None) + if not user_id: + raise HTTPException(status_code=401, detail="未登录") + + success = await background_task_service.cancel_task(task_id, user_id, db) + if not success: + raise HTTPException(status_code=400, detail="无法取消任务(不存在或已完成)") + + return {"message": "任务已取消", "task_id": task_id} + + +@router.delete("/{task_id}", summary="删除任务记录") +async def delete_task( + task_id: str, + request: Request, + db: AsyncSession = Depends(get_db) +): + """删除已完成/失败的任务记录""" + user_id = getattr(request.state, 'user_id', None) + if not user_id: + raise HTTPException(status_code=401, detail="未登录") + + task = await background_task_service.get_task(task_id, user_id, db) + if not task: + raise HTTPException(status_code=404, detail="任务不存在") + + if task.status in ("pending", "running"): + raise HTTPException(status_code=400, detail="无法删除进行中的任务,请先取消") + + await db.delete(task) + await db.commit() + return {"message": "任务记录已删除"} \ No newline at end of file diff --git a/backend/app/database.py b/backend/app/database.py index 31fc56e..66a32eb 100644 --- a/backend/app/database.py +++ b/backend/app/database.py @@ -21,7 +21,8 @@ from app.models import ( Settings, WritingStyle, ProjectDefaultStyle, RelationshipType, CharacterRelationship, Organization, OrganizationMember, StoryMemory, PlotAnalysis, AnalysisTask, BatchGenerationTask, - RegenerationTask, Career, CharacterCareer, User, MCPPlugin, PromptTemplate + RegenerationTask, Career, CharacterCareer, User, MCPPlugin, PromptTemplate, + BackgroundTask ) # 引擎缓存:每个用户一个引擎 diff --git a/backend/app/main.py b/backend/app/main.py index 7006b1e..5e63598 100644 --- a/backend/app/main.py +++ b/backend/app/main.py @@ -29,7 +29,21 @@ async def lifespan(app: FastAPI): """应用生命周期管理""" # 注册MCP状态同步服务 register_status_sync() - + + # 安全保障:确保后台任务表存在(兼容未执行Alembic迁移的旧部署) + try: + from app.database import get_engine + from app.models.background_task import BackgroundTask + _startup_engine = await get_engine("system") + async with _startup_engine.begin() as conn: + # 仅创建 background_tasks 表(如果不存在),不影响其他表 + await conn.run_sync( + lambda sync_conn: BackgroundTask.__table__.create(sync_conn, checkfirst=True) + ) + logger.info("后台任务表检查完成") + except Exception as e: + logger.warning(f"后台任务表检查失败(不影响启动): {e}") + logger.info("应用启动完成") yield @@ -133,7 +147,7 @@ from app.api import ( auth, users, settings, writing_styles, memories, mcp_plugins, admin, inspiration, prompt_templates, changelog, careers, foreshadows, prompt_workshop, book_import, - project_covers + project_covers, tasks ) app.include_router(auth.router, prefix="/api") @@ -159,6 +173,7 @@ app.include_router(prompt_templates.router, prefix="/api") # 提示词模板管 app.include_router(changelog.router, prefix="/api") # 更新日志API app.include_router(prompt_workshop.router, prefix="/api") # 提示词工坊API app.include_router(book_import.router, prefix="/api") # 拆书导入API +app.include_router(tasks.router, prefix="/api") # 后台任务API static_dir = Path(__file__).parent.parent / "static" generated_assets_root_dir = Path(__file__).parent.parent / "storage" diff --git a/backend/app/models/__init__.py b/backend/app/models/__init__.py index 7e9a442..8f19b11 100644 --- a/backend/app/models/__init__.py +++ b/backend/app/models/__init__.py @@ -18,6 +18,7 @@ from app.models.career import Career, CharacterCareer from app.models.prompt_template import PromptTemplate from app.models.foreshadow import Foreshadow from app.models.prompt_workshop import PromptWorkshopItem, PromptSubmission, PromptWorkshopLike +from app.models.background_task import BackgroundTask __all__ = [ "Project", @@ -46,5 +47,6 @@ __all__ = [ "Foreshadow", "PromptWorkshopItem", "PromptSubmission", - "PromptWorkshopLike" -] \ No newline at end of file + "PromptWorkshopLike", + "BackgroundTask" +] diff --git a/backend/app/models/background_task.py b/backend/app/models/background_task.py new file mode 100644 index 0000000..0c9d353 --- /dev/null +++ b/backend/app/models/background_task.py @@ -0,0 +1,46 @@ +"""后台任务数据模型 - 用于长时间运行的AI生成任务""" +from sqlalchemy import Column, String, Integer, DateTime, Boolean, JSON, Text +from sqlalchemy.sql import func +from app.database import Base +import uuid + + +class BackgroundTask(Base): + """后台任务表 - 追踪所有长时间运行的生成任务""" + __tablename__ = "background_tasks" + + id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4())) + user_id = Column(String(100), nullable=False, index=True, comment="用户ID") + project_id = Column(String(36), nullable=False, index=True, comment="项目ID") + + # 任务类型 + task_type = Column(String(50), nullable=False, comment="任务类型: outline_new/outline_continue/outline_expand/chapter_generate/chapter_batch/wizard") + + # 任务状态 + status = Column(String(20), default="pending", comment="任务状态: pending/running/completed/failed/cancelled") + progress = Column(Integer, default=0, comment="进度百分比(0-100)") + status_message = Column(String(500), comment="当前状态消息") + + # 任务输入/输出 + task_input = Column(JSON, comment="任务输入参数(JSON)") + task_result = Column(JSON, comment="任务结果(JSON)") + error_message = Column(Text, comment="错误信息") + + # 进度详情(用于前端展示实时进度) + progress_details = Column(JSON, comment="进度详情: {stage, message, word_count, etc.}") + + # 取消支持 + cancel_requested = Column(Boolean, default=False, comment="是否请求取消") + + # 重试信息 + retry_count = Column(Integer, default=0, comment="已重试次数") + max_retries = Column(Integer, default=3, comment="最大重试次数") + + # 时间记录 + created_at = Column(DateTime, server_default=func.now(), comment="创建时间") + started_at = Column(DateTime, comment="开始时间") + completed_at = Column(DateTime, comment="完成时间") + updated_at = Column(DateTime, server_default=func.now(), onupdate=func.now(), comment="更新时间") + + def __repr__(self): + return f"" \ No newline at end of file diff --git a/backend/app/schemas/chapter.py b/backend/app/schemas/chapter.py index 24b12f6..d13a2f7 100644 --- a/backend/app/schemas/chapter.py +++ b/backend/app/schemas/chapter.py @@ -138,6 +138,7 @@ class BatchGenerateRequest(BaseModel): enable_mcp: bool = Field(True, description="是否启用MCP工具增强(搜索参考资料)") max_retries: int = Field(3, description="每个章节的最大重试次数", ge=0, le=5) model: Optional[str] = Field(None, description="指定使用的AI模型,不提供则使用用户默认模型") + narrative_perspective: Optional[str] = Field(None, description="临时指定叙事人称,不提供则使用项目默认") class BatchGenerateResponse(BaseModel): diff --git a/backend/app/services/background_task_service.py b/backend/app/services/background_task_service.py new file mode 100644 index 0000000..18655dc --- /dev/null +++ b/backend/app/services/background_task_service.py @@ -0,0 +1,387 @@ +"""后台任务管理服务 - 管理长时间运行的AI生成任务""" +import asyncio +from datetime import datetime +from typing import Dict, Any, Optional, Callable, Awaitable +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker +from sqlalchemy import select, update +from app.database import get_engine +from app.models.background_task import BackgroundTask +from app.logger import get_logger + +logger = get_logger(__name__) + + +class TaskProgressTracker: + """后台任务进度追踪器(替代SSE的WizardProgressTracker)""" + + def __init__(self, task_id: str, user_id: str, task_name: str = "任务"): + self.task_id = task_id + self.user_id = user_id + self.task_name = task_name + self.current_progress = 0 + self._last_generating_progress = 20 + + async def _update_task(self, **kwargs): + """更新任务状态到数据库""" + try: + engine = await get_engine(self.user_id) + AsyncSessionLocal = async_sessionmaker( + engine, class_=AsyncSession, expire_on_commit=False + ) + async with AsyncSessionLocal() as session: + result = await session.execute( + select(BackgroundTask).where(BackgroundTask.id == self.task_id) + ) + task = result.scalar_one_or_none() + if task: + for key, value in kwargs.items(): + setattr(task, key, value) + task.updated_at = datetime.now() + await session.commit() + except Exception as e: + logger.error(f"❌ 更新任务进度失败: {e}") + + async def start(self, message: str = None): + self.current_progress = 0 + msg = message or f"开始生成{self.task_name}..." + await self._update_task( + status="running", progress=0, status_message=msg, + started_at=datetime.now(), + progress_details={"stage": "init", "message": msg} + ) + + async def loading(self, message: str = None, sub_progress: float = 0.5): + progress = 5 + int(10 * sub_progress) + self.current_progress = progress + msg = message or "加载数据中..." + await self._update_task( + progress=progress, status_message=msg, + progress_details={"stage": "loading", "message": msg} + ) + + async def preparing(self, message: str = None): + self.current_progress = 17 + msg = message or "准备AI提示词..." + await self._update_task( + progress=17, status_message=msg, + progress_details={"stage": "preparing", "message": msg} + ) + + async def generating(self, current_chars: int = 0, estimated_total: int = 5000, + message: str = None, retry_count: int = 0, max_retries: int = 3): + sub_progress = min(current_chars / max(estimated_total, 1), 1.0) + progress = 20 + int(65 * sub_progress) + if progress < self._last_generating_progress: + progress = self._last_generating_progress + else: + self._last_generating_progress = progress + self.current_progress = progress + + retry_suffix = f" (重试 {retry_count}/{max_retries})" if retry_count > 0 else "" + msg = message or f"生成{self.task_name}中... ({current_chars}字符){retry_suffix}" + await self._update_task( + progress=progress, status_message=msg, + progress_details={"stage": "generating", "message": msg, "current_chars": current_chars} + ) + + async def parsing(self, message: str = None): + self.current_progress = 88 + msg = message or f"解析{self.task_name}数据..." + await self._update_task( + progress=88, status_message=msg, + progress_details={"stage": "parsing", "message": msg} + ) + + async def saving(self, message: str = None, sub_progress: float = 0.5): + progress = 92 + int(6 * sub_progress) + self.current_progress = progress + msg = message or f"保存{self.task_name}到数据库..." + await self._update_task( + progress=progress, status_message=msg, + progress_details={"stage": "saving", "message": msg} + ) + + async def complete(self, message: str = None): + self.current_progress = 100 + msg = message or f"{self.task_name}生成完成!" + await self._update_task( + status="completed", progress=100, status_message=msg, + completed_at=datetime.now(), + progress_details={"stage": "complete", "message": msg} + ) + + async def error(self, error_message: str): + await self._update_task( + status="failed", error_message=error_message, + status_message=f"失败: {error_message}", + completed_at=datetime.now(), + progress_details={"stage": "error", "message": error_message} + ) + + async def warning(self, message: str): + await self._update_task( + status_message=f"⚠️ {message}", + progress_details={"stage": "warning", "message": message} + ) + + async def retry(self, retry_count: int, max_retries: int, reason: str = "准备重试"): + msg = f"⚠️ {reason}... ({retry_count}/{max_retries})" + await self._update_task( + status_message=msg, retry_count=retry_count, + progress_details={"stage": "retry", "message": msg, "retry_count": retry_count} + ) + + def reset_generating_progress(self): + self._last_generating_progress = 20 + + async def check_cancelled(self) -> bool: + """检查任务是否被取消""" + try: + engine = await get_engine(self.user_id) + AsyncSessionLocal = async_sessionmaker( + engine, class_=AsyncSession, expire_on_commit=False + ) + async with AsyncSessionLocal() as session: + result = await session.execute( + select(BackgroundTask.cancel_requested) + .where(BackgroundTask.id == self.task_id) + ) + cancelled = result.scalar_one_or_none() + return bool(cancelled) + except Exception: + return False + + +class BackgroundTaskService: + """后台任务管理服务(按用户排队:同用户任务逐个执行,不同用户可并发)""" + + def __init__(self): + self._user_queues: Dict[str, asyncio.Queue] = {} # user_id -> Queue + self._user_workers: Dict[str, bool] = {} # user_id -> worker是否运行中 + + def _ensure_user_queue(self, user_id: str) -> asyncio.Queue: + """确保指定用户的队列已初始化""" + if user_id not in self._user_queues: + self._user_queues[user_id] = asyncio.Queue() + return self._user_queues[user_id] + + async def _start_user_worker(self, user_id: str): + """启动指定用户的工作协程""" + if self._user_workers.get(user_id, False): + return + self._user_workers[user_id] = True + asyncio.create_task(self._user_worker_loop(user_id)) + logger.info(f"📋 用户 {user_id[:8]} 的任务队列工作协程已启动") + + async def _user_worker_loop(self, user_id: str): + """从指定用户的队列中逐个取出任务并执行""" + queue = self._user_queues[user_id] + try: + while True: + try: + task_item = await queue.get() + task_id = task_item["task_id"] + task_func = task_item["task_func"] + args = task_item["args"] + kwargs = task_item["kwargs"] + + logger.info(f"🔄 [用户{user_id[:8]}] 队列开始执行任务: {task_id[:8]} (队列剩余: {queue.qsize()})") + + try: + await task_func(task_id, args["user_id"], *args["extra_args"], **kwargs) + except Exception as e: + logger.error(f"❌ 后台任务 {task_id[:8]} 异常: {e}", exc_info=True) + # 确保任务状态更新为失败 + try: + engine = await get_engine(user_id) + AsyncSessionLocal = async_sessionmaker( + engine, class_=AsyncSession, expire_on_commit=False + ) + async with AsyncSessionLocal() as session: + result = await session.execute( + select(BackgroundTask).where(BackgroundTask.id == task_id) + ) + task = result.scalar_one_or_none() + if task and task.status == "running": + task.status = "failed" + task.error_message = str(e) + task.status_message = f"任务失败: {str(e)}" + task.completed_at = datetime.now() + await session.commit() + except Exception as update_err: + logger.error(f"❌ 更新失败任务状态失败: {update_err}") + finally: + queue.task_done() + logger.info(f"✅ [用户{user_id[:8]}] 队列任务完成: {task_id[:8]} (队列剩余: {queue.qsize()})") + + except Exception as e: + logger.error(f"❌ [用户{user_id[:8]}] 队列工作循环异常: {e}", exc_info=True) + finally: + # 工作协程退出时清理标记 + self._user_workers.pop(user_id, None) + logger.info(f"📋 用户 {user_id[:8]} 的工作协程已退出") + + @staticmethod + async def create_task( + user_id: str, + project_id: str, + task_type: str, + task_input: Dict[str, Any] = None, + db: AsyncSession = None + ) -> BackgroundTask: + """创建后台任务记录""" + task = BackgroundTask( + user_id=user_id, + project_id=project_id, + task_type=task_type, + task_input=task_input or {}, + status="pending", + progress=0, + status_message="任务已创建,等待执行..." + ) + db.add(task) + await db.commit() + await db.refresh(task) + logger.info(f"📋 创建后台任务: {task.id[:8]} type={task_type} project={project_id[:8]}") + return task + + @staticmethod + async def get_task(task_id: str, user_id: str, db: AsyncSession) -> Optional[BackgroundTask]: + """获取任务详情""" + result = await db.execute( + select(BackgroundTask).where( + BackgroundTask.id == task_id, + BackgroundTask.user_id == user_id + ) + ) + return result.scalar_one_or_none() + + @staticmethod + async def get_project_tasks( + project_id: str, user_id: str, db: AsyncSession, + task_type: str = None, limit: int = 20 + ) -> list: + """获取项目的任务列表""" + query = ( + select(BackgroundTask) + .where( + BackgroundTask.project_id == project_id, + BackgroundTask.user_id == user_id + ) + .order_by(BackgroundTask.created_at.desc()) + ) + if task_type: + query = query.where(BackgroundTask.task_type == task_type) + query = query.limit(limit) + result = await db.execute(query) + return result.scalars().all() + + @staticmethod + async def cancel_task(task_id: str, user_id: str, db: AsyncSession) -> bool: + """请求取消任务""" + result = await db.execute( + select(BackgroundTask).where( + BackgroundTask.id == task_id, + BackgroundTask.user_id == user_id + ) + ) + task = result.scalar_one_or_none() + if not task: + return False + if task.status not in ("pending", "running"): + return False + task.cancel_requested = True + task.status = "cancelled" + task.status_message = "任务已取消" + task.completed_at = datetime.now() + await db.commit() + logger.info(f"🚫 取消任务: {task_id[:8]}") + return True + + @staticmethod + async def cleanup_old_tasks(user_id: str, db: AsyncSession, days: int = 7): + """清理旧任务记录""" + from sqlalchemy import delete as sql_delete + from datetime import timedelta + cutoff = datetime.now() - timedelta(days=days) + result = await db.execute( + sql_delete(BackgroundTask).where( + BackgroundTask.user_id == user_id, + BackgroundTask.status.in_(["completed", "failed", "cancelled"]), + BackgroundTask.completed_at < cutoff + ) + ) + if result.rowcount > 0: + await db.commit() + logger.info(f"🧹 清理用户 {user_id[:8]} 的 {result.rowcount} 条旧任务记录") + + async def spawn_background_task( + self, + task_id: str, + user_id: str, + task_func: Callable[..., Awaitable], + *args, + **kwargs + ): + """ + 将任务加入该用户的队列排队执行(同一用户FIFO,不同用户可并发) + + Args: + task_id: 任务ID + user_id: 用户ID + task_func: 异步任务函数 + *args, **kwargs: 传递给task_func的参数 + """ + # 确保该用户的队列和工作协程已启动 + queue = self._ensure_user_queue(user_id) + await self._start_user_worker(user_id) + + # 将任务放入该用户的队列 + await queue.put({ + "task_id": task_id, + "task_func": task_func, + "args": {"user_id": user_id, "extra_args": args}, + "kwargs": kwargs, + }) + queue_size = queue.qsize() + logger.info(f"📥 任务已加入用户 {user_id[:8]} 的队列: {task_id[:8]} (当前队列长度: {queue_size})") + + # 更新任务状态,显示排队位置 + try: + engine = await get_engine(user_id) + AsyncSessionLocal = async_sessionmaker( + engine, class_=AsyncSession, expire_on_commit=False + ) + async with AsyncSessionLocal() as session: + result = await session.execute( + select(BackgroundTask).where(BackgroundTask.id == task_id) + ) + task = result.scalar_one_or_none() + if task and task.status == "pending": + if queue_size > 0: + task.status_message = f"排队中,前方还有 {queue_size} 个任务等待..." + else: + task.status_message = "即将开始执行..." + task.progress_details = {"stage": "queued", "queue_size": queue_size} + task.updated_at = datetime.now() + await session.commit() + except Exception as e: + logger.error(f"更新队列位置信息失败: {e}") + + def get_queue_size(self, user_id: str = None) -> int: + """获取队列中等待的任务数量""" + if user_id: + queue = self._user_queues.get(user_id) + return queue.qsize() if queue else 0 + # 所有用户队列总数 + return sum(q.qsize() for q in self._user_queues.values()) + + def get_all_queue_info(self) -> Dict[str, int]: + """获取所有用户的队列信息""" + return { + uid: q.qsize() for uid, q in self._user_queues.items() if q.qsize() > 0 + } + + +# 全局单例 +background_task_service = BackgroundTaskService() \ No newline at end of file diff --git a/backend/app/services/foreshadow_service.py b/backend/app/services/foreshadow_service.py index 555e47c..50129a2 100644 --- a/backend/app/services/foreshadow_service.py +++ b/backend/app/services/foreshadow_service.py @@ -1284,7 +1284,7 @@ class ForeshadowService: planted_foreshadows = await self.get_planted_foreshadows_for_analysis(db, project_id) # 每章最多创建的新伏笔数量 - MAX_NEW_FORESHADOWS_PER_CHAPTER = 2 + MAX_NEW_FORESHADOWS_PER_CHAPTER = 5 new_foreshadow_count = 0 for fs_data in analysis_foreshadows: diff --git a/backend/app/services/json_helper.py b/backend/app/services/json_helper.py index 318a03e..44e33db 100644 --- a/backend/app/services/json_helper.py +++ b/backend/app/services/json_helper.py @@ -26,15 +26,101 @@ _QUOTE_MAP = { } +def _is_content_quote(text: str, pos: int) -> bool: + """ + 判断字符串值内的 '"' 是否为内容引号(需转义)而非 JSON 结束引号。 + + 合法 JSON 中,字符串结束引号之后的非空白字符必须是: + ',' (值分隔) / '}' (关闭对象) / ']' (关闭数组) + + 如果 '"' 后面不符合这些模式,则是 AI 写入的内容引号,需要转义。 + """ + j = pos + 1 + + # 跳过空格和制表符 + while j < len(text) and text[j] in ' \t': + j += 1 + + if j >= len(text): + return False # 文本末尾,视为结束引号 + + ch = text[j] + + # } 或 ] → 结束引号 + if ch in ('}', ']'): + return False + + # 换行 → 检查下一行开头判断 + if ch == '\n' or ch == '\r': + k = j + (2 if (ch == '\r' and j + 1 < len(text) and text[j + 1] == '\n') else 1) + while k < len(text) and text[k] in ' \t': + k += 1 + if k >= len(text): + return False + # 下一行以 " (JSON key) 或 } 或 ] 开头 → 结束引号 + if text[k] == '"' or text[k] in ('}', ']'): + return False + return True + + # , → 需要检查逗号后面是什么 + if ch == ',': + k = j + 1 + while k < len(text) and text[k] in ' \t': + k += 1 + + if k >= len(text): + return False + + # 逗号后跟换行 → 检查下一行 + if text[k] in ('\n', '\r'): + k2 = k + (2 if (text[k] == '\r' and k + 1 < len(text) and text[k + 1] == '\n') else 1) + while k2 < len(text) and text[k2] in ' \t\n\r': + k2 += 1 + if k2 >= len(text): + return False + if text[k2] == '"' or text[k2] in ('}', ']'): + return False + return True + + after_comma = text[k] + + # 结构性逗号后应为 JSON 值的开头 + if after_comma == '"': + return False # 字符串值或 key + if after_comma.isdigit() or after_comma == '-': + return False # 数字 + if after_comma in ('{', '['): + return False # 对象/数组 + if text[k:k+4] in ('true', 'null'): + return False + if text[k:k+5] == 'false': + return False + + # 逗号后不是 JSON 值开头 → 内容逗号,引号是内容引号 + return True + + # : → 通常在字符串结束后不可能出现,保守处理为结束引号 + if ch == ':': + return False + + # 其他字符(中文、字母等)→ 内容引号 + return True + + def _fix_json_string_values(text: str) -> str: """ - 修复JSON字符串值中的常见问题: - 1. 裸换行符/制表符 → 转义 - 2. 字符串值内的中文引号 → 转义为ASCII引号(避免破坏JSON结构) - 3. 结构位置的中文引号 → 直接替换为ASCII引号 + 上下文感知的 JSON 修复,区分字符串内外分别处理。 - AI生成的JSON常在字符串值中插入未转义的换行符和中文引号。 - 此函数遍历文本,区分字符串内外,分别处理。 + 字符串值内: + 1. 裸换行符/制表符 → 转义 + 2. 中文引号(""等) → 转义为 \\" + 3. 未转义的 ASCII 双引号 → 智能检测:内容引号转义,结束引号保留 + 4. 中文逗号/冒号 → 保留原样(是内容字符) + + 结构位置(字符串外): + 1. 中文引号 → ASCII 引号 + 2. 中文逗号 → ASCII 逗号 + 3. 中文冒号 → ASCII 冒号 """ if not text or '"' not in text: return text @@ -47,111 +133,234 @@ def _fix_json_string_values(text: str) -> str: while i < len(text): c = text[i] - if c == '"' and not in_string: - # 进入字符串 - in_string = True + # === 非字符串内(结构位置)=== + if not in_string: + # 结构位置的中文标点 → ASCII + if c == '\uff0c': # ,→ , + result.append(',') + fixed_count += 1 + i += 1 + continue + if c == '\uff1a': # :→ : + result.append(':') + fixed_count += 1 + i += 1 + continue + if c in _QUOTE_MAP: + result.append(_QUOTE_MAP[c]) + fixed_count += 1 + i += 1 + continue + + # ASCII 双引号 → 进入字符串 + if c == '"': + in_string = True + result.append(c) + i += 1 + continue + result.append(c) i += 1 continue - if in_string: - if c == '\\': - # 转义字符,检查下一个字符是否合法 - if i + 1 < len(text): - next_c = text[i + 1] - # JSON 合法转义:\" \\ \/ \b \f \n \r \t \uXXXX - if next_c in ('"', '\\', '/', 'b', 'f', 'n', 'r', 't'): - # 合法转义,直接保留 - result.append(c) - result.append(next_c) - i += 2 + # === 字符串值内 === + + # 转义字符处理 + if c == '\\': + if i + 1 < len(text): + next_c = text[i + 1] + if next_c in ('"', '\\', '/', 'b', 'f', 'n', 'r', 't'): + result.append(c) + result.append(next_c) + i += 2 + continue + elif next_c == 'u': + if i + 5 < len(text) and all(text[i+2+k] in '0123456789abcdefABCDEF' for k in range(4)): + result.append(text[i:i+6]) + i += 6 continue - elif next_c == 'u': - # Unicode 转义 \uXXXX,检查是否有4个十六进制字符 - if i + 5 < len(text) and all(text[i+2+k] in '0123456789abcdefABCDEF' for k in range(4)): - result.append(text[i:i+6]) - i += 6 - continue - else: - # 不完整的unicode转义,去掉反斜杠 - result.append(next_c) - fixed_count += 1 - i += 2 - continue else: - # 非法转义字符(如 \c \p \d 等),去掉反斜杠只保留字符 result.append(next_c) fixed_count += 1 i += 2 continue else: - # 末尾孤立的反斜杠,去掉 + result.append(next_c) fixed_count += 1 - i += 1 + i += 2 continue - - if c == '"': - # 字符串结束 + else: + fixed_count += 1 + i += 1 + continue + + # ASCII 双引号 → 智能判断是结束引号还是内容引号 + if c == '"': + if _is_content_quote(text, i): + # 内容引号,需要转义 + result.append('\\') + result.append('"') + fixed_count += 1 + i += 1 + continue + else: + # 结束引号 in_string = False result.append(c) i += 1 continue - - if c == '\n': - # 裸换行符 → 替换为转义换行 - result.append('\\') - result.append('n') - fixed_count += 1 - i += 1 - continue - - if c == '\r': - # 裸回车符 → 忽略或替换 - if i + 1 < len(text) and text[i + 1] == '\n': - result.append('\\') - result.append('n') - fixed_count += 1 - i += 2 - else: - result.append('\\') - result.append('n') - fixed_count += 1 - i += 1 - continue - - if c == '\t': - # 裸制表符 → 替换为转义制表符 - result.append('\\') - result.append('t') - fixed_count += 1 - i += 1 - continue - - # 字符串值内的中文引号 → 转义为 \"(避免破坏JSON结构) - if c in _QUOTE_MAP: - result.append('\\') - result.append(_QUOTE_MAP[c]) - fixed_count += 1 - i += 1 - continue - # 非字符串内的字符 - # 结构位置的中文引号 → 直接替换 - if not in_string and c in _QUOTE_MAP: - result.append(_QUOTE_MAP[c]) + # 裸换行符 → 转义 + if c == '\n': + result.append('\\') + result.append('n') fixed_count += 1 i += 1 continue + if c == '\r': + if i + 1 < len(text) and text[i + 1] == '\n': + result.append('\\') + result.append('n') + fixed_count += 1 + i += 2 + else: + result.append('\\') + result.append('n') + fixed_count += 1 + i += 1 + continue + + if c == '\t': + result.append('\\') + result.append('t') + fixed_count += 1 + i += 1 + continue + + # 中文引号处理 + if c in _QUOTE_MAP: + mapped = _QUOTE_MAP[c] + if mapped == '"': + # 中文双引号在字符串内需要转义 + result.append('\\') + result.append('"') + else: + # 中文单引号在双引号字符串内不需要转义,直接替换 + result.append(mapped) + fixed_count += 1 + i += 1 + continue + + # 其他字符(包括中文逗号、中文冒号)→ 保留原样 result.append(c) i += 1 if fixed_count > 0: - logger.debug(f"✅ 修复了{fixed_count}个JSON问题(裸控制字符/中文引号)") + logger.debug(f"✅ 修复了{fixed_count}个JSON问题(引号/控制字符/中文标点)") return ''.join(result) +def _fix_all_invalid_escapes(text: str) -> str: + """ + 兜底修复:扫描整个文本中的无效JSON转义序列。 + + 当 _fix_json_string_values 因字符串边界追踪错误而遗漏某些无效转义时, + 此函数作为兜底,不依赖字符串状态追踪,扫描整个文本修复所有无效转义。 + + 有效JSON转义:\\" \\\\ \\/ \\b \\f \\n \\r \\t \\uXXXX + 其他 \\X 均为无效转义,修复方式为去掉反斜杠只保留字符。 + """ + if '\\' not in text: + return text + + result = [] + i = 0 + fixed = 0 + + while i < len(text): + if text[i] == '\\' and i + 1 < len(text): + next_c = text[i + 1] + if next_c in ('"', '\\', '/', 'b', 'f', 'n', 'r', 't'): + # 有效转义,保留 + result.append(text[i]) + result.append(next_c) + i += 2 + continue + elif next_c == 'u': + # Unicode 转义,检查是否有4个十六进制字符 + if i + 5 < len(text) and all( + text[i + 2 + k] in '0123456789abcdefABCDEF' + for k in range(4) + ): + result.append(text[i:i + 6]) + i += 6 + continue + else: + # 不完整的unicode转义,去掉反斜杠 + result.append(next_c) + fixed += 1 + i += 2 + continue + else: + # 无效转义(如 \引 \影 \某种 等),去掉反斜杠只保留字符 + result.append(next_c) + fixed += 1 + i += 2 + continue + else: + result.append(text[i]) + i += 1 + + if fixed > 0: + logger.info(f"✅ 兜底修复了{fixed}个无效JSON转义序列") + + return ''.join(result) + + +def _fix_multiple_objects_as_value(text: str) -> str: + """ + 修复AI生成的JSON中,多个对象作为属性值但未合并的问题。 + + 示例: + "key": {"a": "1"}, {"b": "2"} → "key": {"a": "1", "b": "2"} + + AI有时在输出对象类型的属性值时,输出了多个独立的对象而不是合并为一个。 + 例如 relationship_changes 字段输出多个角色关系变化时可能出现此问题。 + 此函数检测并合并这些对象。 + """ + if '{' not in text or '}' not in text: + return text + + # 匹配嵌套层级不超过2的对象: { ... } 其中 ... 不含 { 或仅含单层嵌套 + nested_obj = r'\{(?:[^{}]|\{(?:[^{}]|\{[^{}]*\})*\})*\}' + + # 模式:属性冒号后跟一个对象,然后逗号和另一个对象(没有属性名) + # 即 "key": {obj1}, {obj2} → "key": {obj1, obj2} + pattern = r'(":)\s*(' + nested_obj + r')\s*,\s*(' + nested_obj + r')' + + def merge_objects(match): + colon = match.group(1) + obj1_content = match.group(2)[1:-1] # 去掉外层的 { } + obj2_content = match.group(3)[1:-1] # 去掉外层的 { } + # 合并为一个对象 + return f'{colon} {{{obj1_content}, {obj2_content}}}' + + prev = None + count = 0 + max_iterations = 10 + while prev != text and count < max_iterations: + prev = text + text = re.sub(pattern, merge_objects, text) + count += 1 + + if count > 1: + logger.info(f"✅ 修复了{count - 1}处多对象属性值合并") + + return text + + def clean_json_response(text: str) -> str: """清洗 AI 返回的 JSON(改进版 - 流式安全)""" try: @@ -162,11 +371,8 @@ def clean_json_response(text: str) -> str: original_length = len(text) logger.debug(f"🔍 开始清洗JSON,原始长度: {original_length}") - # 替换中文逗号/冒号(AI可能在JSON结构位置使用,全局替换是安全的) - text = text.replace('\uff0c', ',') # ,→ , - text = text.replace('\uff1a', ':') # :→ : - - # 修复JSON中的中文引号和裸控制字符(上下文感知,区分字符串内外) + # 上下文感知修复:中文引号/逗号/冒号、裸控制字符、未转义的内容引号 + # (区分字符串内外:结构位置替换为ASCII,字符串内保留或转义) text = _fix_json_string_values(text) # 去除 markdown 代码块 @@ -286,9 +492,35 @@ def clean_json_response(text: str) -> str: json.loads(result) logger.debug(f"✅ 清洗后JSON验证成功") except json.JSONDecodeError as e: - logger.error(f"❌ 清洗后JSON仍然无效: {e}") - logger.debug(f" 结果预览: {result[:500]}") - logger.debug(f" 结果结尾: ...{result[-200:]}") + logger.warning(f"⚠️ 清洗后JSON仍然无效: {e},尝试修复结构性问题...") + + # 修复1:合并多对象属性值(AI可能输出 "key": {a:1}, {b:2} ) + result = _fix_multiple_objects_as_value(result) + + try: + json.loads(result) + logger.info(f"✅ 修复多对象属性值后JSON验证成功") + except json.JSONDecodeError: + pass # 继续尝试其他修复 + else: + return result + + # 修复2:兜底修复无效转义序列(不依赖字符串边界追踪) + logger.warning(f"⚠️ 继续尝试兜底修复无效转义...") + result = _fix_all_invalid_escapes(result) + try: + json.loads(result) + logger.info(f"✅ 兜底修复后JSON验证成功") + except json.JSONDecodeError as e2: + # 修复3:再次尝试合并多对象属性值(转义修复后可能产生新的合并机会) + result = _fix_multiple_objects_as_value(result) + try: + json.loads(result) + logger.info(f"✅ 二次修复后JSON验证成功") + except json.JSONDecodeError as e3: + logger.error(f"❌ 所有修复后JSON仍然无效: {e3}") + logger.debug(f" 结果预览: {result[:500]}") + logger.debug(f" 结果结尾: ...{result[-200:]}") return result @@ -339,6 +571,16 @@ def loads_json(text: str) -> Any: except (json.JSONDecodeError, Exception): pass + # 兜底修复无效转义序列后重试 + fixed_text = _fix_all_invalid_escapes(text) + if fixed_text != text: + try: + result = json.loads(fixed_text) + logger.info("✅ 兜底修复无效转义后json.loads成功") + return result + except (json.JSONDecodeError, Exception): + pass + # json5 容错解析 if HAS_JSON5: try: @@ -347,6 +589,14 @@ def loads_json(text: str) -> Any: logger.info("✅ json5容错解析成功") return result except Exception as e5: + # json5也失败,尝试对修复后的文本使用json5 + if fixed_text != text: + try: + result = json5.loads(fixed_text) + logger.info("✅ 兜底修复无效转义后json5容错解析成功") + return result + except Exception: + pass logger.error(f"❌ json5容错解析也失败: {e5}") # 最终失败,抛出标准异常 diff --git a/docker-compose.yml b/docker-compose.yml index 8f871e2..d5e9962 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -53,7 +53,7 @@ services: build: context: . dockerfile: Dockerfile - image: mumujie/mumuainovel:latest + image: mumujie/mumuainovel:dev container_name: mumuainovel depends_on: postgres: diff --git a/frontend/pnpm-lock.yaml b/frontend/pnpm-lock.yaml index 7173b4a..ba8b9bc 100644 --- a/frontend/pnpm-lock.yaml +++ b/frontend/pnpm-lock.yaml @@ -23,6 +23,9 @@ importers: '@types/canvas-confetti': specifier: ^1.9.0 version: 1.9.0 + '@types/dagre': + specifier: ^0.7.54 + version: 0.7.54 '@xyflow/react': specifier: ^12.10.1 version: 12.10.1(@types/react@18.3.28)(react-dom@18.3.1(react@18.3.1))(react@18.3.1) @@ -35,6 +38,9 @@ importers: canvas-confetti: specifier: ^1.9.4 version: 1.9.4 + dagre: + specifier: ^0.8.5 + version: 0.8.5 dayjs: specifier: ^1.11.13 version: 1.11.19 @@ -734,6 +740,9 @@ packages: '@types/d3-zoom@3.0.8': resolution: {integrity: sha512-iqMC4/YlFCSlO8+2Ii1GGGliCAY4XdeG748w5vQUbevlbDu0zSjH/+jojorQVBK/se0j6DUFNPBGSqD3YWYnDw==} + '@types/dagre@0.7.54': + resolution: {integrity: sha512-QjcRY+adGbYvBFS7cwv5txhVIwX1XXIUswWl+kSQTbI6NjgZydrZkEKX/etzVd7i+bCsCb40Z/xlBY5eoFuvWQ==} + '@types/estree@1.0.8': resolution: {integrity: sha512-dWHzHa2WqEXI/O1E9OjrocMTKJl2mSrEolh1Iomrv6U+JuNwaHXsXx9bLu5gG7BUWFIN0skIQJQ/L1rIex4X6w==} @@ -996,6 +1005,9 @@ packages: resolution: {integrity: sha512-b8AmV3kfQaqWAuacbPuNbL6vahnOJflOhexLzMMNLga62+/nh0JzvJ0aO/5a5MVgUFGS7Hu1P9P03o3fJkDCyw==} engines: {node: '>=12'} + dagre@0.8.5: + resolution: {integrity: sha512-/aTqmnRta7x7MCCpExk7HQL2O4owCT2h8NT//9I1OQ9vt29Pa0BzSAkR5lwFUcQ7491yVi/3CXU9jQ5o0Mn2Sw==} + dayjs@1.11.19: resolution: {integrity: sha512-t5EcLVS6QPBNqM2z8fakk/NKel+Xzshgt8FFKAn+qwlD1pzZWxh0nVCrvFK7ZDb6XucZeF9z8C7CBWTRIVApAw==} @@ -1200,6 +1212,9 @@ packages: resolution: {integrity: sha512-ZUKRh6/kUFoAiTAtTYPZJ3hw9wNxx+BIBOijnlG9PnrJsCcSjs1wyyD6vJpaYtgnzDrKYRSqf3OO6Rfa93xsRg==} engines: {node: '>= 0.4'} + graphlib@2.1.8: + resolution: {integrity: sha512-jcLLfkpoVGmH7/InMC/1hIvOPSUh38oJtGhvrOFGzioE1DZ+0YW16RgmOJhHiuWTvGiJQ9Z1Ik43JvkRPRvE+A==} + has-flag@4.0.0: resolution: {integrity: sha512-EykJT/Q1KjTWctppgIAgfSO0tKVuZUjhgMr17kqTumMl6Afv3EISleU7qZUzoXDFTAHTDC4NOoG/ZxU3EvlMPQ==} engines: {node: '>=8'} @@ -1299,6 +1314,9 @@ packages: lodash.merge@4.6.2: resolution: {integrity: sha512-0KpjqXRVvrYyCsX1swR/XTK0va6VQkQM6MNo7PqW77ByjAhoARA8EfrP1N4+KlKj8YS0ZUCtRT/YUuhyYDujIQ==} + lodash@4.18.1: + resolution: {integrity: sha512-dMInicTPVE8d1e5otfwmmjlxkZoUpiVLwyeTdUsi/Caj/gfzzblBcCE5sRHV/AsjuCmxWrte2TNGSYuCeCq+0Q==} + loose-envify@1.4.0: resolution: {integrity: sha512-lyuxPGr/Wfhrlem2CL/UcnUc1zcqKAImBDzukY7Y5F/yQiNdko6+fRLevlw1HgMySw7f611UIY408EtxRSoK3Q==} hasBin: true @@ -2498,6 +2516,8 @@ snapshots: '@types/d3-interpolate': 3.0.4 '@types/d3-selection': 3.0.11 + '@types/dagre@0.7.54': {} + '@types/estree@1.0.8': {} '@types/json-schema@7.0.15': {} @@ -2861,6 +2881,11 @@ snapshots: d3-selection: 3.0.0 d3-transition: 3.0.1(d3-selection@3.0.0) + dagre@0.8.5: + dependencies: + graphlib: 2.1.8 + lodash: 4.18.1 + dayjs@1.11.19: {} debug@4.4.3: @@ -3082,6 +3107,10 @@ snapshots: gopd@1.2.0: {} + graphlib@2.1.8: + dependencies: + lodash: 4.18.1 + has-flag@4.0.0: {} has-symbols@1.1.0: {} @@ -3158,6 +3187,8 @@ snapshots: lodash.merge@4.6.2: {} + lodash@4.18.1: {} + loose-envify@1.4.0: dependencies: js-tokens: 4.0.0 diff --git a/frontend/src/pages/Chapters.tsx b/frontend/src/pages/Chapters.tsx index a4c5b22..f9ef218 100644 --- a/frontend/src/pages/Chapters.tsx +++ b/frontend/src/pages/Chapters.tsx @@ -1,15 +1,15 @@ import { useState, useEffect, useRef, useMemo, useCallback } from 'react'; -import { List, Button, Modal, Form, Input, Select, message, Empty, Space, Badge, Tag, Card, InputNumber, Alert, Radio, Descriptions, Collapse, Popconfirm, Pagination, theme } from 'antd'; -import { EditOutlined, FileTextOutlined, ThunderboltOutlined, LockOutlined, DownloadOutlined, SettingOutlined, FundOutlined, SyncOutlined, CheckCircleOutlined, CloseCircleOutlined, RocketOutlined, StopOutlined, InfoCircleOutlined, CaretRightOutlined, DeleteOutlined, BookOutlined, FormOutlined, PlusOutlined, ReadOutlined } from '@ant-design/icons'; +import { List, Button, Modal, Form, Input, Select, message, Empty, Space, Badge, Tag, Progress, Card, InputNumber, Alert, Radio, Descriptions, Collapse, Popconfirm, Pagination, theme } from 'antd'; +import { EditOutlined, FileTextOutlined, ThunderboltOutlined, LockOutlined, DownloadOutlined, SettingOutlined, FundOutlined, SyncOutlined, CheckCircleOutlined, CloseCircleOutlined, RocketOutlined, StopOutlined, InfoCircleOutlined, CaretRightOutlined, DeleteOutlined, BookOutlined, FormOutlined, PlusOutlined, ReadOutlined, ClockCircleOutlined, LoadingOutlined } from '@ant-design/icons'; import { useStore } from '../store'; import { useChapterSync } from '../store/hooks'; +import { generateChapterBackground, getProjectTasks, cancelTask, deleteTask, type TaskStatus as BgTaskStatus } from '../services/backgroundTaskService'; import { projectApi, writingStyleApi, chapterApi } from '../services/api'; import type { Chapter, ChapterUpdate, ApiError, WritingStyle, AnalysisTask, ExpansionPlanData } from '../types'; import type { TextAreaRef } from 'antd/es/input/TextArea'; import ChapterAnalysis from '../components/ChapterAnalysis'; import ExpansionPlanEditor from '../components/ExpansionPlanEditor'; import { SSELoadingOverlay } from '../components/SSELoadingOverlay'; -import { SSEProgressModal } from '../components/SSEProgressModal'; import ChapterReader from '../components/ChapterReader'; import PartialRegenerateToolbar from '../components/PartialRegenerateToolbar'; import PartialRegenerateModal from '../components/PartialRegenerateModal'; @@ -97,6 +97,112 @@ export default function Chapters() { const [singleChapterProgress, setSingleChapterProgress] = useState(0); const [singleChapterProgressMessage, setSingleChapterProgressMessage] = useState(''); + // 后台生成任务状态 + const [bgTaskVisible, setBgTaskVisible] = useState(false); + const [bgTaskProgress, setBgTaskProgress] = useState(0); + const [bgTaskMessage, setBgTaskMessage] = useState(''); + const [bgTaskRunning, setBgTaskRunning] = useState(false); + const bgTaskCancelRef = useRef<(() => void) | null>(null); + const [projectBgTasks, setProjectBgTasks] = useState([]); + const bgPollTimerRef = useRef | null>(null); + // 后台任务列表 Modal 状态 + const [taskListVisible, setTaskListVisible] = useState(false); + const [taskList, setTaskList] = useState([]); + const [taskListLoading, setTaskListLoading] = useState(false); + + // 轮询项目后台任务 + useEffect(() => { + if (!currentProject) return; + const pollBgTasks = async () => { + try { + const resp = await getProjectTasks(currentProject.id, 'chapter_generate', 10); + const active = resp.items.filter(t => t.status === 'pending' || t.status === 'running'); + setProjectBgTasks(active); + // 如果有活跃任务,继续轮询 + if (active.length > 0) { + bgPollTimerRef.current = setTimeout(pollBgTasks, 3000); + } + } catch {} + }; + pollBgTasks(); + return () => { if (bgPollTimerRef.current) clearTimeout(bgPollTimerRef.current); }; + }, [currentProject]); + + // 加载并显示后台任务列表 + const showTaskListModal = async () => { + if (!currentProject?.id) return; + setTaskListVisible(true); + setTaskListLoading(true); + try { + const result = await getProjectTasks(currentProject.id); + setTaskList(result.items || []); + } catch (error) { + message.error('加载任务列表失败'); + } finally { + setTaskListLoading(false); + } + }; + + // 刷新任务列表 + const refreshTaskList = async () => { + if (!currentProject?.id) return; + setTaskListLoading(true); + try { + const result = await getProjectTasks(currentProject.id); + setTaskList(result.items || []); + const active = (result.items || []).filter(t => t.status === 'pending' || t.status === 'running'); + setProjectBgTasks(active); + } catch (error) { + console.error('刷新任务列表失败:', error); + } finally { + setTaskListLoading(false); + } + }; + + // 获取任务状态标签 + const getTaskStatusTag = (status: BgTaskStatus['status']) => { + switch (status) { + case 'pending': return } color="default">等待中; + case 'running': return } color="processing">运行中; + case 'completed': return } color="success">已完成; + case 'failed': return } color="error">失败; + case 'cancelled': return } color="default">已取消; + default: return {status}; + } + }; + + // 获取任务类型标签 + const getTaskTypeLabel = (taskType: string) => { + switch (taskType) { + case 'chapter_generate': return '章节生成'; + case 'outline_new': return '大纲生成'; + case 'outline_continue': return '大纲续写'; + default: return taskType; + } + }; + + // 处理取消后台任务 + const handleCancelBgTask = async (taskId: string) => { + try { + await cancelTask(taskId); + message.success('任务已取消'); + refreshTaskList(); + } catch (error) { + message.error('取消任务失败'); + } + }; + + // 处理删除任务记录 + const handleDeleteBgTask = async (taskId: string) => { + try { + await deleteTask(taskId); + message.success('任务记录已删除'); + refreshTaskList(); + } catch (error) { + message.error('删除任务记录失败'); + } + }; + // 批量生成相关状态 const [batchGenerateVisible, setBatchGenerateVisible] = useState(false); const [batchGenerating, setBatchGenerating] = useState(false); @@ -523,7 +629,7 @@ export default function Chapters() { if (data.has_active_task && data.task) { const task = data.task; - // 恢复任务状态 + // 恢复任务状态(只在顶部进度条显示,不弹出Modal) setBatchTaskId(task.batch_id); setBatchProgress({ status: task.status, @@ -532,12 +638,12 @@ export default function Chapters() { current_chapter_number: task.current_chapter_number, }); setBatchGenerating(true); - setBatchGenerateVisible(true); + // 不设置 setBatchGenerateVisible(true),避免弹出Modal遮挡页面 // 启动轮询 startBatchPolling(task.batch_id); - message.info('检测到未完成的批量生成任务,已自动恢复'); + message.info('检测到未完成的批量生成任务,已在顶部显示进度'); } } catch (error) { console.error('检查批量生成任务失败:', error); @@ -971,9 +1077,62 @@ export default function Chapters() { }); }; + + // 后台生成章节(关闭浏览器也不影响) + const handleBackgroundGenerate = async () => { + if (!editingId) return; + if (!selectedStyleId) { + message.error("请先选择写作风格"); + return; + } + + try { + setBgTaskVisible(true); + setBgTaskRunning(true); + setBgTaskProgress(0); + setBgTaskMessage("正在创建后台任务..."); + + const cancelFn = await generateChapterBackground( + editingId, + { + style_id: selectedStyleId, + target_word_count: targetWordCount, + model: selectedModel, + narrative_perspective: temporaryNarrativePerspective, + }, + (status) => { + setBgTaskProgress(status.progress || 0); + setBgTaskMessage(status.status_message || "处理中..."); + }, + (_) => { + setBgTaskProgress(100); + setBgTaskMessage("生成完成!"); + setBgTaskRunning(false); + message.success("后台章节生成完成!"); + refreshChapters(); + if (currentProject) { + projectApi.getProject(currentProject.id).then(setCurrentProject).catch(console.error); + } + loadAnalysisTasks(); + }, + (error) => { + setBgTaskRunning(false); + setBgTaskMessage("失败: " + error); + message.error("后台生成失败: " + error); + } + ); + + bgTaskCancelRef.current = cancelFn; + message.info("已提交后台生成任务,可以关闭此页面"); + } catch (error) { + message.error("创建后台任务失败"); + setBgTaskRunning(false); + } + }; const getStatusColor = (status: string) => { const colors: Record = { 'draft': 'default', + 'pending': 'warning', 'writing': 'processing', 'completed': 'success', }; @@ -983,6 +1142,7 @@ export default function Chapters() { const getStatusText = (status: string) => { const texts: Record = { 'draft': '草稿', + 'pending': '待处理', 'writing': '创作中', 'completed': '已完成', }; @@ -1387,6 +1547,7 @@ export default function Chapters() { > @@ -1931,6 +2092,13 @@ export default function Chapters() { > 一键分析{batchAnalyzableChapterCount > 0 ? ` (${batchAnalyzableChapterCount})` : ''} + + + )} + {/* 单章节后台生成进度 */} + {projectBgTasks.map(task => ( +
+ + {task.status === 'running' ? '生成中' : '排队中'} + +
+
+
+
+
+ + {task.progress || 0}% + + + {task.status_message || ''} + +
+ ))} +
+ )} +
{chapters.length === 0 ? ( @@ -2435,6 +2700,7 @@ export default function Chapters() { @@ -2497,23 +2763,56 @@ export default function Chapters() { const disabledReason = currentChapter ? getGenerateDisabledReason(currentChapter) : ''; return ( + <> + + ); })()} + {/* 后台生成进度 */} + {bgTaskVisible && ( + +
{bgTaskMessage}
+
+
+
+
{bgTaskProgress}%
+
+ } + type={bgTaskRunning ? 'info' : (bgTaskProgress >= 100 ? 'success' : 'error')} + showIcon + style={{ marginBottom: 12 }} + closable={!bgTaskRunning} + onClose={() => setBgTaskVisible(false)} + /> + )} + {/* 第一行:写作风格 + 叙事角度 */}
- {/* 批量生成进度显示 - 使用统一的进度组件 */} - + + 后台任务 + {taskList.filter(t => t.status === 'running' || t.status === 'pending').length > 0 && ( + t.status === 'running' || t.status === 'pending').length} /> + )} + } - title="批量生成章节" - onCancel={() => { - modal.confirm({ - title: '确认取消', - content: '确定要取消批量生成吗?已生成的章节将保留。', - okText: '确定取消', - cancelText: '继续生成', - okButtonProps: { danger: true }, - centered: true, - onOk: handleCancelBatchGenerate, - }); - }} - cancelButtonText="取消任务" - /> + open={taskListVisible} + onCancel={() => setTaskListVisible(false)} + width={isMobile ? '95%' : 700} + centered + footer={ + + + + + } + > + {taskListLoading && taskList.length === 0 ? ( +
+ +
加载中...
+
+ ) : taskList.length === 0 ? ( + + ) : ( + ( + handleCancelBgTask(task.id)}>取消] + : [] + ), + ...(task.status === 'completed' || task.status === 'failed' || task.status === 'cancelled' + ? [] + : [] + ), + ].filter(Boolean)} + > + + {getTaskStatusTag(task.status)} + {getTaskTypeLabel(task.task_type)} + {task.status === 'running' || task.status === 'pending' ? ( + + ) : null} + + } + description={ +
+
+ {task.status_message || '无状态信息'} +
+
+ 创建: {task.created_at ? new Date(task.created_at).toLocaleString() : '-'} + {task.completed_at && ' | 完成: ' + new Date(task.completed_at).toLocaleString()} +
+ {task.error_message && ( +
+ {'❌ ' + task.error_message} +
+ )} + {task.task_result && task.status === 'completed' && ( +
+ {'✅ ' + ((task.task_result as Record).message as string || '任务完成')} +
+ )} +
+ } + /> +
+ )} + /> + )} + + {/* 章节阅读器 */} {readingChapter && ( diff --git a/frontend/src/pages/Outline.tsx b/frontend/src/pages/Outline.tsx index fb81b61..c529197 100644 --- a/frontend/src/pages/Outline.tsx +++ b/frontend/src/pages/Outline.tsx @@ -1,10 +1,11 @@ -import { useState, useEffect, useMemo } from 'react'; -import { Button, List, Modal, Form, Input, message, Empty, Space, Popconfirm, Card, Select, Radio, Tag, InputNumber, Tabs, Pagination, theme } from 'antd'; -import { EditOutlined, DeleteOutlined, ThunderboltOutlined, BranchesOutlined, AppstoreAddOutlined, CheckCircleOutlined, ExclamationCircleOutlined, PlusOutlined, FileTextOutlined } from '@ant-design/icons'; +import { useState, useEffect, useMemo, useRef } from 'react'; +import { Button, List, Modal, Form, Input, message, Empty, Space, Popconfirm, Card, Select, Radio, Tag, InputNumber, Tabs, Pagination, theme, Progress, Badge, Tooltip } from 'antd'; +import { EditOutlined, DeleteOutlined, ThunderboltOutlined, BranchesOutlined, AppstoreAddOutlined, CheckCircleOutlined, ExclamationCircleOutlined, PlusOutlined, FileTextOutlined, ClockCircleOutlined, ReloadOutlined, CloseCircleOutlined, LoadingOutlined } from '@ant-design/icons'; import { useStore } from '../store'; import { useOutlineSync } from '../store/hooks'; import { SSEPostClient } from '../utils/sseClient'; import { SSEProgressModal } from '../components/SSEProgressModal'; +import { generateOutlineBackground, getProjectTasks, cancelTask, deleteTask, type TaskStatus } from '../services/backgroundTaskService'; import { outlineApi, chapterApi, projectApi, characterApi } from '../services/api'; import type { OutlineExpansionResponse, BatchOutlineExpansionResponse, ChapterPlanItem, ApiError, Character } from '../types'; @@ -154,6 +155,14 @@ export default function Outline() { const [sseMessage, setSSEMessage] = useState(''); const [sseModalVisible, setSSEModalVisible] = useState(false); + // 后台任务取消函数引用 + const cancelGenerateRef = useRef<(() => void) | null>(null); + + // 后台任务列表状态 + const [taskListVisible, setTaskListVisible] = useState(false); + const [taskList, setTaskList] = useState([]); + const [taskListLoading, setTaskListLoading] = useState(false); + useEffect(() => { const handleResize = () => { setIsMobile(window.innerWidth <= 768); @@ -573,33 +582,31 @@ export default function Outline() { console.log('6. 最终请求数据:', JSON.stringify(requestData, null, 2)); console.log('========================='); - // 使用SSE客户端 - const apiUrl = `/api/outlines/generate-stream`; - const client = new SSEPostClient(apiUrl, requestData, { - onProgress: (msg: string, progress: number) => { - setSSEMessage(msg); - setSSEProgress(progress); + // 使用后台任务生成(不怕断连,关闭浏览器也继续运行) + setSSEMessage('正在创建后台任务...'); + + const cancelFn = await generateOutlineBackground( + requestData, + (status) => { + setSSEProgress(status.progress); + setSSEMessage(status.status_message || '处理中...'); }, - onResult: (data: unknown) => { - console.log('生成完成,结果:', data); + (result) => { + message.success(result.task_result?.message as string || '大纲生成完成!'); + setSSEModalVisible(false); + setIsGenerating(false); + cancelGenerateRef.current = null; + refreshOutlines(); }, - onError: (error: string) => { - // 现在只处理真正的错误 + (error) => { message.error(`生成失败: ${error}`); setSSEModalVisible(false); setIsGenerating(false); - }, - onComplete: () => { - message.success('大纲生成完成!'); - setSSEModalVisible(false); - setIsGenerating(false); - // 刷新大纲列表 - refreshOutlines(); + cancelGenerateRef.current = null; } - }); + ); - // 开始连接 - client.connect(); + cancelGenerateRef.current = cancelFn; } catch (error) { console.error('AI生成失败:', error); @@ -1895,8 +1902,168 @@ export default function Outline() { }; + // 加载并显示后台任务列表 + const showTaskListModal = async () => { + if (!currentProject?.id) return; + setTaskListVisible(true); + setTaskListLoading(true); + try { + const result = await getProjectTasks(currentProject.id); + setTaskList(result.items || []); + } catch (error) { + message.error('加载任务列表失败'); + } finally { + setTaskListLoading(false); + } + }; + + // 刷新任务列表 + const refreshTaskList = async () => { + if (!currentProject?.id) return; + setTaskListLoading(true); + try { + const result = await getProjectTasks(currentProject.id); + setTaskList(result.items || []); + } catch (error) { + console.error('刷新任务列表失败:', error); + } finally { + setTaskListLoading(false); + } + }; + + // 获取任务状态标签 + const getTaskStatusTag = (status: TaskStatus['status']) => { + switch (status) { + case 'pending': return } color="default">等待中; + case 'running': return } color="processing">运行中; + case 'completed': return } color="success">已完成; + case 'failed': return } color="error">失败; + case 'cancelled': return } color="default">已取消; + default: return {status}; + } + }; + + // 获取任务类型标签 + const getTaskTypeLabel = (taskType: string) => { + switch (taskType) { + case 'outline_new': return '大纲生成'; + case 'outline_continue': return '大纲续写'; + default: return taskType; + } + }; + + // 处理取消后台任务 + const handleCancelTask = async (taskId: string) => { + try { + await cancelTask(taskId); + message.success('任务已取消'); + refreshTaskList(); + } catch (error) { + message.error('取消任务失败'); + } + }; + + // 处理删除任务记录 + const handleDeleteTask = async (taskId: string) => { + try { + await deleteTask(taskId); + message.success('任务记录已删除'); + refreshTaskList(); + } catch (error) { + message.error('删除任务记录失败'); + } + }; + return ( <> + {/* 后台任务列表 Modal */} + + + 后台任务 + {taskList.filter(t => t.status === 'running' || t.status === 'pending').length > 0 && ( + t.status === 'running' || t.status === 'pending').length} /> + )} + + } + open={taskListVisible} + onCancel={() => setTaskListVisible(false)} + width={isMobile ? '95%' : 700} + centered + footer={ + + + + + } + > + {taskListLoading && taskList.length === 0 ? ( +
+ +
加载中...
+
+ ) : taskList.length === 0 ? ( + + ) : ( + ( + handleCancelTask(task.id)}>取消] + : [] + ), + ...(task.status === 'completed' || task.status === 'failed' || task.status === 'cancelled' + ? [] + : [] + ), + ].filter(Boolean)} + > + + {getTaskStatusTag(task.status)} + {getTaskTypeLabel(task.task_type)} + {task.status === 'running' || task.status === 'pending' ? ( + + ) : null} + + } + description={ +
+
+ {task.status_message || '无状态信息'} +
+
+ 创建: {task.created_at ? new Date(task.created_at).toLocaleString() : '-'} + {task.completed_at && ` | 完成: ${new Date(task.completed_at).toLocaleString()}`} +
+ {task.error_message && ( +
+ ❌ {task.error_message} +
+ )} + {task.task_result && task.status === 'completed' && ( +
+ ✅ {(task.task_result as Record).message as string || '任务完成'} +
+ )} +
+ } + /> +
+ )} + /> + )} +
+ {/* 批量展开预览 Modal */} { + if (cancelGenerateRef.current) { + cancelGenerateRef.current(); + cancelGenerateRef.current = null; + } + setSSEModalVisible(false); + setIsGenerating(false); + message.info('已取消生成任务'); + }} />
@@ -1977,6 +2153,15 @@ export default function Outline() { > {isMobile ? 'AI生成/续写' : 'AI生成/续写大纲'} + + + {outlines.length > 0 && currentProject?.outline_mode === 'one-to-many' && (