feat: 后台任务系统 + JSON容错解析 + SSE心跳保活 + 多项Bug修复
新功能: - 大纲/章节生成改为服务端后台任务,支持断线续传 - 后台任务队列排队执行,按用户排队(同用户串行不同用户并发) - 章节管理页面添加后台任务列表弹窗和进度面板 - 章节状态添加 pending(待处理)选项 - 集成json5容错解析器 + 上下文感知JSON修复 - SSE流式生成添加心跳保活,防止连接超时 - SSEPostClient添加credentials:include修复network error - 每章最大伏笔数从2调整为5 - 添加大纲读区伏笔的功能 Bug修复: - 修复AI生成JSON中未转义引号/中文标点/多对象属性值未合并 - 修复JSON非法转义字符清洗和中文引号处理 - 修复MCP插件TimeoutError/连接失败上下文清理 - MCP插件后台注册添加重试机制 - 续写模式添加缺失的mcp_references参数 - 修复Alembic迁移链分叉 - 使用torch CPU版本加速Docker构建
This commit is contained in:
@@ -0,0 +1,387 @@
|
||||
"""后台任务管理服务 - 管理长时间运行的AI生成任务"""
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
from typing import Dict, Any, Optional, Callable, Awaitable
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
from sqlalchemy import select, update
|
||||
from app.database import get_engine
|
||||
from app.models.background_task import BackgroundTask
|
||||
from app.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class TaskProgressTracker:
|
||||
"""后台任务进度追踪器(替代SSE的WizardProgressTracker)"""
|
||||
|
||||
def __init__(self, task_id: str, user_id: str, task_name: str = "任务"):
|
||||
self.task_id = task_id
|
||||
self.user_id = user_id
|
||||
self.task_name = task_name
|
||||
self.current_progress = 0
|
||||
self._last_generating_progress = 20
|
||||
|
||||
async def _update_task(self, **kwargs):
|
||||
"""更新任务状态到数据库"""
|
||||
try:
|
||||
engine = await get_engine(self.user_id)
|
||||
AsyncSessionLocal = async_sessionmaker(
|
||||
engine, class_=AsyncSession, expire_on_commit=False
|
||||
)
|
||||
async with AsyncSessionLocal() as session:
|
||||
result = await session.execute(
|
||||
select(BackgroundTask).where(BackgroundTask.id == self.task_id)
|
||||
)
|
||||
task = result.scalar_one_or_none()
|
||||
if task:
|
||||
for key, value in kwargs.items():
|
||||
setattr(task, key, value)
|
||||
task.updated_at = datetime.now()
|
||||
await session.commit()
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 更新任务进度失败: {e}")
|
||||
|
||||
async def start(self, message: str = None):
|
||||
self.current_progress = 0
|
||||
msg = message or f"开始生成{self.task_name}..."
|
||||
await self._update_task(
|
||||
status="running", progress=0, status_message=msg,
|
||||
started_at=datetime.now(),
|
||||
progress_details={"stage": "init", "message": msg}
|
||||
)
|
||||
|
||||
async def loading(self, message: str = None, sub_progress: float = 0.5):
|
||||
progress = 5 + int(10 * sub_progress)
|
||||
self.current_progress = progress
|
||||
msg = message or "加载数据中..."
|
||||
await self._update_task(
|
||||
progress=progress, status_message=msg,
|
||||
progress_details={"stage": "loading", "message": msg}
|
||||
)
|
||||
|
||||
async def preparing(self, message: str = None):
|
||||
self.current_progress = 17
|
||||
msg = message or "准备AI提示词..."
|
||||
await self._update_task(
|
||||
progress=17, status_message=msg,
|
||||
progress_details={"stage": "preparing", "message": msg}
|
||||
)
|
||||
|
||||
async def generating(self, current_chars: int = 0, estimated_total: int = 5000,
|
||||
message: str = None, retry_count: int = 0, max_retries: int = 3):
|
||||
sub_progress = min(current_chars / max(estimated_total, 1), 1.0)
|
||||
progress = 20 + int(65 * sub_progress)
|
||||
if progress < self._last_generating_progress:
|
||||
progress = self._last_generating_progress
|
||||
else:
|
||||
self._last_generating_progress = progress
|
||||
self.current_progress = progress
|
||||
|
||||
retry_suffix = f" (重试 {retry_count}/{max_retries})" if retry_count > 0 else ""
|
||||
msg = message or f"生成{self.task_name}中... ({current_chars}字符){retry_suffix}"
|
||||
await self._update_task(
|
||||
progress=progress, status_message=msg,
|
||||
progress_details={"stage": "generating", "message": msg, "current_chars": current_chars}
|
||||
)
|
||||
|
||||
async def parsing(self, message: str = None):
|
||||
self.current_progress = 88
|
||||
msg = message or f"解析{self.task_name}数据..."
|
||||
await self._update_task(
|
||||
progress=88, status_message=msg,
|
||||
progress_details={"stage": "parsing", "message": msg}
|
||||
)
|
||||
|
||||
async def saving(self, message: str = None, sub_progress: float = 0.5):
|
||||
progress = 92 + int(6 * sub_progress)
|
||||
self.current_progress = progress
|
||||
msg = message or f"保存{self.task_name}到数据库..."
|
||||
await self._update_task(
|
||||
progress=progress, status_message=msg,
|
||||
progress_details={"stage": "saving", "message": msg}
|
||||
)
|
||||
|
||||
async def complete(self, message: str = None):
|
||||
self.current_progress = 100
|
||||
msg = message or f"{self.task_name}生成完成!"
|
||||
await self._update_task(
|
||||
status="completed", progress=100, status_message=msg,
|
||||
completed_at=datetime.now(),
|
||||
progress_details={"stage": "complete", "message": msg}
|
||||
)
|
||||
|
||||
async def error(self, error_message: str):
|
||||
await self._update_task(
|
||||
status="failed", error_message=error_message,
|
||||
status_message=f"失败: {error_message}",
|
||||
completed_at=datetime.now(),
|
||||
progress_details={"stage": "error", "message": error_message}
|
||||
)
|
||||
|
||||
async def warning(self, message: str):
|
||||
await self._update_task(
|
||||
status_message=f"⚠️ {message}",
|
||||
progress_details={"stage": "warning", "message": message}
|
||||
)
|
||||
|
||||
async def retry(self, retry_count: int, max_retries: int, reason: str = "准备重试"):
|
||||
msg = f"⚠️ {reason}... ({retry_count}/{max_retries})"
|
||||
await self._update_task(
|
||||
status_message=msg, retry_count=retry_count,
|
||||
progress_details={"stage": "retry", "message": msg, "retry_count": retry_count}
|
||||
)
|
||||
|
||||
def reset_generating_progress(self):
|
||||
self._last_generating_progress = 20
|
||||
|
||||
async def check_cancelled(self) -> bool:
|
||||
"""检查任务是否被取消"""
|
||||
try:
|
||||
engine = await get_engine(self.user_id)
|
||||
AsyncSessionLocal = async_sessionmaker(
|
||||
engine, class_=AsyncSession, expire_on_commit=False
|
||||
)
|
||||
async with AsyncSessionLocal() as session:
|
||||
result = await session.execute(
|
||||
select(BackgroundTask.cancel_requested)
|
||||
.where(BackgroundTask.id == self.task_id)
|
||||
)
|
||||
cancelled = result.scalar_one_or_none()
|
||||
return bool(cancelled)
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
class BackgroundTaskService:
|
||||
"""后台任务管理服务(按用户排队:同用户任务逐个执行,不同用户可并发)"""
|
||||
|
||||
def __init__(self):
|
||||
self._user_queues: Dict[str, asyncio.Queue] = {} # user_id -> Queue
|
||||
self._user_workers: Dict[str, bool] = {} # user_id -> worker是否运行中
|
||||
|
||||
def _ensure_user_queue(self, user_id: str) -> asyncio.Queue:
|
||||
"""确保指定用户的队列已初始化"""
|
||||
if user_id not in self._user_queues:
|
||||
self._user_queues[user_id] = asyncio.Queue()
|
||||
return self._user_queues[user_id]
|
||||
|
||||
async def _start_user_worker(self, user_id: str):
|
||||
"""启动指定用户的工作协程"""
|
||||
if self._user_workers.get(user_id, False):
|
||||
return
|
||||
self._user_workers[user_id] = True
|
||||
asyncio.create_task(self._user_worker_loop(user_id))
|
||||
logger.info(f"📋 用户 {user_id[:8]} 的任务队列工作协程已启动")
|
||||
|
||||
async def _user_worker_loop(self, user_id: str):
|
||||
"""从指定用户的队列中逐个取出任务并执行"""
|
||||
queue = self._user_queues[user_id]
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
task_item = await queue.get()
|
||||
task_id = task_item["task_id"]
|
||||
task_func = task_item["task_func"]
|
||||
args = task_item["args"]
|
||||
kwargs = task_item["kwargs"]
|
||||
|
||||
logger.info(f"🔄 [用户{user_id[:8]}] 队列开始执行任务: {task_id[:8]} (队列剩余: {queue.qsize()})")
|
||||
|
||||
try:
|
||||
await task_func(task_id, args["user_id"], *args["extra_args"], **kwargs)
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 后台任务 {task_id[:8]} 异常: {e}", exc_info=True)
|
||||
# 确保任务状态更新为失败
|
||||
try:
|
||||
engine = await get_engine(user_id)
|
||||
AsyncSessionLocal = async_sessionmaker(
|
||||
engine, class_=AsyncSession, expire_on_commit=False
|
||||
)
|
||||
async with AsyncSessionLocal() as session:
|
||||
result = await session.execute(
|
||||
select(BackgroundTask).where(BackgroundTask.id == task_id)
|
||||
)
|
||||
task = result.scalar_one_or_none()
|
||||
if task and task.status == "running":
|
||||
task.status = "failed"
|
||||
task.error_message = str(e)
|
||||
task.status_message = f"任务失败: {str(e)}"
|
||||
task.completed_at = datetime.now()
|
||||
await session.commit()
|
||||
except Exception as update_err:
|
||||
logger.error(f"❌ 更新失败任务状态失败: {update_err}")
|
||||
finally:
|
||||
queue.task_done()
|
||||
logger.info(f"✅ [用户{user_id[:8]}] 队列任务完成: {task_id[:8]} (队列剩余: {queue.qsize()})")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ [用户{user_id[:8]}] 队列工作循环异常: {e}", exc_info=True)
|
||||
finally:
|
||||
# 工作协程退出时清理标记
|
||||
self._user_workers.pop(user_id, None)
|
||||
logger.info(f"📋 用户 {user_id[:8]} 的工作协程已退出")
|
||||
|
||||
@staticmethod
|
||||
async def create_task(
|
||||
user_id: str,
|
||||
project_id: str,
|
||||
task_type: str,
|
||||
task_input: Dict[str, Any] = None,
|
||||
db: AsyncSession = None
|
||||
) -> BackgroundTask:
|
||||
"""创建后台任务记录"""
|
||||
task = BackgroundTask(
|
||||
user_id=user_id,
|
||||
project_id=project_id,
|
||||
task_type=task_type,
|
||||
task_input=task_input or {},
|
||||
status="pending",
|
||||
progress=0,
|
||||
status_message="任务已创建,等待执行..."
|
||||
)
|
||||
db.add(task)
|
||||
await db.commit()
|
||||
await db.refresh(task)
|
||||
logger.info(f"📋 创建后台任务: {task.id[:8]} type={task_type} project={project_id[:8]}")
|
||||
return task
|
||||
|
||||
@staticmethod
|
||||
async def get_task(task_id: str, user_id: str, db: AsyncSession) -> Optional[BackgroundTask]:
|
||||
"""获取任务详情"""
|
||||
result = await db.execute(
|
||||
select(BackgroundTask).where(
|
||||
BackgroundTask.id == task_id,
|
||||
BackgroundTask.user_id == user_id
|
||||
)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
@staticmethod
|
||||
async def get_project_tasks(
|
||||
project_id: str, user_id: str, db: AsyncSession,
|
||||
task_type: str = None, limit: int = 20
|
||||
) -> list:
|
||||
"""获取项目的任务列表"""
|
||||
query = (
|
||||
select(BackgroundTask)
|
||||
.where(
|
||||
BackgroundTask.project_id == project_id,
|
||||
BackgroundTask.user_id == user_id
|
||||
)
|
||||
.order_by(BackgroundTask.created_at.desc())
|
||||
)
|
||||
if task_type:
|
||||
query = query.where(BackgroundTask.task_type == task_type)
|
||||
query = query.limit(limit)
|
||||
result = await db.execute(query)
|
||||
return result.scalars().all()
|
||||
|
||||
@staticmethod
|
||||
async def cancel_task(task_id: str, user_id: str, db: AsyncSession) -> bool:
|
||||
"""请求取消任务"""
|
||||
result = await db.execute(
|
||||
select(BackgroundTask).where(
|
||||
BackgroundTask.id == task_id,
|
||||
BackgroundTask.user_id == user_id
|
||||
)
|
||||
)
|
||||
task = result.scalar_one_or_none()
|
||||
if not task:
|
||||
return False
|
||||
if task.status not in ("pending", "running"):
|
||||
return False
|
||||
task.cancel_requested = True
|
||||
task.status = "cancelled"
|
||||
task.status_message = "任务已取消"
|
||||
task.completed_at = datetime.now()
|
||||
await db.commit()
|
||||
logger.info(f"🚫 取消任务: {task_id[:8]}")
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
async def cleanup_old_tasks(user_id: str, db: AsyncSession, days: int = 7):
|
||||
"""清理旧任务记录"""
|
||||
from sqlalchemy import delete as sql_delete
|
||||
from datetime import timedelta
|
||||
cutoff = datetime.now() - timedelta(days=days)
|
||||
result = await db.execute(
|
||||
sql_delete(BackgroundTask).where(
|
||||
BackgroundTask.user_id == user_id,
|
||||
BackgroundTask.status.in_(["completed", "failed", "cancelled"]),
|
||||
BackgroundTask.completed_at < cutoff
|
||||
)
|
||||
)
|
||||
if result.rowcount > 0:
|
||||
await db.commit()
|
||||
logger.info(f"🧹 清理用户 {user_id[:8]} 的 {result.rowcount} 条旧任务记录")
|
||||
|
||||
async def spawn_background_task(
|
||||
self,
|
||||
task_id: str,
|
||||
user_id: str,
|
||||
task_func: Callable[..., Awaitable],
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
"""
|
||||
将任务加入该用户的队列排队执行(同一用户FIFO,不同用户可并发)
|
||||
|
||||
Args:
|
||||
task_id: 任务ID
|
||||
user_id: 用户ID
|
||||
task_func: 异步任务函数
|
||||
*args, **kwargs: 传递给task_func的参数
|
||||
"""
|
||||
# 确保该用户的队列和工作协程已启动
|
||||
queue = self._ensure_user_queue(user_id)
|
||||
await self._start_user_worker(user_id)
|
||||
|
||||
# 将任务放入该用户的队列
|
||||
await queue.put({
|
||||
"task_id": task_id,
|
||||
"task_func": task_func,
|
||||
"args": {"user_id": user_id, "extra_args": args},
|
||||
"kwargs": kwargs,
|
||||
})
|
||||
queue_size = queue.qsize()
|
||||
logger.info(f"📥 任务已加入用户 {user_id[:8]} 的队列: {task_id[:8]} (当前队列长度: {queue_size})")
|
||||
|
||||
# 更新任务状态,显示排队位置
|
||||
try:
|
||||
engine = await get_engine(user_id)
|
||||
AsyncSessionLocal = async_sessionmaker(
|
||||
engine, class_=AsyncSession, expire_on_commit=False
|
||||
)
|
||||
async with AsyncSessionLocal() as session:
|
||||
result = await session.execute(
|
||||
select(BackgroundTask).where(BackgroundTask.id == task_id)
|
||||
)
|
||||
task = result.scalar_one_or_none()
|
||||
if task and task.status == "pending":
|
||||
if queue_size > 0:
|
||||
task.status_message = f"排队中,前方还有 {queue_size} 个任务等待..."
|
||||
else:
|
||||
task.status_message = "即将开始执行..."
|
||||
task.progress_details = {"stage": "queued", "queue_size": queue_size}
|
||||
task.updated_at = datetime.now()
|
||||
await session.commit()
|
||||
except Exception as e:
|
||||
logger.error(f"更新队列位置信息失败: {e}")
|
||||
|
||||
def get_queue_size(self, user_id: str = None) -> int:
|
||||
"""获取队列中等待的任务数量"""
|
||||
if user_id:
|
||||
queue = self._user_queues.get(user_id)
|
||||
return queue.qsize() if queue else 0
|
||||
# 所有用户队列总数
|
||||
return sum(q.qsize() for q in self._user_queues.values())
|
||||
|
||||
def get_all_queue_info(self) -> Dict[str, int]:
|
||||
"""获取所有用户的队列信息"""
|
||||
return {
|
||||
uid: q.qsize() for uid, q in self._user_queues.items() if q.qsize() > 0
|
||||
}
|
||||
|
||||
|
||||
# 全局单例
|
||||
background_task_service = BackgroundTaskService()
|
||||
@@ -1284,7 +1284,7 @@ class ForeshadowService:
|
||||
planted_foreshadows = await self.get_planted_foreshadows_for_analysis(db, project_id)
|
||||
|
||||
# 每章最多创建的新伏笔数量
|
||||
MAX_NEW_FORESHADOWS_PER_CHAPTER = 2
|
||||
MAX_NEW_FORESHADOWS_PER_CHAPTER = 5
|
||||
new_foreshadow_count = 0
|
||||
|
||||
for fs_data in analysis_foreshadows:
|
||||
|
||||
@@ -26,15 +26,101 @@ _QUOTE_MAP = {
|
||||
}
|
||||
|
||||
|
||||
def _is_content_quote(text: str, pos: int) -> bool:
|
||||
"""
|
||||
判断字符串值内的 '"' 是否为内容引号(需转义)而非 JSON 结束引号。
|
||||
|
||||
合法 JSON 中,字符串结束引号之后的非空白字符必须是:
|
||||
',' (值分隔) / '}' (关闭对象) / ']' (关闭数组)
|
||||
|
||||
如果 '"' 后面不符合这些模式,则是 AI 写入的内容引号,需要转义。
|
||||
"""
|
||||
j = pos + 1
|
||||
|
||||
# 跳过空格和制表符
|
||||
while j < len(text) and text[j] in ' \t':
|
||||
j += 1
|
||||
|
||||
if j >= len(text):
|
||||
return False # 文本末尾,视为结束引号
|
||||
|
||||
ch = text[j]
|
||||
|
||||
# } 或 ] → 结束引号
|
||||
if ch in ('}', ']'):
|
||||
return False
|
||||
|
||||
# 换行 → 检查下一行开头判断
|
||||
if ch == '\n' or ch == '\r':
|
||||
k = j + (2 if (ch == '\r' and j + 1 < len(text) and text[j + 1] == '\n') else 1)
|
||||
while k < len(text) and text[k] in ' \t':
|
||||
k += 1
|
||||
if k >= len(text):
|
||||
return False
|
||||
# 下一行以 " (JSON key) 或 } 或 ] 开头 → 结束引号
|
||||
if text[k] == '"' or text[k] in ('}', ']'):
|
||||
return False
|
||||
return True
|
||||
|
||||
# , → 需要检查逗号后面是什么
|
||||
if ch == ',':
|
||||
k = j + 1
|
||||
while k < len(text) and text[k] in ' \t':
|
||||
k += 1
|
||||
|
||||
if k >= len(text):
|
||||
return False
|
||||
|
||||
# 逗号后跟换行 → 检查下一行
|
||||
if text[k] in ('\n', '\r'):
|
||||
k2 = k + (2 if (text[k] == '\r' and k + 1 < len(text) and text[k + 1] == '\n') else 1)
|
||||
while k2 < len(text) and text[k2] in ' \t\n\r':
|
||||
k2 += 1
|
||||
if k2 >= len(text):
|
||||
return False
|
||||
if text[k2] == '"' or text[k2] in ('}', ']'):
|
||||
return False
|
||||
return True
|
||||
|
||||
after_comma = text[k]
|
||||
|
||||
# 结构性逗号后应为 JSON 值的开头
|
||||
if after_comma == '"':
|
||||
return False # 字符串值或 key
|
||||
if after_comma.isdigit() or after_comma == '-':
|
||||
return False # 数字
|
||||
if after_comma in ('{', '['):
|
||||
return False # 对象/数组
|
||||
if text[k:k+4] in ('true', 'null'):
|
||||
return False
|
||||
if text[k:k+5] == 'false':
|
||||
return False
|
||||
|
||||
# 逗号后不是 JSON 值开头 → 内容逗号,引号是内容引号
|
||||
return True
|
||||
|
||||
# : → 通常在字符串结束后不可能出现,保守处理为结束引号
|
||||
if ch == ':':
|
||||
return False
|
||||
|
||||
# 其他字符(中文、字母等)→ 内容引号
|
||||
return True
|
||||
|
||||
|
||||
def _fix_json_string_values(text: str) -> str:
|
||||
"""
|
||||
修复JSON字符串值中的常见问题:
|
||||
1. 裸换行符/制表符 → 转义
|
||||
2. 字符串值内的中文引号 → 转义为ASCII引号(避免破坏JSON结构)
|
||||
3. 结构位置的中文引号 → 直接替换为ASCII引号
|
||||
上下文感知的 JSON 修复,区分字符串内外分别处理。
|
||||
|
||||
AI生成的JSON常在字符串值中插入未转义的换行符和中文引号。
|
||||
此函数遍历文本,区分字符串内外,分别处理。
|
||||
字符串值内:
|
||||
1. 裸换行符/制表符 → 转义
|
||||
2. 中文引号(""等) → 转义为 \\"
|
||||
3. 未转义的 ASCII 双引号 → 智能检测:内容引号转义,结束引号保留
|
||||
4. 中文逗号/冒号 → 保留原样(是内容字符)
|
||||
|
||||
结构位置(字符串外):
|
||||
1. 中文引号 → ASCII 引号
|
||||
2. 中文逗号 → ASCII 逗号
|
||||
3. 中文冒号 → ASCII 冒号
|
||||
"""
|
||||
if not text or '"' not in text:
|
||||
return text
|
||||
@@ -47,111 +133,234 @@ def _fix_json_string_values(text: str) -> str:
|
||||
while i < len(text):
|
||||
c = text[i]
|
||||
|
||||
if c == '"' and not in_string:
|
||||
# 进入字符串
|
||||
in_string = True
|
||||
# === 非字符串内(结构位置)===
|
||||
if not in_string:
|
||||
# 结构位置的中文标点 → ASCII
|
||||
if c == '\uff0c': # ,→ ,
|
||||
result.append(',')
|
||||
fixed_count += 1
|
||||
i += 1
|
||||
continue
|
||||
if c == '\uff1a': # :→ :
|
||||
result.append(':')
|
||||
fixed_count += 1
|
||||
i += 1
|
||||
continue
|
||||
if c in _QUOTE_MAP:
|
||||
result.append(_QUOTE_MAP[c])
|
||||
fixed_count += 1
|
||||
i += 1
|
||||
continue
|
||||
|
||||
# ASCII 双引号 → 进入字符串
|
||||
if c == '"':
|
||||
in_string = True
|
||||
result.append(c)
|
||||
i += 1
|
||||
continue
|
||||
|
||||
result.append(c)
|
||||
i += 1
|
||||
continue
|
||||
|
||||
if in_string:
|
||||
if c == '\\':
|
||||
# 转义字符,检查下一个字符是否合法
|
||||
if i + 1 < len(text):
|
||||
next_c = text[i + 1]
|
||||
# JSON 合法转义:\" \\ \/ \b \f \n \r \t \uXXXX
|
||||
if next_c in ('"', '\\', '/', 'b', 'f', 'n', 'r', 't'):
|
||||
# 合法转义,直接保留
|
||||
result.append(c)
|
||||
result.append(next_c)
|
||||
i += 2
|
||||
# === 字符串值内 ===
|
||||
|
||||
# 转义字符处理
|
||||
if c == '\\':
|
||||
if i + 1 < len(text):
|
||||
next_c = text[i + 1]
|
||||
if next_c in ('"', '\\', '/', 'b', 'f', 'n', 'r', 't'):
|
||||
result.append(c)
|
||||
result.append(next_c)
|
||||
i += 2
|
||||
continue
|
||||
elif next_c == 'u':
|
||||
if i + 5 < len(text) and all(text[i+2+k] in '0123456789abcdefABCDEF' for k in range(4)):
|
||||
result.append(text[i:i+6])
|
||||
i += 6
|
||||
continue
|
||||
elif next_c == 'u':
|
||||
# Unicode 转义 \uXXXX,检查是否有4个十六进制字符
|
||||
if i + 5 < len(text) and all(text[i+2+k] in '0123456789abcdefABCDEF' for k in range(4)):
|
||||
result.append(text[i:i+6])
|
||||
i += 6
|
||||
continue
|
||||
else:
|
||||
# 不完整的unicode转义,去掉反斜杠
|
||||
result.append(next_c)
|
||||
fixed_count += 1
|
||||
i += 2
|
||||
continue
|
||||
else:
|
||||
# 非法转义字符(如 \c \p \d 等),去掉反斜杠只保留字符
|
||||
result.append(next_c)
|
||||
fixed_count += 1
|
||||
i += 2
|
||||
continue
|
||||
else:
|
||||
# 末尾孤立的反斜杠,去掉
|
||||
result.append(next_c)
|
||||
fixed_count += 1
|
||||
i += 1
|
||||
i += 2
|
||||
continue
|
||||
|
||||
if c == '"':
|
||||
# 字符串结束
|
||||
else:
|
||||
fixed_count += 1
|
||||
i += 1
|
||||
continue
|
||||
|
||||
# ASCII 双引号 → 智能判断是结束引号还是内容引号
|
||||
if c == '"':
|
||||
if _is_content_quote(text, i):
|
||||
# 内容引号,需要转义
|
||||
result.append('\\')
|
||||
result.append('"')
|
||||
fixed_count += 1
|
||||
i += 1
|
||||
continue
|
||||
else:
|
||||
# 结束引号
|
||||
in_string = False
|
||||
result.append(c)
|
||||
i += 1
|
||||
continue
|
||||
|
||||
if c == '\n':
|
||||
# 裸换行符 → 替换为转义换行
|
||||
result.append('\\')
|
||||
result.append('n')
|
||||
fixed_count += 1
|
||||
i += 1
|
||||
continue
|
||||
|
||||
if c == '\r':
|
||||
# 裸回车符 → 忽略或替换
|
||||
if i + 1 < len(text) and text[i + 1] == '\n':
|
||||
result.append('\\')
|
||||
result.append('n')
|
||||
fixed_count += 1
|
||||
i += 2
|
||||
else:
|
||||
result.append('\\')
|
||||
result.append('n')
|
||||
fixed_count += 1
|
||||
i += 1
|
||||
continue
|
||||
|
||||
if c == '\t':
|
||||
# 裸制表符 → 替换为转义制表符
|
||||
result.append('\\')
|
||||
result.append('t')
|
||||
fixed_count += 1
|
||||
i += 1
|
||||
continue
|
||||
|
||||
# 字符串值内的中文引号 → 转义为 \"(避免破坏JSON结构)
|
||||
if c in _QUOTE_MAP:
|
||||
result.append('\\')
|
||||
result.append(_QUOTE_MAP[c])
|
||||
fixed_count += 1
|
||||
i += 1
|
||||
continue
|
||||
|
||||
# 非字符串内的字符
|
||||
# 结构位置的中文引号 → 直接替换
|
||||
if not in_string and c in _QUOTE_MAP:
|
||||
result.append(_QUOTE_MAP[c])
|
||||
# 裸换行符 → 转义
|
||||
if c == '\n':
|
||||
result.append('\\')
|
||||
result.append('n')
|
||||
fixed_count += 1
|
||||
i += 1
|
||||
continue
|
||||
|
||||
if c == '\r':
|
||||
if i + 1 < len(text) and text[i + 1] == '\n':
|
||||
result.append('\\')
|
||||
result.append('n')
|
||||
fixed_count += 1
|
||||
i += 2
|
||||
else:
|
||||
result.append('\\')
|
||||
result.append('n')
|
||||
fixed_count += 1
|
||||
i += 1
|
||||
continue
|
||||
|
||||
if c == '\t':
|
||||
result.append('\\')
|
||||
result.append('t')
|
||||
fixed_count += 1
|
||||
i += 1
|
||||
continue
|
||||
|
||||
# 中文引号处理
|
||||
if c in _QUOTE_MAP:
|
||||
mapped = _QUOTE_MAP[c]
|
||||
if mapped == '"':
|
||||
# 中文双引号在字符串内需要转义
|
||||
result.append('\\')
|
||||
result.append('"')
|
||||
else:
|
||||
# 中文单引号在双引号字符串内不需要转义,直接替换
|
||||
result.append(mapped)
|
||||
fixed_count += 1
|
||||
i += 1
|
||||
continue
|
||||
|
||||
# 其他字符(包括中文逗号、中文冒号)→ 保留原样
|
||||
result.append(c)
|
||||
i += 1
|
||||
|
||||
if fixed_count > 0:
|
||||
logger.debug(f"✅ 修复了{fixed_count}个JSON问题(裸控制字符/中文引号)")
|
||||
logger.debug(f"✅ 修复了{fixed_count}个JSON问题(引号/控制字符/中文标点)")
|
||||
|
||||
return ''.join(result)
|
||||
|
||||
|
||||
def _fix_all_invalid_escapes(text: str) -> str:
|
||||
"""
|
||||
兜底修复:扫描整个文本中的无效JSON转义序列。
|
||||
|
||||
当 _fix_json_string_values 因字符串边界追踪错误而遗漏某些无效转义时,
|
||||
此函数作为兜底,不依赖字符串状态追踪,扫描整个文本修复所有无效转义。
|
||||
|
||||
有效JSON转义:\\" \\\\ \\/ \\b \\f \\n \\r \\t \\uXXXX
|
||||
其他 \\X 均为无效转义,修复方式为去掉反斜杠只保留字符。
|
||||
"""
|
||||
if '\\' not in text:
|
||||
return text
|
||||
|
||||
result = []
|
||||
i = 0
|
||||
fixed = 0
|
||||
|
||||
while i < len(text):
|
||||
if text[i] == '\\' and i + 1 < len(text):
|
||||
next_c = text[i + 1]
|
||||
if next_c in ('"', '\\', '/', 'b', 'f', 'n', 'r', 't'):
|
||||
# 有效转义,保留
|
||||
result.append(text[i])
|
||||
result.append(next_c)
|
||||
i += 2
|
||||
continue
|
||||
elif next_c == 'u':
|
||||
# Unicode 转义,检查是否有4个十六进制字符
|
||||
if i + 5 < len(text) and all(
|
||||
text[i + 2 + k] in '0123456789abcdefABCDEF'
|
||||
for k in range(4)
|
||||
):
|
||||
result.append(text[i:i + 6])
|
||||
i += 6
|
||||
continue
|
||||
else:
|
||||
# 不完整的unicode转义,去掉反斜杠
|
||||
result.append(next_c)
|
||||
fixed += 1
|
||||
i += 2
|
||||
continue
|
||||
else:
|
||||
# 无效转义(如 \引 \影 \某种 等),去掉反斜杠只保留字符
|
||||
result.append(next_c)
|
||||
fixed += 1
|
||||
i += 2
|
||||
continue
|
||||
else:
|
||||
result.append(text[i])
|
||||
i += 1
|
||||
|
||||
if fixed > 0:
|
||||
logger.info(f"✅ 兜底修复了{fixed}个无效JSON转义序列")
|
||||
|
||||
return ''.join(result)
|
||||
|
||||
|
||||
def _fix_multiple_objects_as_value(text: str) -> str:
|
||||
"""
|
||||
修复AI生成的JSON中,多个对象作为属性值但未合并的问题。
|
||||
|
||||
示例:
|
||||
"key": {"a": "1"}, {"b": "2"} → "key": {"a": "1", "b": "2"}
|
||||
|
||||
AI有时在输出对象类型的属性值时,输出了多个独立的对象而不是合并为一个。
|
||||
例如 relationship_changes 字段输出多个角色关系变化时可能出现此问题。
|
||||
此函数检测并合并这些对象。
|
||||
"""
|
||||
if '{' not in text or '}' not in text:
|
||||
return text
|
||||
|
||||
# 匹配嵌套层级不超过2的对象: { ... } 其中 ... 不含 { 或仅含单层嵌套
|
||||
nested_obj = r'\{(?:[^{}]|\{(?:[^{}]|\{[^{}]*\})*\})*\}'
|
||||
|
||||
# 模式:属性冒号后跟一个对象,然后逗号和另一个对象(没有属性名)
|
||||
# 即 "key": {obj1}, {obj2} → "key": {obj1, obj2}
|
||||
pattern = r'(":)\s*(' + nested_obj + r')\s*,\s*(' + nested_obj + r')'
|
||||
|
||||
def merge_objects(match):
|
||||
colon = match.group(1)
|
||||
obj1_content = match.group(2)[1:-1] # 去掉外层的 { }
|
||||
obj2_content = match.group(3)[1:-1] # 去掉外层的 { }
|
||||
# 合并为一个对象
|
||||
return f'{colon} {{{obj1_content}, {obj2_content}}}'
|
||||
|
||||
prev = None
|
||||
count = 0
|
||||
max_iterations = 10
|
||||
while prev != text and count < max_iterations:
|
||||
prev = text
|
||||
text = re.sub(pattern, merge_objects, text)
|
||||
count += 1
|
||||
|
||||
if count > 1:
|
||||
logger.info(f"✅ 修复了{count - 1}处多对象属性值合并")
|
||||
|
||||
return text
|
||||
|
||||
|
||||
def clean_json_response(text: str) -> str:
|
||||
"""清洗 AI 返回的 JSON(改进版 - 流式安全)"""
|
||||
try:
|
||||
@@ -162,11 +371,8 @@ def clean_json_response(text: str) -> str:
|
||||
original_length = len(text)
|
||||
logger.debug(f"🔍 开始清洗JSON,原始长度: {original_length}")
|
||||
|
||||
# 替换中文逗号/冒号(AI可能在JSON结构位置使用,全局替换是安全的)
|
||||
text = text.replace('\uff0c', ',') # ,→ ,
|
||||
text = text.replace('\uff1a', ':') # :→ :
|
||||
|
||||
# 修复JSON中的中文引号和裸控制字符(上下文感知,区分字符串内外)
|
||||
# 上下文感知修复:中文引号/逗号/冒号、裸控制字符、未转义的内容引号
|
||||
# (区分字符串内外:结构位置替换为ASCII,字符串内保留或转义)
|
||||
text = _fix_json_string_values(text)
|
||||
|
||||
# 去除 markdown 代码块
|
||||
@@ -286,9 +492,35 @@ def clean_json_response(text: str) -> str:
|
||||
json.loads(result)
|
||||
logger.debug(f"✅ 清洗后JSON验证成功")
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"❌ 清洗后JSON仍然无效: {e}")
|
||||
logger.debug(f" 结果预览: {result[:500]}")
|
||||
logger.debug(f" 结果结尾: ...{result[-200:]}")
|
||||
logger.warning(f"⚠️ 清洗后JSON仍然无效: {e},尝试修复结构性问题...")
|
||||
|
||||
# 修复1:合并多对象属性值(AI可能输出 "key": {a:1}, {b:2} )
|
||||
result = _fix_multiple_objects_as_value(result)
|
||||
|
||||
try:
|
||||
json.loads(result)
|
||||
logger.info(f"✅ 修复多对象属性值后JSON验证成功")
|
||||
except json.JSONDecodeError:
|
||||
pass # 继续尝试其他修复
|
||||
else:
|
||||
return result
|
||||
|
||||
# 修复2:兜底修复无效转义序列(不依赖字符串边界追踪)
|
||||
logger.warning(f"⚠️ 继续尝试兜底修复无效转义...")
|
||||
result = _fix_all_invalid_escapes(result)
|
||||
try:
|
||||
json.loads(result)
|
||||
logger.info(f"✅ 兜底修复后JSON验证成功")
|
||||
except json.JSONDecodeError as e2:
|
||||
# 修复3:再次尝试合并多对象属性值(转义修复后可能产生新的合并机会)
|
||||
result = _fix_multiple_objects_as_value(result)
|
||||
try:
|
||||
json.loads(result)
|
||||
logger.info(f"✅ 二次修复后JSON验证成功")
|
||||
except json.JSONDecodeError as e3:
|
||||
logger.error(f"❌ 所有修复后JSON仍然无效: {e3}")
|
||||
logger.debug(f" 结果预览: {result[:500]}")
|
||||
logger.debug(f" 结果结尾: ...{result[-200:]}")
|
||||
|
||||
return result
|
||||
|
||||
@@ -339,6 +571,16 @@ def loads_json(text: str) -> Any:
|
||||
except (json.JSONDecodeError, Exception):
|
||||
pass
|
||||
|
||||
# 兜底修复无效转义序列后重试
|
||||
fixed_text = _fix_all_invalid_escapes(text)
|
||||
if fixed_text != text:
|
||||
try:
|
||||
result = json.loads(fixed_text)
|
||||
logger.info("✅ 兜底修复无效转义后json.loads成功")
|
||||
return result
|
||||
except (json.JSONDecodeError, Exception):
|
||||
pass
|
||||
|
||||
# json5 容错解析
|
||||
if HAS_JSON5:
|
||||
try:
|
||||
@@ -347,6 +589,14 @@ def loads_json(text: str) -> Any:
|
||||
logger.info("✅ json5容错解析成功")
|
||||
return result
|
||||
except Exception as e5:
|
||||
# json5也失败,尝试对修复后的文本使用json5
|
||||
if fixed_text != text:
|
||||
try:
|
||||
result = json5.loads(fixed_text)
|
||||
logger.info("✅ 兜底修复无效转义后json5容错解析成功")
|
||||
return result
|
||||
except Exception:
|
||||
pass
|
||||
logger.error(f"❌ json5容错解析也失败: {e5}")
|
||||
|
||||
# 最终失败,抛出标准异常
|
||||
|
||||
Reference in New Issue
Block a user