2bd8b61e91
新功能: - 大纲/章节生成改为服务端后台任务,支持断线续传 - 后台任务队列排队执行,按用户排队(同用户串行不同用户并发) - 章节管理页面添加后台任务列表弹窗和进度面板 - 章节状态添加 pending(待处理)选项 - 集成json5容错解析器 + 上下文感知JSON修复 - SSE流式生成添加心跳保活,防止连接超时 - SSEPostClient添加credentials:include修复network error - 每章最大伏笔数从2调整为5 - 添加大纲读区伏笔的功能 Bug修复: - 修复AI生成JSON中未转义引号/中文标点/多对象属性值未合并 - 修复JSON非法转义字符清洗和中文引号处理 - 修复MCP插件TimeoutError/连接失败上下文清理 - MCP插件后台注册添加重试机制 - 续写模式添加缺失的mcp_references参数 - 修复Alembic迁移链分叉 - 使用torch CPU版本加速Docker构建
387 lines
16 KiB
Python
387 lines
16 KiB
Python
"""后台任务管理服务 - 管理长时间运行的AI生成任务"""
|
||
import asyncio
|
||
from datetime import datetime
|
||
from typing import Dict, Any, Optional, Callable, Awaitable
|
||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||
from sqlalchemy import select, update
|
||
from app.database import get_engine
|
||
from app.models.background_task import BackgroundTask
|
||
from app.logger import get_logger
|
||
|
||
logger = get_logger(__name__)
|
||
|
||
|
||
class TaskProgressTracker:
|
||
"""后台任务进度追踪器(替代SSE的WizardProgressTracker)"""
|
||
|
||
def __init__(self, task_id: str, user_id: str, task_name: str = "任务"):
|
||
self.task_id = task_id
|
||
self.user_id = user_id
|
||
self.task_name = task_name
|
||
self.current_progress = 0
|
||
self._last_generating_progress = 20
|
||
|
||
async def _update_task(self, **kwargs):
|
||
"""更新任务状态到数据库"""
|
||
try:
|
||
engine = await get_engine(self.user_id)
|
||
AsyncSessionLocal = async_sessionmaker(
|
||
engine, class_=AsyncSession, expire_on_commit=False
|
||
)
|
||
async with AsyncSessionLocal() as session:
|
||
result = await session.execute(
|
||
select(BackgroundTask).where(BackgroundTask.id == self.task_id)
|
||
)
|
||
task = result.scalar_one_or_none()
|
||
if task:
|
||
for key, value in kwargs.items():
|
||
setattr(task, key, value)
|
||
task.updated_at = datetime.now()
|
||
await session.commit()
|
||
except Exception as e:
|
||
logger.error(f"❌ 更新任务进度失败: {e}")
|
||
|
||
async def start(self, message: str = None):
|
||
self.current_progress = 0
|
||
msg = message or f"开始生成{self.task_name}..."
|
||
await self._update_task(
|
||
status="running", progress=0, status_message=msg,
|
||
started_at=datetime.now(),
|
||
progress_details={"stage": "init", "message": msg}
|
||
)
|
||
|
||
async def loading(self, message: str = None, sub_progress: float = 0.5):
|
||
progress = 5 + int(10 * sub_progress)
|
||
self.current_progress = progress
|
||
msg = message or "加载数据中..."
|
||
await self._update_task(
|
||
progress=progress, status_message=msg,
|
||
progress_details={"stage": "loading", "message": msg}
|
||
)
|
||
|
||
async def preparing(self, message: str = None):
|
||
self.current_progress = 17
|
||
msg = message or "准备AI提示词..."
|
||
await self._update_task(
|
||
progress=17, status_message=msg,
|
||
progress_details={"stage": "preparing", "message": msg}
|
||
)
|
||
|
||
async def generating(self, current_chars: int = 0, estimated_total: int = 5000,
|
||
message: str = None, retry_count: int = 0, max_retries: int = 3):
|
||
sub_progress = min(current_chars / max(estimated_total, 1), 1.0)
|
||
progress = 20 + int(65 * sub_progress)
|
||
if progress < self._last_generating_progress:
|
||
progress = self._last_generating_progress
|
||
else:
|
||
self._last_generating_progress = progress
|
||
self.current_progress = progress
|
||
|
||
retry_suffix = f" (重试 {retry_count}/{max_retries})" if retry_count > 0 else ""
|
||
msg = message or f"生成{self.task_name}中... ({current_chars}字符){retry_suffix}"
|
||
await self._update_task(
|
||
progress=progress, status_message=msg,
|
||
progress_details={"stage": "generating", "message": msg, "current_chars": current_chars}
|
||
)
|
||
|
||
async def parsing(self, message: str = None):
|
||
self.current_progress = 88
|
||
msg = message or f"解析{self.task_name}数据..."
|
||
await self._update_task(
|
||
progress=88, status_message=msg,
|
||
progress_details={"stage": "parsing", "message": msg}
|
||
)
|
||
|
||
async def saving(self, message: str = None, sub_progress: float = 0.5):
|
||
progress = 92 + int(6 * sub_progress)
|
||
self.current_progress = progress
|
||
msg = message or f"保存{self.task_name}到数据库..."
|
||
await self._update_task(
|
||
progress=progress, status_message=msg,
|
||
progress_details={"stage": "saving", "message": msg}
|
||
)
|
||
|
||
async def complete(self, message: str = None):
|
||
self.current_progress = 100
|
||
msg = message or f"{self.task_name}生成完成!"
|
||
await self._update_task(
|
||
status="completed", progress=100, status_message=msg,
|
||
completed_at=datetime.now(),
|
||
progress_details={"stage": "complete", "message": msg}
|
||
)
|
||
|
||
async def error(self, error_message: str):
|
||
await self._update_task(
|
||
status="failed", error_message=error_message,
|
||
status_message=f"失败: {error_message}",
|
||
completed_at=datetime.now(),
|
||
progress_details={"stage": "error", "message": error_message}
|
||
)
|
||
|
||
async def warning(self, message: str):
|
||
await self._update_task(
|
||
status_message=f"⚠️ {message}",
|
||
progress_details={"stage": "warning", "message": message}
|
||
)
|
||
|
||
async def retry(self, retry_count: int, max_retries: int, reason: str = "准备重试"):
|
||
msg = f"⚠️ {reason}... ({retry_count}/{max_retries})"
|
||
await self._update_task(
|
||
status_message=msg, retry_count=retry_count,
|
||
progress_details={"stage": "retry", "message": msg, "retry_count": retry_count}
|
||
)
|
||
|
||
def reset_generating_progress(self):
|
||
self._last_generating_progress = 20
|
||
|
||
async def check_cancelled(self) -> bool:
|
||
"""检查任务是否被取消"""
|
||
try:
|
||
engine = await get_engine(self.user_id)
|
||
AsyncSessionLocal = async_sessionmaker(
|
||
engine, class_=AsyncSession, expire_on_commit=False
|
||
)
|
||
async with AsyncSessionLocal() as session:
|
||
result = await session.execute(
|
||
select(BackgroundTask.cancel_requested)
|
||
.where(BackgroundTask.id == self.task_id)
|
||
)
|
||
cancelled = result.scalar_one_or_none()
|
||
return bool(cancelled)
|
||
except Exception:
|
||
return False
|
||
|
||
|
||
class BackgroundTaskService:
|
||
"""后台任务管理服务(按用户排队:同用户任务逐个执行,不同用户可并发)"""
|
||
|
||
def __init__(self):
|
||
self._user_queues: Dict[str, asyncio.Queue] = {} # user_id -> Queue
|
||
self._user_workers: Dict[str, bool] = {} # user_id -> worker是否运行中
|
||
|
||
def _ensure_user_queue(self, user_id: str) -> asyncio.Queue:
|
||
"""确保指定用户的队列已初始化"""
|
||
if user_id not in self._user_queues:
|
||
self._user_queues[user_id] = asyncio.Queue()
|
||
return self._user_queues[user_id]
|
||
|
||
async def _start_user_worker(self, user_id: str):
|
||
"""启动指定用户的工作协程"""
|
||
if self._user_workers.get(user_id, False):
|
||
return
|
||
self._user_workers[user_id] = True
|
||
asyncio.create_task(self._user_worker_loop(user_id))
|
||
logger.info(f"📋 用户 {user_id[:8]} 的任务队列工作协程已启动")
|
||
|
||
async def _user_worker_loop(self, user_id: str):
|
||
"""从指定用户的队列中逐个取出任务并执行"""
|
||
queue = self._user_queues[user_id]
|
||
try:
|
||
while True:
|
||
try:
|
||
task_item = await queue.get()
|
||
task_id = task_item["task_id"]
|
||
task_func = task_item["task_func"]
|
||
args = task_item["args"]
|
||
kwargs = task_item["kwargs"]
|
||
|
||
logger.info(f"🔄 [用户{user_id[:8]}] 队列开始执行任务: {task_id[:8]} (队列剩余: {queue.qsize()})")
|
||
|
||
try:
|
||
await task_func(task_id, args["user_id"], *args["extra_args"], **kwargs)
|
||
except Exception as e:
|
||
logger.error(f"❌ 后台任务 {task_id[:8]} 异常: {e}", exc_info=True)
|
||
# 确保任务状态更新为失败
|
||
try:
|
||
engine = await get_engine(user_id)
|
||
AsyncSessionLocal = async_sessionmaker(
|
||
engine, class_=AsyncSession, expire_on_commit=False
|
||
)
|
||
async with AsyncSessionLocal() as session:
|
||
result = await session.execute(
|
||
select(BackgroundTask).where(BackgroundTask.id == task_id)
|
||
)
|
||
task = result.scalar_one_or_none()
|
||
if task and task.status == "running":
|
||
task.status = "failed"
|
||
task.error_message = str(e)
|
||
task.status_message = f"任务失败: {str(e)}"
|
||
task.completed_at = datetime.now()
|
||
await session.commit()
|
||
except Exception as update_err:
|
||
logger.error(f"❌ 更新失败任务状态失败: {update_err}")
|
||
finally:
|
||
queue.task_done()
|
||
logger.info(f"✅ [用户{user_id[:8]}] 队列任务完成: {task_id[:8]} (队列剩余: {queue.qsize()})")
|
||
|
||
except Exception as e:
|
||
logger.error(f"❌ [用户{user_id[:8]}] 队列工作循环异常: {e}", exc_info=True)
|
||
finally:
|
||
# 工作协程退出时清理标记
|
||
self._user_workers.pop(user_id, None)
|
||
logger.info(f"📋 用户 {user_id[:8]} 的工作协程已退出")
|
||
|
||
@staticmethod
|
||
async def create_task(
|
||
user_id: str,
|
||
project_id: str,
|
||
task_type: str,
|
||
task_input: Dict[str, Any] = None,
|
||
db: AsyncSession = None
|
||
) -> BackgroundTask:
|
||
"""创建后台任务记录"""
|
||
task = BackgroundTask(
|
||
user_id=user_id,
|
||
project_id=project_id,
|
||
task_type=task_type,
|
||
task_input=task_input or {},
|
||
status="pending",
|
||
progress=0,
|
||
status_message="任务已创建,等待执行..."
|
||
)
|
||
db.add(task)
|
||
await db.commit()
|
||
await db.refresh(task)
|
||
logger.info(f"📋 创建后台任务: {task.id[:8]} type={task_type} project={project_id[:8]}")
|
||
return task
|
||
|
||
@staticmethod
|
||
async def get_task(task_id: str, user_id: str, db: AsyncSession) -> Optional[BackgroundTask]:
|
||
"""获取任务详情"""
|
||
result = await db.execute(
|
||
select(BackgroundTask).where(
|
||
BackgroundTask.id == task_id,
|
||
BackgroundTask.user_id == user_id
|
||
)
|
||
)
|
||
return result.scalar_one_or_none()
|
||
|
||
@staticmethod
|
||
async def get_project_tasks(
|
||
project_id: str, user_id: str, db: AsyncSession,
|
||
task_type: str = None, limit: int = 20
|
||
) -> list:
|
||
"""获取项目的任务列表"""
|
||
query = (
|
||
select(BackgroundTask)
|
||
.where(
|
||
BackgroundTask.project_id == project_id,
|
||
BackgroundTask.user_id == user_id
|
||
)
|
||
.order_by(BackgroundTask.created_at.desc())
|
||
)
|
||
if task_type:
|
||
query = query.where(BackgroundTask.task_type == task_type)
|
||
query = query.limit(limit)
|
||
result = await db.execute(query)
|
||
return result.scalars().all()
|
||
|
||
@staticmethod
|
||
async def cancel_task(task_id: str, user_id: str, db: AsyncSession) -> bool:
|
||
"""请求取消任务"""
|
||
result = await db.execute(
|
||
select(BackgroundTask).where(
|
||
BackgroundTask.id == task_id,
|
||
BackgroundTask.user_id == user_id
|
||
)
|
||
)
|
||
task = result.scalar_one_or_none()
|
||
if not task:
|
||
return False
|
||
if task.status not in ("pending", "running"):
|
||
return False
|
||
task.cancel_requested = True
|
||
task.status = "cancelled"
|
||
task.status_message = "任务已取消"
|
||
task.completed_at = datetime.now()
|
||
await db.commit()
|
||
logger.info(f"🚫 取消任务: {task_id[:8]}")
|
||
return True
|
||
|
||
@staticmethod
|
||
async def cleanup_old_tasks(user_id: str, db: AsyncSession, days: int = 7):
|
||
"""清理旧任务记录"""
|
||
from sqlalchemy import delete as sql_delete
|
||
from datetime import timedelta
|
||
cutoff = datetime.now() - timedelta(days=days)
|
||
result = await db.execute(
|
||
sql_delete(BackgroundTask).where(
|
||
BackgroundTask.user_id == user_id,
|
||
BackgroundTask.status.in_(["completed", "failed", "cancelled"]),
|
||
BackgroundTask.completed_at < cutoff
|
||
)
|
||
)
|
||
if result.rowcount > 0:
|
||
await db.commit()
|
||
logger.info(f"🧹 清理用户 {user_id[:8]} 的 {result.rowcount} 条旧任务记录")
|
||
|
||
async def spawn_background_task(
|
||
self,
|
||
task_id: str,
|
||
user_id: str,
|
||
task_func: Callable[..., Awaitable],
|
||
*args,
|
||
**kwargs
|
||
):
|
||
"""
|
||
将任务加入该用户的队列排队执行(同一用户FIFO,不同用户可并发)
|
||
|
||
Args:
|
||
task_id: 任务ID
|
||
user_id: 用户ID
|
||
task_func: 异步任务函数
|
||
*args, **kwargs: 传递给task_func的参数
|
||
"""
|
||
# 确保该用户的队列和工作协程已启动
|
||
queue = self._ensure_user_queue(user_id)
|
||
await self._start_user_worker(user_id)
|
||
|
||
# 将任务放入该用户的队列
|
||
await queue.put({
|
||
"task_id": task_id,
|
||
"task_func": task_func,
|
||
"args": {"user_id": user_id, "extra_args": args},
|
||
"kwargs": kwargs,
|
||
})
|
||
queue_size = queue.qsize()
|
||
logger.info(f"📥 任务已加入用户 {user_id[:8]} 的队列: {task_id[:8]} (当前队列长度: {queue_size})")
|
||
|
||
# 更新任务状态,显示排队位置
|
||
try:
|
||
engine = await get_engine(user_id)
|
||
AsyncSessionLocal = async_sessionmaker(
|
||
engine, class_=AsyncSession, expire_on_commit=False
|
||
)
|
||
async with AsyncSessionLocal() as session:
|
||
result = await session.execute(
|
||
select(BackgroundTask).where(BackgroundTask.id == task_id)
|
||
)
|
||
task = result.scalar_one_or_none()
|
||
if task and task.status == "pending":
|
||
if queue_size > 0:
|
||
task.status_message = f"排队中,前方还有 {queue_size} 个任务等待..."
|
||
else:
|
||
task.status_message = "即将开始执行..."
|
||
task.progress_details = {"stage": "queued", "queue_size": queue_size}
|
||
task.updated_at = datetime.now()
|
||
await session.commit()
|
||
except Exception as e:
|
||
logger.error(f"更新队列位置信息失败: {e}")
|
||
|
||
def get_queue_size(self, user_id: str = None) -> int:
|
||
"""获取队列中等待的任务数量"""
|
||
if user_id:
|
||
queue = self._user_queues.get(user_id)
|
||
return queue.qsize() if queue else 0
|
||
# 所有用户队列总数
|
||
return sum(q.qsize() for q in self._user_queues.values())
|
||
|
||
def get_all_queue_info(self) -> Dict[str, int]:
|
||
"""获取所有用户的队列信息"""
|
||
return {
|
||
uid: q.qsize() for uid, q in self._user_queues.items() if q.qsize() > 0
|
||
}
|
||
|
||
|
||
# 全局单例
|
||
background_task_service = BackgroundTaskService() |