2025-10-30 11:14:43 +08:00
|
|
|
"""Server-Sent Events (SSE) 响应工具类"""
|
|
|
|
|
import json
|
|
|
|
|
import asyncio
|
2026-01-09 17:13:19 +08:00
|
|
|
from enum import Enum
|
|
|
|
|
from typing import AsyncGenerator, Dict, Any, Optional, Callable
|
|
|
|
|
from dataclasses import dataclass
|
2025-10-30 11:14:43 +08:00
|
|
|
from fastapi.responses import StreamingResponse
|
|
|
|
|
from app.logger import get_logger
|
|
|
|
|
|
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
|
|
|
|
|
2026-01-09 17:13:19 +08:00
|
|
|
class ProgressStage(Enum):
|
|
|
|
|
"""标准化进度阶段枚举"""
|
|
|
|
|
# 初始化阶段 (0-5%)
|
|
|
|
|
INIT = "init"
|
|
|
|
|
# 加载数据阶段 (5-15%)
|
|
|
|
|
LOADING = "loading"
|
|
|
|
|
# 准备提示词阶段 (15-20%)
|
|
|
|
|
PREPARING = "preparing"
|
|
|
|
|
# AI生成阶段 (20-85%)
|
|
|
|
|
GENERATING = "generating"
|
|
|
|
|
# 解析数据阶段 (85-92%)
|
|
|
|
|
PARSING = "parsing"
|
|
|
|
|
# 保存数据阶段 (92-98%)
|
|
|
|
|
SAVING = "saving"
|
|
|
|
|
# 完成阶段 (100%)
|
|
|
|
|
COMPLETE = "complete"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
|
class StageConfig:
|
|
|
|
|
"""阶段配置"""
|
|
|
|
|
start: int # 起始进度
|
|
|
|
|
end: int # 结束进度
|
|
|
|
|
default_message: str # 默认消息
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 标准进度阶段配置
|
|
|
|
|
STAGE_CONFIGS: Dict[ProgressStage, StageConfig] = {
|
|
|
|
|
ProgressStage.INIT: StageConfig(0, 5, "开始处理..."),
|
|
|
|
|
ProgressStage.LOADING: StageConfig(5, 15, "加载数据中..."),
|
|
|
|
|
ProgressStage.PREPARING: StageConfig(15, 20, "准备AI提示词..."),
|
|
|
|
|
ProgressStage.GENERATING: StageConfig(20, 85, "AI生成中..."),
|
|
|
|
|
ProgressStage.PARSING: StageConfig(85, 92, "解析数据..."),
|
|
|
|
|
ProgressStage.SAVING: StageConfig(92, 98, "保存到数据库..."),
|
|
|
|
|
ProgressStage.COMPLETE: StageConfig(100, 100, "完成!"),
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class WizardProgressTracker:
|
|
|
|
|
"""
|
|
|
|
|
向导进度追踪器 - 标准化管理SSE进度推送
|
|
|
|
|
|
|
|
|
|
使用示例:
|
|
|
|
|
tracker = WizardProgressTracker("世界观")
|
|
|
|
|
yield await tracker.start()
|
|
|
|
|
yield await tracker.loading("加载项目信息")
|
|
|
|
|
yield await tracker.preparing()
|
|
|
|
|
async for chunk in ai_stream:
|
|
|
|
|
yield await tracker.generating_chunk(chunk, len(accumulated))
|
|
|
|
|
yield await tracker.parsing()
|
|
|
|
|
yield await tracker.saving("保存世界观数据")
|
|
|
|
|
yield await tracker.complete()
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, task_name: str = "任务"):
|
|
|
|
|
"""
|
|
|
|
|
初始化进度追踪器
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
task_name: 任务名称,用于消息前缀
|
|
|
|
|
"""
|
|
|
|
|
self.task_name = task_name
|
|
|
|
|
self.current_stage = ProgressStage.INIT
|
|
|
|
|
self.current_progress = 0
|
|
|
|
|
self._last_generating_progress = 20 # 生成阶段的最后进度值
|
|
|
|
|
|
|
|
|
|
def _get_stage_progress(
|
|
|
|
|
self,
|
|
|
|
|
stage: ProgressStage,
|
|
|
|
|
sub_progress: float = 0.0
|
|
|
|
|
) -> int:
|
|
|
|
|
"""
|
|
|
|
|
计算阶段内的进度值
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
stage: 当前阶段
|
|
|
|
|
sub_progress: 阶段内子进度 (0.0-1.0)
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
总进度值 (0-100)
|
|
|
|
|
"""
|
|
|
|
|
config = STAGE_CONFIGS[stage]
|
|
|
|
|
if sub_progress <= 0:
|
|
|
|
|
return config.start
|
|
|
|
|
if sub_progress >= 1:
|
|
|
|
|
return config.end
|
|
|
|
|
return config.start + int((config.end - config.start) * sub_progress)
|
|
|
|
|
|
|
|
|
|
async def start(self, message: str = None) -> str:
|
|
|
|
|
"""开始阶段"""
|
|
|
|
|
self.current_stage = ProgressStage.INIT
|
|
|
|
|
self.current_progress = 0
|
|
|
|
|
msg = message or f"开始生成{self.task_name}..."
|
|
|
|
|
return await SSEResponse.send_progress(msg, 0, "processing")
|
|
|
|
|
|
|
|
|
|
async def loading(self, message: str = None, sub_progress: float = 0.5) -> str:
|
|
|
|
|
"""加载数据阶段"""
|
|
|
|
|
self.current_stage = ProgressStage.LOADING
|
|
|
|
|
progress = self._get_stage_progress(ProgressStage.LOADING, sub_progress)
|
|
|
|
|
self.current_progress = progress
|
|
|
|
|
msg = message or STAGE_CONFIGS[ProgressStage.LOADING].default_message
|
|
|
|
|
return await SSEResponse.send_progress(msg, progress, "processing")
|
|
|
|
|
|
|
|
|
|
async def preparing(self, message: str = None) -> str:
|
|
|
|
|
"""准备提示词阶段"""
|
|
|
|
|
self.current_stage = ProgressStage.PREPARING
|
|
|
|
|
progress = self._get_stage_progress(ProgressStage.PREPARING, 0.5)
|
|
|
|
|
self.current_progress = progress
|
|
|
|
|
msg = message or STAGE_CONFIGS[ProgressStage.PREPARING].default_message
|
|
|
|
|
return await SSEResponse.send_progress(msg, progress, "processing")
|
|
|
|
|
|
|
|
|
|
async def generating(
|
|
|
|
|
self,
|
|
|
|
|
current_chars: int = 0,
|
|
|
|
|
estimated_total: int = 5000,
|
|
|
|
|
message: str = None,
|
|
|
|
|
retry_count: int = 0,
|
|
|
|
|
max_retries: int = 3
|
|
|
|
|
) -> str:
|
|
|
|
|
"""
|
|
|
|
|
AI生成阶段进度更新
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
current_chars: 当前已生成字符数
|
|
|
|
|
estimated_total: 预估总字符数
|
|
|
|
|
message: 自定义消息
|
|
|
|
|
retry_count: 当前重试次数
|
|
|
|
|
max_retries: 最大重试次数
|
|
|
|
|
"""
|
|
|
|
|
self.current_stage = ProgressStage.GENERATING
|
|
|
|
|
|
|
|
|
|
# 计算生成进度 (0.0-1.0)
|
|
|
|
|
sub_progress = min(current_chars / max(estimated_total, 1), 1.0)
|
|
|
|
|
progress = self._get_stage_progress(ProgressStage.GENERATING, sub_progress)
|
|
|
|
|
|
|
|
|
|
# 确保进度单调递增
|
|
|
|
|
if progress < self._last_generating_progress:
|
|
|
|
|
progress = self._last_generating_progress
|
|
|
|
|
else:
|
|
|
|
|
self._last_generating_progress = progress
|
|
|
|
|
|
|
|
|
|
self.current_progress = progress
|
|
|
|
|
|
|
|
|
|
# 构建消息
|
|
|
|
|
retry_suffix = f" (重试 {retry_count}/{max_retries})" if retry_count > 0 else ""
|
|
|
|
|
if message:
|
|
|
|
|
msg = f"{message}{retry_suffix}"
|
|
|
|
|
else:
|
|
|
|
|
msg = f"生成{self.task_name}中... ({current_chars}字符){retry_suffix}"
|
|
|
|
|
|
|
|
|
|
return await SSEResponse.send_progress(msg, progress, "processing")
|
|
|
|
|
|
|
|
|
|
async def generating_chunk(self, chunk: str) -> str:
|
|
|
|
|
"""发送生成的内容块"""
|
|
|
|
|
return await SSEResponse.send_chunk(chunk)
|
|
|
|
|
|
|
|
|
|
async def parsing(self, message: str = None, sub_progress: float = 0.5) -> str:
|
|
|
|
|
"""解析数据阶段"""
|
|
|
|
|
self.current_stage = ProgressStage.PARSING
|
|
|
|
|
progress = self._get_stage_progress(ProgressStage.PARSING, sub_progress)
|
|
|
|
|
self.current_progress = progress
|
|
|
|
|
msg = message or f"解析{self.task_name}数据..."
|
|
|
|
|
return await SSEResponse.send_progress(msg, progress, "processing")
|
|
|
|
|
|
|
|
|
|
async def saving(self, message: str = None, sub_progress: float = 0.5) -> str:
|
|
|
|
|
"""保存数据阶段"""
|
|
|
|
|
self.current_stage = ProgressStage.SAVING
|
|
|
|
|
progress = self._get_stage_progress(ProgressStage.SAVING, sub_progress)
|
|
|
|
|
self.current_progress = progress
|
|
|
|
|
msg = message or f"保存{self.task_name}到数据库..."
|
|
|
|
|
return await SSEResponse.send_progress(msg, progress, "processing")
|
|
|
|
|
|
|
|
|
|
async def complete(self, message: str = None) -> str:
|
|
|
|
|
"""完成阶段"""
|
|
|
|
|
self.current_stage = ProgressStage.COMPLETE
|
|
|
|
|
self.current_progress = 100
|
|
|
|
|
msg = message or f"{self.task_name}生成完成!"
|
|
|
|
|
return await SSEResponse.send_progress(msg, 100, "success")
|
|
|
|
|
|
|
|
|
|
async def warning(self, message: str) -> str:
|
|
|
|
|
"""发送警告消息(保持当前进度)"""
|
|
|
|
|
return await SSEResponse.send_progress(
|
|
|
|
|
f"⚠️ {message}",
|
|
|
|
|
self.current_progress,
|
|
|
|
|
"warning"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
async def retry(self, retry_count: int, max_retries: int, reason: str = "准备重试") -> str:
|
|
|
|
|
"""发送重试消息"""
|
|
|
|
|
return await SSEResponse.send_progress(
|
|
|
|
|
f"⚠️ {reason}... ({retry_count}/{max_retries})",
|
|
|
|
|
self.current_progress,
|
|
|
|
|
"warning"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
async def error(self, error_message: str, code: int = 500) -> str:
|
|
|
|
|
"""发送错误消息"""
|
|
|
|
|
return await SSEResponse.send_error(error_message, code)
|
|
|
|
|
|
|
|
|
|
async def result(self, data: Dict[str, Any]) -> str:
|
|
|
|
|
"""发送结果数据"""
|
|
|
|
|
return await SSEResponse.send_result(data)
|
|
|
|
|
|
|
|
|
|
async def done(self) -> str:
|
|
|
|
|
"""发送完成信号"""
|
|
|
|
|
return await SSEResponse.send_done()
|
|
|
|
|
|
|
|
|
|
async def heartbeat(self) -> str:
|
|
|
|
|
"""发送心跳"""
|
|
|
|
|
return await SSEResponse.send_heartbeat()
|
|
|
|
|
|
|
|
|
|
def reset_generating_progress(self):
|
|
|
|
|
"""重置生成阶段进度(用于重试时)"""
|
|
|
|
|
self._last_generating_progress = 20
|
|
|
|
|
|
|
|
|
|
|
2025-10-30 11:14:43 +08:00
|
|
|
class SSEResponse:
|
|
|
|
|
"""SSE响应构建器"""
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def format_sse(data: Dict[str, Any], event: Optional[str] = None) -> str:
|
|
|
|
|
"""
|
|
|
|
|
格式化SSE消息
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
data: 要发送的数据字典
|
|
|
|
|
event: 事件类型(可选)
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
格式化后的SSE消息字符串
|
|
|
|
|
"""
|
2025-12-28 19:35:23 +08:00
|
|
|
try:
|
|
|
|
|
message = ""
|
|
|
|
|
if event:
|
|
|
|
|
message += f"event: {event}\n"
|
|
|
|
|
message += f"data: {json.dumps(data, ensure_ascii=False)}\n\n"
|
|
|
|
|
return message
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(f"❌ SSE格式化失败: {type(e).__name__}: {e}")
|
|
|
|
|
logger.error(f" data类型: {type(data)}")
|
|
|
|
|
logger.error(f" data内容: {str(data)[:500]}")
|
|
|
|
|
# 返回错误消息而不是崩溃
|
|
|
|
|
error_message = ""
|
|
|
|
|
if event:
|
|
|
|
|
error_message += f"event: {event}\n"
|
|
|
|
|
error_message += f'data: {{"type": "error", "error": "SSE格式化失败: {str(e)}", "code": 500}}\n\n'
|
|
|
|
|
return error_message
|
2025-10-30 11:14:43 +08:00
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
async def send_progress(
|
|
|
|
|
message: str,
|
|
|
|
|
progress: int,
|
|
|
|
|
status: str = "processing"
|
|
|
|
|
) -> str:
|
|
|
|
|
"""
|
|
|
|
|
发送进度消息
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
message: 进度消息
|
|
|
|
|
progress: 进度百分比(0-100)
|
|
|
|
|
status: 状态(processing/success/error)
|
|
|
|
|
"""
|
|
|
|
|
return SSEResponse.format_sse({
|
|
|
|
|
"type": "progress",
|
|
|
|
|
"message": message,
|
|
|
|
|
"progress": progress,
|
|
|
|
|
"status": status
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
async def send_chunk(content: str) -> str:
|
|
|
|
|
"""
|
|
|
|
|
发送内容块(用于流式输出AI生成内容)
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
content: 内容块
|
|
|
|
|
"""
|
|
|
|
|
return SSEResponse.format_sse({
|
|
|
|
|
"type": "chunk",
|
|
|
|
|
"content": content
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
async def send_result(data: Dict[str, Any]) -> str:
|
|
|
|
|
"""
|
|
|
|
|
发送最终结果
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
data: 结果数据
|
|
|
|
|
"""
|
|
|
|
|
return SSEResponse.format_sse({
|
|
|
|
|
"type": "result",
|
|
|
|
|
"data": data
|
|
|
|
|
})
|
|
|
|
|
|
2025-12-11 12:43:28 +08:00
|
|
|
@staticmethod
|
|
|
|
|
async def send_event(event: str, data: Dict[str, Any]) -> str:
|
|
|
|
|
"""
|
|
|
|
|
发送自定义事件类型的SSE消息
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
event: 事件类型名称
|
|
|
|
|
data: 事件数据
|
|
|
|
|
"""
|
|
|
|
|
return SSEResponse.format_sse(data, event=event)
|
|
|
|
|
|
2025-10-30 11:14:43 +08:00
|
|
|
@staticmethod
|
|
|
|
|
async def send_error(error: str, code: int = 500) -> str:
|
|
|
|
|
"""
|
|
|
|
|
发送错误消息
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
error: 错误描述
|
|
|
|
|
code: 错误码
|
|
|
|
|
"""
|
|
|
|
|
return SSEResponse.format_sse({
|
|
|
|
|
"type": "error",
|
|
|
|
|
"error": error,
|
|
|
|
|
"code": code
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
async def send_done() -> str:
|
|
|
|
|
"""发送完成消息"""
|
|
|
|
|
return SSEResponse.format_sse({
|
|
|
|
|
"type": "done"
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
async def send_heartbeat() -> str:
|
|
|
|
|
"""发送心跳消息(保持连接活跃)"""
|
|
|
|
|
return ": heartbeat\n\n"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def create_sse_generator(
|
|
|
|
|
async_gen: AsyncGenerator[str, None],
|
|
|
|
|
show_progress: bool = True
|
|
|
|
|
) -> AsyncGenerator[str, None]:
|
|
|
|
|
"""
|
|
|
|
|
创建SSE生成器包装器
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
async_gen: 异步生成器
|
|
|
|
|
show_progress: 是否显示进度
|
|
|
|
|
|
|
|
|
|
Yields:
|
|
|
|
|
格式化的SSE消息
|
|
|
|
|
"""
|
|
|
|
|
try:
|
|
|
|
|
if show_progress:
|
|
|
|
|
yield await SSEResponse.send_progress("开始生成...", 0)
|
|
|
|
|
|
|
|
|
|
# 累积内容用于进度计算
|
|
|
|
|
accumulated_content = ""
|
|
|
|
|
chunk_count = 0
|
|
|
|
|
|
|
|
|
|
async for chunk in async_gen:
|
|
|
|
|
chunk_count += 1
|
|
|
|
|
accumulated_content += chunk
|
|
|
|
|
|
|
|
|
|
# 发送内容块
|
|
|
|
|
yield await SSEResponse.send_chunk(chunk)
|
|
|
|
|
|
|
|
|
|
# 每10个块发送一次心跳
|
|
|
|
|
if chunk_count % 10 == 0:
|
|
|
|
|
yield await SSEResponse.send_heartbeat()
|
|
|
|
|
|
|
|
|
|
if show_progress:
|
|
|
|
|
yield await SSEResponse.send_progress("生成完成", 100, "success")
|
|
|
|
|
|
|
|
|
|
# 发送完成信号
|
|
|
|
|
yield await SSEResponse.send_done()
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(f"SSE生成器错误: {str(e)}")
|
|
|
|
|
yield await SSEResponse.send_error(str(e))
|
|
|
|
|
|
|
|
|
|
|
2026-04-26 13:58:15 +08:00
|
|
|
class _HeartbeatSentinel:
|
|
|
|
|
"""心跳哨兵对象,用于标识心跳事件(非AI内容)"""
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
HEARTBEAT = _HeartbeatSentinel()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def wrap_stream_with_heartbeat(
|
|
|
|
|
async_gen: AsyncGenerator,
|
|
|
|
|
heartbeat_interval: float = 15.0
|
|
|
|
|
) -> AsyncGenerator:
|
|
|
|
|
"""
|
|
|
|
|
包装异步生成器,在等待数据时产生心跳哨兵,防止连接超时断开。
|
|
|
|
|
|
|
|
|
|
用法:
|
|
|
|
|
async for chunk in wrap_stream_with_heartbeat(
|
|
|
|
|
ai_service.generate_text_stream(prompt),
|
|
|
|
|
heartbeat_interval=15
|
|
|
|
|
):
|
|
|
|
|
if chunk is HEARTBEAT:
|
|
|
|
|
yield await tracker.heartbeat()
|
|
|
|
|
continue
|
|
|
|
|
# chunk 是原始AI数据
|
|
|
|
|
"""
|
|
|
|
|
ait = async_gen.__aiter__()
|
|
|
|
|
while True:
|
|
|
|
|
try:
|
|
|
|
|
item = await asyncio.wait_for(ait.__anext__(), timeout=heartbeat_interval)
|
|
|
|
|
yield item
|
|
|
|
|
except asyncio.TimeoutError:
|
|
|
|
|
# 等待超时,产生心跳哨兵
|
|
|
|
|
yield HEARTBEAT
|
|
|
|
|
except StopAsyncIteration:
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
|
2025-10-30 11:14:43 +08:00
|
|
|
def create_sse_response(generator: AsyncGenerator[str, None]) -> StreamingResponse:
|
|
|
|
|
"""
|
2026-01-06 09:12:11 +08:00
|
|
|
创建SSE StreamingResponse - 兼容HTTP/2协议
|
2025-10-30 11:14:43 +08:00
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
generator: SSE消息生成器
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
StreamingResponse对象
|
2026-01-06 09:12:11 +08:00
|
|
|
|
|
|
|
|
注意:
|
|
|
|
|
- HTTP/2不支持Connection头,已移除
|
|
|
|
|
- 明确指定charset=utf-8以确保编码正确
|
|
|
|
|
- 添加CORS头以支持跨域请求
|
2025-10-30 11:14:43 +08:00
|
|
|
"""
|
2025-11-11 19:50:12 +08:00
|
|
|
async def wrapper():
|
|
|
|
|
"""包装生成器以捕获StreamingResponse初始化时的GeneratorExit"""
|
|
|
|
|
try:
|
|
|
|
|
async for chunk in generator:
|
|
|
|
|
yield chunk
|
|
|
|
|
except GeneratorExit:
|
|
|
|
|
# StreamingResponse在初始化时会进行类型检查,导致GeneratorExit
|
|
|
|
|
# 这是正常行为,不需要记录警告
|
|
|
|
|
pass
|
|
|
|
|
|
2025-10-30 11:14:43 +08:00
|
|
|
return StreamingResponse(
|
2025-11-11 19:50:12 +08:00
|
|
|
wrapper(),
|
2026-01-06 09:12:11 +08:00
|
|
|
media_type="text/event-stream; charset=utf-8", # 明确指定charset
|
2025-10-30 11:14:43 +08:00
|
|
|
headers={
|
2026-01-06 09:12:11 +08:00
|
|
|
"Cache-Control": "no-cache, no-transform", # 禁用缓存和转换
|
|
|
|
|
# 移除 Connection: keep-alive (HTTP/2不兼容)
|
2025-10-30 11:14:43 +08:00
|
|
|
"X-Accel-Buffering": "no", # 禁用nginx缓冲
|
2026-01-06 09:12:11 +08:00
|
|
|
"Access-Control-Allow-Origin": "*", # CORS支持
|
|
|
|
|
"Access-Control-Allow-Methods": "POST, GET, OPTIONS",
|
|
|
|
|
"Access-Control-Allow-Headers": "Content-Type, Authorization",
|
2025-10-30 11:14:43 +08:00
|
|
|
}
|
|
|
|
|
)
|